Compare commits
116 Commits
68e04a911a
...
dev
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a94e29d99c | ||
|
|
81e48c73ca | ||
|
|
a3f78dc801 | ||
|
|
07309013d7 | ||
|
|
846fc31190 | ||
|
|
ff7a67cb58 | ||
|
|
0760a8284d | ||
|
|
ce4d0c7b0d | ||
|
|
4ceb8ad98c | ||
|
|
f8aafb250d | ||
|
|
4385d20ca6 | ||
|
|
1a36907f10 | ||
|
|
0553a1fc53 | ||
|
|
57e969ed67 | ||
|
|
68275b1dd3 | ||
|
|
80d2dc0cb2 | ||
|
|
a8aa416ecb | ||
|
|
4c6bf55bcc | ||
|
|
98b455fdc3 | ||
|
|
0646c96b19 | ||
|
|
62afb328fe | ||
|
|
b9a746bc16 | ||
|
|
de8e18e97d | ||
|
|
a3e557d022 | ||
|
|
4e357db25d | ||
|
|
568aad3673 | ||
|
|
ddcf926158 | ||
|
|
865eeece58 | ||
|
|
05fb3612f9 | ||
|
|
1b2e7dde35 | ||
|
|
29074f26a6 | ||
|
|
77ed190310 | ||
|
|
2bbe925cef | ||
|
|
4a06b96b2e | ||
|
|
088c1725b0 | ||
|
|
7ba1767cea | ||
|
|
c63b6a4f76 | ||
|
|
803b720530 | ||
|
|
7ff00426f2 | ||
|
|
b3f0dd4005 | ||
|
|
707315facd | ||
|
|
38114b79f9 | ||
|
|
1cb3658369 | ||
|
|
dc875c5c95 | ||
|
|
0ea428b718 | ||
|
|
400d6f6f75 | ||
|
|
7716468d72 | ||
|
|
48f052200f | ||
|
|
fbb030da69 | ||
|
|
d49f819469 | ||
|
|
507f2e9c00 | ||
|
|
c0b253a010 | ||
|
|
fcbcff99e9 | ||
|
|
b49678b7df | ||
|
|
aeed9dfdbc | ||
|
|
13f617828b | ||
|
|
84e0a7fe81 | ||
|
|
063a35e698 | ||
|
|
a2246fb6e1 | ||
|
|
16ee4e0cb3 | ||
|
|
e6792c2d6c | ||
|
|
1d20b149dc | ||
|
|
570848cc2d | ||
|
|
6b970765ba | ||
|
|
e79215b4de | ||
|
|
3bf28aa121 | ||
|
|
cda9810a7e | ||
|
|
d47bd34a92 | ||
|
|
5b0ae54365 | ||
|
|
372af25aaa | ||
|
|
d0b717a128 | ||
|
|
9d40aece30 | ||
|
|
487c8a3863 | ||
|
|
8659e884e9 | ||
|
|
a05def5906 | ||
|
|
9f655913b1 | ||
|
|
13abd159fa | ||
|
|
acfe59c8b3 | ||
|
|
2e4700ae9b | ||
|
|
8c83e2a699 | ||
|
|
9b6356b0db | ||
|
|
a410586cfb | ||
|
|
0e34cab921 | ||
|
|
3cf3858fca | ||
|
|
db0c555041 | ||
|
|
51ad80071a | ||
|
|
d730ab7526 | ||
|
|
b218be9318 | ||
|
|
e6813c87c3 | ||
|
|
210204eb7a | ||
|
|
6ad4cda3f4 | ||
|
|
54ceaa6f5d | ||
|
|
34e7f69465 | ||
|
|
8fdbc2b359 | ||
|
|
28b1cc6e48 | ||
|
|
5a21847382 | ||
|
|
444d495f83 | ||
|
|
a943f79ce7 | ||
|
|
f54905abd0 | ||
|
|
0105e765b3 | ||
|
|
bb06b450fd | ||
|
|
c1d6a04276 | ||
|
|
d7b333385d | ||
|
|
f02320e57c | ||
|
|
3ec589293c | ||
|
|
7b1bea2966 | ||
|
|
da7b6b5bfa | ||
|
|
7aa63d79df | ||
|
|
333c9c40af | ||
|
|
0b192ce030 | ||
|
|
da021d0640 | ||
|
|
d1b47006f4 | ||
|
|
a73d3c7d3e | ||
|
|
55ae92c460 | ||
|
|
fe6a98c379 | ||
|
|
b7c1191335 |
55
.env.demo
Normal file
55
.env.demo
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
# Common settings
|
||||||
|
PROJECT_NAME=App
|
||||||
|
VERSION=1.0.0
|
||||||
|
|
||||||
|
# Database settings
|
||||||
|
POSTGRES_USER=postgres
|
||||||
|
POSTGRES_PASSWORD=postgres
|
||||||
|
POSTGRES_DB=app
|
||||||
|
POSTGRES_HOST=db
|
||||||
|
POSTGRES_PORT=5432
|
||||||
|
DATABASE_URL=postgresql://postgres:postgres@db:5432/app
|
||||||
|
|
||||||
|
# Backend settings
|
||||||
|
BACKEND_PORT=8000
|
||||||
|
# CRITICAL: Generate a secure SECRET_KEY for production!
|
||||||
|
# Generate with: python -c 'import secrets; print(secrets.token_urlsafe(32))'
|
||||||
|
# Must be at least 32 characters
|
||||||
|
SECRET_KEY=demo_secret_key_for_testing_only_do_not_use_in_prod
|
||||||
|
ENVIRONMENT=development
|
||||||
|
DEMO_MODE=true
|
||||||
|
DEBUG=true
|
||||||
|
BACKEND_CORS_ORIGINS=["http://localhost:3000"]
|
||||||
|
FIRST_SUPERUSER_EMAIL=admin@example.com
|
||||||
|
# IMPORTANT: Use a strong password (min 12 chars, mixed case, digits)
|
||||||
|
# Default weak passwords like 'Admin123' are rejected
|
||||||
|
FIRST_SUPERUSER_PASSWORD=Admin123!
|
||||||
|
|
||||||
|
# OAuth Configuration (Social Login)
|
||||||
|
# Set OAUTH_ENABLED=true and configure at least one provider
|
||||||
|
OAUTH_ENABLED=false
|
||||||
|
OAUTH_AUTO_LINK_BY_EMAIL=true
|
||||||
|
|
||||||
|
# Google OAuth (from Google Cloud Console > APIs & Services > Credentials)
|
||||||
|
# https://console.cloud.google.com/apis/credentials
|
||||||
|
# OAUTH_GOOGLE_CLIENT_ID=your-google-client-id.apps.googleusercontent.com
|
||||||
|
# OAUTH_GOOGLE_CLIENT_SECRET=your-google-client-secret
|
||||||
|
|
||||||
|
# GitHub OAuth (from GitHub > Settings > Developer settings > OAuth Apps)
|
||||||
|
# https://github.com/settings/developers
|
||||||
|
# OAUTH_GITHUB_CLIENT_ID=your-github-client-id
|
||||||
|
# OAUTH_GITHUB_CLIENT_SECRET=your-github-client-secret
|
||||||
|
|
||||||
|
# OAuth Provider Mode (Authorization Server for MCP/third-party clients)
|
||||||
|
# Set OAUTH_PROVIDER_ENABLED=true to act as an OAuth 2.0 Authorization Server
|
||||||
|
OAUTH_PROVIDER_ENABLED=true
|
||||||
|
# IMPORTANT: Must be HTTPS in production!
|
||||||
|
OAUTH_ISSUER=http://localhost:8000
|
||||||
|
|
||||||
|
# Frontend settings
|
||||||
|
FRONTEND_PORT=3000
|
||||||
|
FRONTEND_URL=http://localhost:3000
|
||||||
|
NEXT_PUBLIC_API_URL=http://localhost:8000
|
||||||
|
NEXT_PUBLIC_API_BASE_URL=http://localhost:8000
|
||||||
|
NEXT_PUBLIC_APP_URL=http://localhost:3000
|
||||||
|
NODE_ENV=development
|
||||||
@@ -17,6 +17,7 @@ BACKEND_PORT=8000
|
|||||||
# Must be at least 32 characters
|
# Must be at least 32 characters
|
||||||
SECRET_KEY=your_secret_key_here_REPLACE_WITH_GENERATED_KEY_32_CHARS_MIN
|
SECRET_KEY=your_secret_key_here_REPLACE_WITH_GENERATED_KEY_32_CHARS_MIN
|
||||||
ENVIRONMENT=development
|
ENVIRONMENT=development
|
||||||
|
DEMO_MODE=false
|
||||||
DEBUG=true
|
DEBUG=true
|
||||||
BACKEND_CORS_ORIGINS=["http://localhost:3000"]
|
BACKEND_CORS_ORIGINS=["http://localhost:3000"]
|
||||||
FIRST_SUPERUSER_EMAIL=admin@example.com
|
FIRST_SUPERUSER_EMAIL=admin@example.com
|
||||||
@@ -24,7 +25,31 @@ FIRST_SUPERUSER_EMAIL=admin@example.com
|
|||||||
# Default weak passwords like 'Admin123' are rejected
|
# Default weak passwords like 'Admin123' are rejected
|
||||||
FIRST_SUPERUSER_PASSWORD=YourStrongPassword123!
|
FIRST_SUPERUSER_PASSWORD=YourStrongPassword123!
|
||||||
|
|
||||||
|
# OAuth Configuration (Social Login)
|
||||||
|
# Set OAUTH_ENABLED=true and configure at least one provider
|
||||||
|
OAUTH_ENABLED=false
|
||||||
|
OAUTH_AUTO_LINK_BY_EMAIL=true
|
||||||
|
|
||||||
|
# Google OAuth (from Google Cloud Console > APIs & Services > Credentials)
|
||||||
|
# https://console.cloud.google.com/apis/credentials
|
||||||
|
# OAUTH_GOOGLE_CLIENT_ID=your-google-client-id.apps.googleusercontent.com
|
||||||
|
# OAUTH_GOOGLE_CLIENT_SECRET=your-google-client-secret
|
||||||
|
|
||||||
|
# GitHub OAuth (from GitHub > Settings > Developer settings > OAuth Apps)
|
||||||
|
# https://github.com/settings/developers
|
||||||
|
# OAUTH_GITHUB_CLIENT_ID=your-github-client-id
|
||||||
|
# OAUTH_GITHUB_CLIENT_SECRET=your-github-client-secret
|
||||||
|
|
||||||
|
# OAuth Provider Mode (Authorization Server for MCP/third-party clients)
|
||||||
|
# Set OAUTH_PROVIDER_ENABLED=true to act as an OAuth 2.0 Authorization Server
|
||||||
|
OAUTH_PROVIDER_ENABLED=false
|
||||||
|
# IMPORTANT: Must be HTTPS in production!
|
||||||
|
OAUTH_ISSUER=http://localhost:8000
|
||||||
|
|
||||||
# Frontend settings
|
# Frontend settings
|
||||||
FRONTEND_PORT=3000
|
FRONTEND_PORT=3000
|
||||||
|
FRONTEND_URL=http://localhost:3000
|
||||||
NEXT_PUBLIC_API_URL=http://localhost:8000
|
NEXT_PUBLIC_API_URL=http://localhost:8000
|
||||||
|
NEXT_PUBLIC_API_BASE_URL=http://localhost:8000
|
||||||
|
NEXT_PUBLIC_APP_URL=http://localhost:3000
|
||||||
NODE_ENV=development
|
NODE_ENV=development
|
||||||
|
|||||||
2
.github/workflows/README.md
vendored
2
.github/workflows/README.md
vendored
@@ -41,7 +41,7 @@ To enable CI/CD workflows:
|
|||||||
- Runs on: Push to main/develop, PRs affecting frontend code
|
- Runs on: Push to main/develop, PRs affecting frontend code
|
||||||
- Tests: Frontend unit tests (Jest)
|
- Tests: Frontend unit tests (Jest)
|
||||||
- Coverage: Uploads to Codecov
|
- Coverage: Uploads to Codecov
|
||||||
- Fast: Uses npm cache
|
- Fast: Uses bun cache
|
||||||
|
|
||||||
### `e2e-tests.yml`
|
### `e2e-tests.yml`
|
||||||
- Runs on: All pushes and PRs
|
- Runs on: All pushes and PRs
|
||||||
|
|||||||
77
.github/workflows/backend-e2e-tests.yml.template
vendored
Normal file
77
.github/workflows/backend-e2e-tests.yml.template
vendored
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
# Backend E2E Tests CI Pipeline
|
||||||
|
#
|
||||||
|
# Runs end-to-end tests with real PostgreSQL via Testcontainers
|
||||||
|
# and validates API contracts with Schemathesis.
|
||||||
|
#
|
||||||
|
# To enable: Rename this file to backend-e2e-tests.yml
|
||||||
|
|
||||||
|
name: Backend E2E Tests
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [main, develop]
|
||||||
|
paths:
|
||||||
|
- 'backend/**'
|
||||||
|
- '.github/workflows/backend-e2e-tests.yml'
|
||||||
|
pull_request:
|
||||||
|
branches: [main, develop]
|
||||||
|
paths:
|
||||||
|
- 'backend/**'
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
e2e-tests:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
# E2E test failures don't block merge - they're advisory
|
||||||
|
continue-on-error: true
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: '3.12'
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@v4
|
||||||
|
with:
|
||||||
|
version: "latest"
|
||||||
|
|
||||||
|
- name: Cache uv dependencies
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: ~/.cache/uv
|
||||||
|
key: uv-${{ runner.os }}-${{ hashFiles('backend/uv.lock') }}
|
||||||
|
restore-keys: |
|
||||||
|
uv-${{ runner.os }}-
|
||||||
|
|
||||||
|
- name: Install dependencies (with E2E)
|
||||||
|
working-directory: ./backend
|
||||||
|
run: uv sync --extra dev --extra e2e
|
||||||
|
|
||||||
|
- name: Check Docker availability
|
||||||
|
id: docker-check
|
||||||
|
run: |
|
||||||
|
if docker info > /dev/null 2>&1; then
|
||||||
|
echo "available=true" >> $GITHUB_OUTPUT
|
||||||
|
echo "Docker is available"
|
||||||
|
else
|
||||||
|
echo "available=false" >> $GITHUB_OUTPUT
|
||||||
|
echo "::warning::Docker not available - E2E tests will be skipped"
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Run E2E tests
|
||||||
|
if: steps.docker-check.outputs.available == 'true'
|
||||||
|
working-directory: ./backend
|
||||||
|
env:
|
||||||
|
IS_TEST: "True"
|
||||||
|
SECRET_KEY: "e2e-test-secret-key-minimum-32-characters-long"
|
||||||
|
PYTHONPATH: "."
|
||||||
|
run: |
|
||||||
|
uv run pytest tests/e2e/ -v --tb=short
|
||||||
|
|
||||||
|
- name: E2E tests skipped
|
||||||
|
if: steps.docker-check.outputs.available != 'true'
|
||||||
|
run: echo "E2E tests were skipped due to Docker unavailability"
|
||||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -187,7 +187,7 @@ coverage.xml
|
|||||||
.hypothesis/
|
.hypothesis/
|
||||||
.pytest_cache/
|
.pytest_cache/
|
||||||
cover/
|
cover/
|
||||||
|
backend/.benchmarks
|
||||||
# Translations
|
# Translations
|
||||||
*.mo
|
*.mo
|
||||||
*.pot
|
*.pot
|
||||||
@@ -268,6 +268,7 @@ celerybeat.pid
|
|||||||
.env
|
.env
|
||||||
.env.*
|
.env.*
|
||||||
!.env.template
|
!.env.template
|
||||||
|
!.env.demo
|
||||||
.venv
|
.venv
|
||||||
env/
|
env/
|
||||||
venv/
|
venv/
|
||||||
|
|||||||
315
AGENTS.md
Normal file
315
AGENTS.md
Normal file
@@ -0,0 +1,315 @@
|
|||||||
|
# AGENTS.md
|
||||||
|
|
||||||
|
AI coding assistant context for FastAPI + Next.js Full-Stack Template.
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Backend (Python with uv)
|
||||||
|
cd backend
|
||||||
|
make install-dev # Install dependencies
|
||||||
|
make test # Run tests
|
||||||
|
uv run uvicorn app.main:app --reload # Start dev server
|
||||||
|
|
||||||
|
# Frontend (Node.js)
|
||||||
|
cd frontend
|
||||||
|
bun install # Install dependencies
|
||||||
|
bun run dev # Start dev server
|
||||||
|
bun run generate:api # Generate API client from OpenAPI
|
||||||
|
bun run test:e2e # Run E2E tests
|
||||||
|
```
|
||||||
|
|
||||||
|
**Access points:**
|
||||||
|
- Frontend: **http://localhost:3000**
|
||||||
|
- Backend API: **http://localhost:8000**
|
||||||
|
- API Docs: **http://localhost:8000/docs**
|
||||||
|
|
||||||
|
Default superuser (change in production):
|
||||||
|
- Email: `admin@example.com`
|
||||||
|
- Password: `admin123`
|
||||||
|
|
||||||
|
## Project Architecture
|
||||||
|
|
||||||
|
**Full-stack TypeScript/Python application:**
|
||||||
|
|
||||||
|
```
|
||||||
|
├── backend/ # FastAPI backend
|
||||||
|
│ ├── app/
|
||||||
|
│ │ ├── api/ # API routes (auth, users, organizations, admin)
|
||||||
|
│ │ ├── core/ # Core functionality (auth, config, database)
|
||||||
|
│ │ ├── repositories/ # Repository pattern (database operations)
|
||||||
|
│ │ ├── models/ # SQLAlchemy ORM models
|
||||||
|
│ │ ├── schemas/ # Pydantic request/response schemas
|
||||||
|
│ │ ├── services/ # Business logic layer
|
||||||
|
│ │ └── utils/ # Utilities (security, device detection)
|
||||||
|
│ ├── tests/ # 96% coverage, 987 tests
|
||||||
|
│ └── alembic/ # Database migrations
|
||||||
|
│
|
||||||
|
└── frontend/ # Next.js 16 frontend
|
||||||
|
├── src/
|
||||||
|
│ ├── app/ # App Router pages (Next.js 16)
|
||||||
|
│ ├── components/ # React components
|
||||||
|
│ ├── lib/
|
||||||
|
│ │ ├── api/ # Auto-generated API client
|
||||||
|
│ │ └── stores/ # Zustand state management
|
||||||
|
│ └── hooks/ # Custom React hooks
|
||||||
|
└── e2e/ # Playwright E2E tests (56 passing)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Critical Implementation Notes
|
||||||
|
|
||||||
|
### Authentication Flow
|
||||||
|
- **JWT-based**: Access tokens (15 min) + refresh tokens (7 days)
|
||||||
|
- **OAuth/Social Login**: Google and GitHub with PKCE support
|
||||||
|
- **Session tracking**: Database-backed with device info, IP, user agent
|
||||||
|
- **Token refresh**: Validates JTI in database, not just JWT signature
|
||||||
|
- **Authorization**: FastAPI dependencies in `api/dependencies/auth.py`
|
||||||
|
- `get_current_user`: Requires valid access token
|
||||||
|
- `get_current_active_user`: Requires active account
|
||||||
|
- `get_optional_current_user`: Accepts authenticated or anonymous
|
||||||
|
- `get_current_superuser`: Requires superuser flag
|
||||||
|
|
||||||
|
### OAuth Provider Mode (MCP Integration)
|
||||||
|
Full OAuth 2.0 Authorization Server for MCP (Model Context Protocol) clients:
|
||||||
|
- **Authorization Code Flow with PKCE**: RFC 7636 compliant
|
||||||
|
- **JWT access tokens**: Self-contained, no DB lookup required
|
||||||
|
- **Opaque refresh tokens**: Stored hashed in database, supports rotation
|
||||||
|
- **Token introspection**: RFC 7662 compliant endpoint
|
||||||
|
- **Token revocation**: RFC 7009 compliant endpoint
|
||||||
|
- **Server metadata**: RFC 8414 compliant discovery endpoint
|
||||||
|
- **Consent management**: User can review and revoke app permissions
|
||||||
|
|
||||||
|
**API endpoints:**
|
||||||
|
- `GET /.well-known/oauth-authorization-server` - Server metadata
|
||||||
|
- `GET /oauth/provider/authorize` - Authorization endpoint
|
||||||
|
- `POST /oauth/provider/authorize/consent` - Consent submission
|
||||||
|
- `POST /oauth/provider/token` - Token endpoint
|
||||||
|
- `POST /oauth/provider/revoke` - Token revocation
|
||||||
|
- `POST /oauth/provider/introspect` - Token introspection
|
||||||
|
- Client management endpoints (admin only)
|
||||||
|
|
||||||
|
**Scopes supported:** `openid`, `profile`, `email`, `read:users`, `write:users`, `admin`
|
||||||
|
|
||||||
|
**OAuth Configuration (backend `.env`):**
|
||||||
|
```bash
|
||||||
|
# OAuth Social Login (as OAuth Consumer)
|
||||||
|
OAUTH_ENABLED=true # Enable OAuth social login
|
||||||
|
OAUTH_AUTO_LINK_BY_EMAIL=true # Auto-link accounts by email
|
||||||
|
OAUTH_STATE_EXPIRE_MINUTES=10 # CSRF state expiration
|
||||||
|
|
||||||
|
# Google OAuth
|
||||||
|
OAUTH_GOOGLE_CLIENT_ID=your-google-client-id
|
||||||
|
OAUTH_GOOGLE_CLIENT_SECRET=your-google-client-secret
|
||||||
|
|
||||||
|
# GitHub OAuth
|
||||||
|
OAUTH_GITHUB_CLIENT_ID=your-github-client-id
|
||||||
|
OAUTH_GITHUB_CLIENT_SECRET=your-github-client-secret
|
||||||
|
|
||||||
|
# OAuth Provider Mode (as Authorization Server for MCP)
|
||||||
|
OAUTH_PROVIDER_ENABLED=true # Enable OAuth provider mode
|
||||||
|
OAUTH_ISSUER=https://api.yourdomain.com # JWT issuer URL (must be HTTPS in production)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Database Pattern
|
||||||
|
- **Async SQLAlchemy 2.0** with PostgreSQL
|
||||||
|
- **Connection pooling**: 20 base connections, 50 max overflow
|
||||||
|
- **Repository base class**: `repositories/base.py` with common operations
|
||||||
|
- **Migrations**: Alembic with helper script `migrate.py`
|
||||||
|
- `python migrate.py auto "message"` - Generate and apply
|
||||||
|
- `python migrate.py list` - View history
|
||||||
|
|
||||||
|
### Frontend State Management
|
||||||
|
- **Zustand stores**: Lightweight state management
|
||||||
|
- **TanStack Query**: API data fetching/caching
|
||||||
|
- **Auto-generated client**: From OpenAPI spec via `bun run generate:api`
|
||||||
|
- **Dependency Injection**: ALWAYS use `useAuth()` from `AuthContext`, NEVER import `useAuthStore` directly
|
||||||
|
|
||||||
|
### Internationalization (i18n)
|
||||||
|
- **next-intl v4**: Type-safe internationalization library
|
||||||
|
- **Locale routing**: `/en/*`, `/it/*` (English and Italian supported)
|
||||||
|
- **Translation files**: `frontend/messages/en.json`, `frontend/messages/it.json`
|
||||||
|
- **LocaleSwitcher**: Component for seamless language switching
|
||||||
|
- **SEO-friendly**: Locale-aware metadata, sitemaps, and robots.txt
|
||||||
|
- **Type safety**: Full TypeScript support for translations
|
||||||
|
- **Utilities**: `frontend/src/lib/i18n/` (metadata, routing, utils)
|
||||||
|
|
||||||
|
### Organization System
|
||||||
|
Three-tier RBAC:
|
||||||
|
- **Owner**: Full control (delete org, manage all members)
|
||||||
|
- **Admin**: Add/remove members, assign admin role (not owner)
|
||||||
|
- **Member**: Read-only organization access
|
||||||
|
|
||||||
|
Permission dependencies in `api/dependencies/permissions.py`:
|
||||||
|
- `require_organization_owner`
|
||||||
|
- `require_organization_admin`
|
||||||
|
- `require_organization_member`
|
||||||
|
- `can_manage_organization_member`
|
||||||
|
|
||||||
|
### Testing Infrastructure
|
||||||
|
|
||||||
|
**Backend Unit/Integration (pytest + SQLite):**
|
||||||
|
- 96% coverage, 987 tests
|
||||||
|
- Security-focused: JWT attacks, session hijacking, privilege escalation
|
||||||
|
- Async fixtures in `tests/conftest.py`
|
||||||
|
- Run: `IS_TEST=True uv run pytest` or `make test`
|
||||||
|
- Coverage: `make test-cov`
|
||||||
|
|
||||||
|
**Backend E2E (pytest + Testcontainers + Schemathesis):**
|
||||||
|
- Real PostgreSQL via Docker containers
|
||||||
|
- OpenAPI contract testing with Schemathesis
|
||||||
|
- Install: `make install-e2e`
|
||||||
|
- Run: `make test-e2e`
|
||||||
|
- Schema tests: `make test-e2e-schema`
|
||||||
|
- Docs: `backend/docs/E2E_TESTING.md`
|
||||||
|
|
||||||
|
**Frontend Unit Tests (Jest):**
|
||||||
|
- 97% coverage
|
||||||
|
- Component, hook, and utility testing
|
||||||
|
- Run: `bun run test`
|
||||||
|
- Coverage: `bun run test:coverage`
|
||||||
|
|
||||||
|
**Frontend E2E Tests (Playwright):**
|
||||||
|
- 56 passing, 1 skipped (zero flaky tests)
|
||||||
|
- Complete user flows (auth, navigation, settings)
|
||||||
|
- Run: `bun run test:e2e`
|
||||||
|
- UI mode: `bun run test:e2e:ui`
|
||||||
|
|
||||||
|
### Development Tooling
|
||||||
|
|
||||||
|
**Backend:**
|
||||||
|
- **uv**: Modern Python package manager (10-100x faster than pip)
|
||||||
|
- **Ruff**: All-in-one linting/formatting (replaces Black, Flake8, isort)
|
||||||
|
- **Pyright**: Static type checking (strict mode)
|
||||||
|
- **pip-audit**: Dependency vulnerability scanning (OSV database)
|
||||||
|
- **detect-secrets**: Hardcoded secrets detection
|
||||||
|
- **pip-licenses**: License compliance checking
|
||||||
|
- **pre-commit**: Git hook framework (Ruff, detect-secrets, standard checks)
|
||||||
|
- **Makefile**: `make help` for all commands
|
||||||
|
|
||||||
|
**Frontend:**
|
||||||
|
- **Next.js 16**: App Router with React 19
|
||||||
|
- **TypeScript**: Full type safety
|
||||||
|
- **TailwindCSS + shadcn/ui**: Design system
|
||||||
|
- **ESLint + Prettier**: Code quality
|
||||||
|
|
||||||
|
### Environment Configuration
|
||||||
|
|
||||||
|
**Backend** (`.env`):
|
||||||
|
```bash
|
||||||
|
POSTGRES_USER=postgres
|
||||||
|
POSTGRES_PASSWORD=your_password
|
||||||
|
POSTGRES_HOST=db
|
||||||
|
POSTGRES_PORT=5432
|
||||||
|
POSTGRES_DB=app
|
||||||
|
|
||||||
|
SECRET_KEY=your-secret-key-min-32-chars
|
||||||
|
ENVIRONMENT=development|production
|
||||||
|
CSP_MODE=relaxed|strict|disabled
|
||||||
|
|
||||||
|
FIRST_SUPERUSER_EMAIL=admin@example.com
|
||||||
|
FIRST_SUPERUSER_PASSWORD=admin123
|
||||||
|
|
||||||
|
BACKEND_CORS_ORIGINS=["http://localhost:3000"]
|
||||||
|
```
|
||||||
|
|
||||||
|
**Frontend** (`.env.local`):
|
||||||
|
```bash
|
||||||
|
NEXT_PUBLIC_API_URL=http://localhost:8000/api/v1
|
||||||
|
```
|
||||||
|
|
||||||
|
## Common Development Workflows
|
||||||
|
|
||||||
|
### Adding a New API Endpoint
|
||||||
|
|
||||||
|
1. **Define schema** in `backend/app/schemas/`
|
||||||
|
2. **Create repository** in `backend/app/repositories/`
|
||||||
|
3. **Implement route** in `backend/app/api/routes/`
|
||||||
|
4. **Register router** in `backend/app/api/main.py`
|
||||||
|
5. **Write tests** in `backend/tests/api/`
|
||||||
|
6. **Generate frontend client**: `bun run generate:api`
|
||||||
|
|
||||||
|
### Database Migrations
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd backend
|
||||||
|
python migrate.py generate "description" # Create migration
|
||||||
|
python migrate.py apply # Apply migrations
|
||||||
|
python migrate.py auto "description" # Generate + apply
|
||||||
|
```
|
||||||
|
|
||||||
|
### Frontend Component Development
|
||||||
|
|
||||||
|
1. **Create component** in `frontend/src/components/`
|
||||||
|
2. **Follow design system** (see `frontend/docs/design-system/`)
|
||||||
|
3. **Use dependency injection** for auth (`useAuth()` not `useAuthStore`)
|
||||||
|
4. **Write tests** in `frontend/tests/` or `__tests__/`
|
||||||
|
5. **Run type check**: `bun run type-check`
|
||||||
|
|
||||||
|
## Security Features
|
||||||
|
|
||||||
|
- **Password hashing**: bcrypt with salt rounds
|
||||||
|
- **Rate limiting**: 60 req/min default, 10 req/min on auth endpoints
|
||||||
|
- **Security headers**: CSP, X-Frame-Options, HSTS, etc.
|
||||||
|
- **CSRF protection**: Built into FastAPI
|
||||||
|
- **Session revocation**: Database-backed session tracking
|
||||||
|
- **Comprehensive security tests**: JWT algorithm attacks, session hijacking, privilege escalation
|
||||||
|
- **Dependency vulnerability scanning**: `make dep-audit` (pip-audit against OSV database)
|
||||||
|
- **License compliance**: `make license-check` (blocks GPL-3.0/AGPL)
|
||||||
|
- **Secrets detection**: Pre-commit hook blocks hardcoded secrets
|
||||||
|
- **Unified security pipeline**: `make audit` (all security checks), `make check` (quality + security + tests)
|
||||||
|
|
||||||
|
## Docker Deployment
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Development (with hot reload)
|
||||||
|
docker-compose -f docker-compose.dev.yml up
|
||||||
|
|
||||||
|
# Production
|
||||||
|
docker-compose up -d
|
||||||
|
|
||||||
|
# Run migrations
|
||||||
|
docker-compose exec backend alembic upgrade head
|
||||||
|
|
||||||
|
# Create first superuser
|
||||||
|
docker-compose exec backend python -c "from app.init_db import init_db; import asyncio; asyncio.run(init_db())"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Documentation
|
||||||
|
|
||||||
|
**For comprehensive documentation, see:**
|
||||||
|
- **[README.md](./README.md)** - User-facing project overview
|
||||||
|
- **[CLAUDE.md](./CLAUDE.md)** - Claude Code-specific guidance
|
||||||
|
- **Backend docs**: `backend/docs/` (Architecture, Coding Standards, Common Pitfalls, Feature Examples)
|
||||||
|
- **Frontend docs**: `frontend/docs/` (Design System, Architecture, E2E Testing)
|
||||||
|
- **API docs**: http://localhost:8000/docs (Swagger UI when running)
|
||||||
|
|
||||||
|
## Current Status (Nov 2025)
|
||||||
|
|
||||||
|
### Completed Features ✅
|
||||||
|
- Authentication system (JWT with refresh tokens, OAuth/social login)
|
||||||
|
- **OAuth Provider Mode (MCP-ready)**: Full OAuth 2.0 Authorization Server
|
||||||
|
- Session management (device tracking, revocation)
|
||||||
|
- User management (full lifecycle, password change)
|
||||||
|
- Organization system (multi-tenant with RBAC)
|
||||||
|
- Admin panel (user/org management, bulk operations)
|
||||||
|
- **Internationalization (i18n)** with English and Italian
|
||||||
|
- Comprehensive test coverage (96% backend, 97% frontend unit, 56 E2E tests)
|
||||||
|
- Design system documentation
|
||||||
|
- **Marketing landing page** with animations
|
||||||
|
- **`/dev` documentation portal** with live examples
|
||||||
|
- **Toast notifications**, charts, markdown rendering
|
||||||
|
- **SEO optimization** (sitemap, robots.txt, locale metadata)
|
||||||
|
- Docker deployment
|
||||||
|
|
||||||
|
### In Progress 🚧
|
||||||
|
- Frontend admin pages (70% complete)
|
||||||
|
- Email integration (templates ready, SMTP pending)
|
||||||
|
|
||||||
|
### Planned 🔮
|
||||||
|
- GitHub Actions CI/CD
|
||||||
|
- Additional languages (Spanish, French, German, etc.)
|
||||||
|
- SSO/SAML authentication
|
||||||
|
- Real-time notifications (WebSockets)
|
||||||
|
- Webhook system
|
||||||
|
- Background job processing
|
||||||
|
- File upload/storage
|
||||||
785
CLAUDE.md
785
CLAUDE.md
@@ -1,10 +1,14 @@
|
|||||||
# CLAUDE.md
|
# CLAUDE.md
|
||||||
|
|
||||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
Claude Code context for FastAPI + Next.js Full-Stack Template.
|
||||||
|
|
||||||
## Critical User Preferences
|
**See [AGENTS.md](./AGENTS.md) for project context, architecture, and development commands.**
|
||||||
|
|
||||||
### File Operations - NEVER Use Heredoc/Cat Append
|
## Claude Code-Specific Guidance
|
||||||
|
|
||||||
|
### Critical User Preferences
|
||||||
|
|
||||||
|
#### File Operations - NEVER Use Heredoc/Cat Append
|
||||||
**ALWAYS use Read/Write/Edit tools instead of `cat >> file << EOF` commands.**
|
**ALWAYS use Read/Write/Edit tools instead of `cat >> file << EOF` commands.**
|
||||||
|
|
||||||
This triggers manual approval dialogs and disrupts workflow.
|
This triggers manual approval dialogs and disrupts workflow.
|
||||||
@@ -18,215 +22,53 @@ EOF
|
|||||||
# CORRECT ✅ - Use Read, then Write tools
|
# CORRECT ✅ - Use Read, then Write tools
|
||||||
```
|
```
|
||||||
|
|
||||||
### Work Style
|
#### Work Style
|
||||||
- User prefers autonomous operation without frequent interruptions
|
- User prefers autonomous operation without frequent interruptions
|
||||||
- Ask for batch permissions upfront for long work sessions
|
- Ask for batch permissions upfront for long work sessions
|
||||||
- Work independently, document decisions clearly
|
- Work independently, document decisions clearly
|
||||||
|
- Only use emojis if the user explicitly requests it
|
||||||
|
|
||||||
## Project Architecture
|
### When Working with This Stack
|
||||||
|
|
||||||
This is a **FastAPI + Next.js full-stack application** with the following structure:
|
**Dependency Management:**
|
||||||
|
- Backend uses **uv** (modern Python package manager), not pip
|
||||||
|
- Always use `uv run` prefix: `IS_TEST=True uv run pytest`
|
||||||
|
- Or use Makefile commands: `make test`, `make install-dev`
|
||||||
|
- Add dependencies: `uv add <package>` or `uv add --dev <package>`
|
||||||
|
|
||||||
### Backend (FastAPI)
|
**Database Migrations:**
|
||||||
```
|
- Use the `migrate.py` helper script, not Alembic directly
|
||||||
backend/app/
|
- Generate + apply: `python migrate.py auto "message"`
|
||||||
├── api/ # API routes organized by version
|
- Never commit migrations without testing them first
|
||||||
│ ├── routes/ # Endpoint implementations (auth, users, sessions, admin, organizations)
|
- Check current state: `python migrate.py current`
|
||||||
│ └── dependencies/ # FastAPI dependencies (auth, permissions)
|
|
||||||
├── core/ # Core functionality
|
|
||||||
│ ├── config.py # Settings (Pydantic BaseSettings)
|
|
||||||
│ ├── database.py # SQLAlchemy async engine setup
|
|
||||||
│ ├── auth.py # JWT token generation/validation
|
|
||||||
│ └── exceptions.py # Custom exception classes and handlers
|
|
||||||
├── crud/ # Database CRUD operations (base, user, session, organization)
|
|
||||||
├── models/ # SQLAlchemy ORM models
|
|
||||||
├── schemas/ # Pydantic request/response schemas
|
|
||||||
├── services/ # Business logic layer (auth_service)
|
|
||||||
└── utils/ # Utilities (security, device detection, test helpers)
|
|
||||||
```
|
|
||||||
|
|
||||||
### Frontend (Next.js 15)
|
**Frontend API Client Generation:**
|
||||||
```
|
- Run `bun run generate:api` after backend schema changes
|
||||||
frontend/src/
|
- Client is auto-generated from OpenAPI spec
|
||||||
├── app/ # Next.js App Router pages
|
- Located in `frontend/src/lib/api/generated/`
|
||||||
├── components/ # React components (auth/, ui/)
|
- NEVER manually edit generated files
|
||||||
├── lib/
|
|
||||||
│ ├── api/ # API client (auto-generated from OpenAPI)
|
|
||||||
│ ├── stores/ # Zustand state management
|
|
||||||
│ └── utils/ # Utility functions
|
|
||||||
└── hooks/ # Custom React hooks
|
|
||||||
```
|
|
||||||
|
|
||||||
## Development Commands
|
**Testing Commands:**
|
||||||
|
- Backend unit/integration: `IS_TEST=True uv run pytest` (always prefix with `IS_TEST=True`)
|
||||||
|
- Backend E2E (requires Docker): `make test-e2e`
|
||||||
|
- Frontend unit: `bun run test`
|
||||||
|
- Frontend E2E: `bun run test:e2e`
|
||||||
|
- Use `make test` or `make test-cov` in backend for convenience
|
||||||
|
|
||||||
### Backend
|
**Security & Quality Commands (Backend):**
|
||||||
|
- `make validate` — lint + format + type checks
|
||||||
|
- `make audit` — dependency vulnerabilities + license compliance
|
||||||
|
- `make validate-all` — quality + security checks
|
||||||
|
- `make check` — **full pipeline**: quality + security + tests
|
||||||
|
|
||||||
#### Setup
|
**Backend E2E Testing (requires Docker):**
|
||||||
|
- Install deps: `make install-e2e`
|
||||||
**Dependencies are managed with [uv](https://docs.astral.sh/uv/) - the modern, fast Python package manager.**
|
- Run all E2E tests: `make test-e2e`
|
||||||
|
- Run schema tests only: `make test-e2e-schema`
|
||||||
```bash
|
- Run all tests: `make test-all` (unit + E2E)
|
||||||
cd backend
|
- Uses Testcontainers (real PostgreSQL) + Schemathesis (OpenAPI contract testing)
|
||||||
|
- Markers: `@pytest.mark.e2e`, `@pytest.mark.postgres`, `@pytest.mark.schemathesis`
|
||||||
# Install uv (if not already installed)
|
- See: `backend/docs/E2E_TESTING.md` for complete guide
|
||||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
|
||||||
|
|
||||||
# Install all dependencies (production + dev) from uv.lock
|
|
||||||
uv sync --extra dev
|
|
||||||
|
|
||||||
# Or use the Makefile
|
|
||||||
make install-dev
|
|
||||||
```
|
|
||||||
|
|
||||||
**Why uv?**
|
|
||||||
- 🚀 10-100x faster than pip
|
|
||||||
- 🔒 Reproducible builds with `uv.lock`
|
|
||||||
- 📦 Modern dependency resolution
|
|
||||||
- ⚡ Built by Astral (creators of Ruff)
|
|
||||||
|
|
||||||
#### Database Migrations
|
|
||||||
```bash
|
|
||||||
# Using the migration helper (preferred)
|
|
||||||
python migrate.py generate "migration message" # Generate migration
|
|
||||||
python migrate.py apply # Apply migrations
|
|
||||||
python migrate.py auto "message" # Generate and apply in one step
|
|
||||||
python migrate.py list # List all migrations
|
|
||||||
python migrate.py current # Show current revision
|
|
||||||
python migrate.py check # Check DB connection
|
|
||||||
|
|
||||||
# Or using Alembic directly
|
|
||||||
alembic revision --autogenerate -m "message"
|
|
||||||
alembic upgrade head
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Testing
|
|
||||||
|
|
||||||
**Test Coverage: High (comprehensive test suite)**
|
|
||||||
- Security-focused testing with JWT algorithm attack prevention (CVE-2015-9235)
|
|
||||||
- Session hijacking and privilege escalation tests included
|
|
||||||
- Missing lines justified as defensive code, error handlers, and production-only code
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Run all tests (uses pytest-xdist for parallel execution)
|
|
||||||
make test
|
|
||||||
|
|
||||||
# Run with coverage report
|
|
||||||
make test-cov
|
|
||||||
|
|
||||||
# Or run directly with uv
|
|
||||||
IS_TEST=True uv run pytest
|
|
||||||
|
|
||||||
# Run specific test file
|
|
||||||
IS_TEST=True uv run pytest tests/api/test_auth.py -v
|
|
||||||
|
|
||||||
# Run single test
|
|
||||||
IS_TEST=True uv run pytest tests/api/test_auth.py::TestLogin::test_login_success -v
|
|
||||||
```
|
|
||||||
|
|
||||||
**Available Make Commands:**
|
|
||||||
```bash
|
|
||||||
make help # Show all available commands
|
|
||||||
make install-dev # Install all dependencies
|
|
||||||
make validate # Run lint + format + type checks
|
|
||||||
make test # Run tests
|
|
||||||
make test-cov # Run tests with coverage
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Running Locally
|
|
||||||
```bash
|
|
||||||
cd backend
|
|
||||||
uv run uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
|
|
||||||
```
|
|
||||||
|
|
||||||
### Frontend
|
|
||||||
|
|
||||||
#### Setup
|
|
||||||
```bash
|
|
||||||
cd frontend
|
|
||||||
npm install
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Development
|
|
||||||
```bash
|
|
||||||
npm run dev # Start dev server on http://localhost:3000
|
|
||||||
npm run build # Production build
|
|
||||||
npm run lint # ESLint
|
|
||||||
npm run type-check # TypeScript checking
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Testing
|
|
||||||
```bash
|
|
||||||
# Unit tests (Jest)
|
|
||||||
npm test # Run all unit tests
|
|
||||||
npm run test:watch # Watch mode
|
|
||||||
npm run test:coverage # With coverage
|
|
||||||
|
|
||||||
# E2E tests (Playwright)
|
|
||||||
npm run test:e2e # Run all E2E tests
|
|
||||||
npm run test:e2e:ui # Open Playwright UI
|
|
||||||
npm run test:e2e:debug # Debug mode
|
|
||||||
npx playwright test auth-login.spec.ts # Run specific file
|
|
||||||
```
|
|
||||||
|
|
||||||
**E2E Test Best Practices:**
|
|
||||||
- Use `Promise.all()` pattern for Next.js Link navigation:
|
|
||||||
```typescript
|
|
||||||
await Promise.all([
|
|
||||||
page.waitForURL('/target', { timeout: 10000 }),
|
|
||||||
link.click()
|
|
||||||
]);
|
|
||||||
```
|
|
||||||
- Use ID-based selectors for validation errors (e.g., `#email-error`)
|
|
||||||
- Error IDs use dashes not underscores (`#new-password-error`)
|
|
||||||
- Target `.border-destructive[role="alert"]` to avoid Next.js route announcer conflicts
|
|
||||||
- Uses 12 workers in non-CI mode (`workers: 12` in `playwright.config.ts`)
|
|
||||||
- URL assertions should use regex to handle query params: `/\/auth\/login/`
|
|
||||||
|
|
||||||
### Docker
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Development (with hot reload)
|
|
||||||
docker-compose -f docker-compose.dev.yml up
|
|
||||||
|
|
||||||
# Production
|
|
||||||
docker-compose up -d
|
|
||||||
|
|
||||||
# Rebuild specific service
|
|
||||||
docker-compose build backend
|
|
||||||
docker-compose build frontend
|
|
||||||
```
|
|
||||||
|
|
||||||
## Key Architectural Patterns
|
|
||||||
|
|
||||||
### Authentication Flow
|
|
||||||
1. **Login**: `POST /api/v1/auth/login` returns access + refresh tokens
|
|
||||||
- Access token: 15 minutes expiry (JWT)
|
|
||||||
- Refresh token: 7 days expiry (JWT with JTI stored in DB)
|
|
||||||
- Session tracking with device info (IP, user agent, device ID)
|
|
||||||
|
|
||||||
2. **Token Refresh**: `POST /api/v1/auth/refresh` validates refresh token JTI
|
|
||||||
- Checks session is active in database
|
|
||||||
- Issues new access token (refresh token remains valid)
|
|
||||||
- Updates session `last_used_at`
|
|
||||||
|
|
||||||
3. **Authorization**: FastAPI dependencies in `api/dependencies/auth.py`
|
|
||||||
- `get_current_user`: Validates access token, returns User (raises 401 if invalid)
|
|
||||||
- `get_current_active_user`: Requires valid access token + active account
|
|
||||||
- `get_optional_current_user`: Accepts both authenticated and anonymous users (returns User or None)
|
|
||||||
- `get_current_superuser`: Requires superuser flag
|
|
||||||
|
|
||||||
### Database Pattern: Async SQLAlchemy
|
|
||||||
- **Engine**: Created in `core/database.py` with connection pooling
|
|
||||||
- **Sessions**: AsyncSession from `async_sessionmaker`
|
|
||||||
- **CRUD**: Base class in `crud/base.py` with common operations
|
|
||||||
- Inherits: `CRUDUser`, `CRUDSession`, `CRUDOrganization`
|
|
||||||
- Pattern: `async def get(db: AsyncSession, id: str) -> Model | None`
|
|
||||||
|
|
||||||
### Frontend State Management
|
|
||||||
- **Zustand stores**: `lib/stores/` (authStore, etc.)
|
|
||||||
- **TanStack Query**: API data fetching/caching
|
|
||||||
- **Auto-generated client**: `lib/api/generated/` from OpenAPI spec
|
|
||||||
- Generate with: `npm run generate:api` (runs `scripts/generate-api-client.sh`)
|
|
||||||
|
|
||||||
### 🔴 CRITICAL: Auth Store Dependency Injection Pattern
|
### 🔴 CRITICAL: Auth Store Dependency Injection Pattern
|
||||||
|
|
||||||
@@ -252,423 +94,160 @@ const { user, isAuthenticated } = useAuth();
|
|||||||
1. `AuthContext.tsx` - DI boundary, legitimately needs real store
|
1. `AuthContext.tsx` - DI boundary, legitimately needs real store
|
||||||
2. `client.ts` - Non-React context, uses dynamic import + `__TEST_AUTH_STORE__` check
|
2. `client.ts` - Non-React context, uses dynamic import + `__TEST_AUTH_STORE__` check
|
||||||
|
|
||||||
**See**: `frontend/docs/ARCHITECTURE_FIX_REPORT.md` for full details.
|
### E2E Test Best Practices
|
||||||
|
|
||||||
### Session Management Architecture
|
When writing or fixing Playwright tests:
|
||||||
**Database-backed session tracking** (not just JWT):
|
|
||||||
- Each refresh token has a corresponding `UserSession` record
|
|
||||||
- Tracks: device info, IP, location, last used timestamp
|
|
||||||
- Supports session revocation (logout from specific devices)
|
|
||||||
- Cleanup job removes expired sessions
|
|
||||||
|
|
||||||
### Permission System
|
**Navigation Pattern:**
|
||||||
Three-tier organization roles:
|
```typescript
|
||||||
- **Owner**: Full control (delete org, manage all members)
|
// ✅ CORRECT - Use Promise.all for Next.js Link clicks
|
||||||
- **Admin**: Can add/remove members, assign admin role (not owner)
|
await Promise.all([
|
||||||
- **Member**: Read-only organization access
|
page.waitForURL('/target', { timeout: 10000 }),
|
||||||
|
link.click()
|
||||||
|
]);
|
||||||
|
```
|
||||||
|
|
||||||
Dependencies in `api/dependencies/permissions.py`:
|
**Selectors:**
|
||||||
- `require_organization_owner`
|
- Use ID-based selectors for validation errors: `#email-error`
|
||||||
- `require_organization_admin`
|
- Error IDs use dashes not underscores: `#new-password-error`
|
||||||
- `require_organization_member`
|
- Target `.border-destructive[role="alert"]` to avoid Next.js route announcer conflicts
|
||||||
- `can_manage_organization_member` (owner or admin, but not self-demotion)
|
- Avoid generic `[role="alert"]` which matches multiple elements
|
||||||
|
|
||||||
## Testing Infrastructure
|
**URL Assertions:**
|
||||||
|
```typescript
|
||||||
|
// ✅ Use regex to handle query params
|
||||||
|
await expect(page).toHaveURL(/\/auth\/login/);
|
||||||
|
|
||||||
### Backend Test Patterns
|
// ❌ Don't use exact strings (fails with query params)
|
||||||
|
await expect(page).toHaveURL('/auth/login');
|
||||||
|
```
|
||||||
|
|
||||||
**Fixtures** (in `tests/conftest.py`):
|
**Configuration:**
|
||||||
- `async_test_db`: Fresh SQLite in-memory database per test
|
- Uses 12 workers in non-CI mode (`playwright.config.ts`)
|
||||||
- `client`: AsyncClient with test database override
|
- Reduces to 2 workers in CI for stability
|
||||||
- `async_test_user`: Pre-created regular user
|
- Tests are designed to be non-flaky with proper waits
|
||||||
- `async_test_superuser`: Pre-created superuser
|
|
||||||
- `user_token` / `superuser_token`: Access tokens for API calls
|
|
||||||
|
|
||||||
**Database Mocking for Exception Testing**:
|
### Important Implementation Details
|
||||||
|
|
||||||
|
**Authentication Testing:**
|
||||||
|
- Backend fixtures in `tests/conftest.py`:
|
||||||
|
- `async_test_db`: Fresh SQLite per test
|
||||||
|
- `async_test_user` / `async_test_superuser`: Pre-created users
|
||||||
|
- `user_token` / `superuser_token`: Access tokens for API calls
|
||||||
|
- Always use `@pytest.mark.asyncio` for async tests
|
||||||
|
- Use `@pytest_asyncio.fixture` for async fixtures
|
||||||
|
|
||||||
|
**Database Testing:**
|
||||||
```python
|
```python
|
||||||
|
# Mock database exceptions correctly
|
||||||
from unittest.mock import patch, AsyncMock
|
from unittest.mock import patch, AsyncMock
|
||||||
|
|
||||||
# Mock database commit to raise exception
|
|
||||||
async def mock_commit():
|
async def mock_commit():
|
||||||
raise OperationalError("Connection lost", {}, Exception())
|
raise OperationalError("Connection lost", {}, Exception())
|
||||||
|
|
||||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
||||||
with pytest.raises(OperationalError):
|
with pytest.raises(OperationalError):
|
||||||
await crud_method(session, obj_in=data)
|
await repo_method(session, obj_in=data)
|
||||||
mock_rollback.assert_called_once()
|
mock_rollback.assert_called_once()
|
||||||
```
|
```
|
||||||
|
|
||||||
**Testing Routes**:
|
**Frontend Component Development:**
|
||||||
```python
|
- Follow design system docs in `frontend/docs/design-system/`
|
||||||
@pytest.mark.asyncio
|
- Read `08-ai-guidelines.md` for AI code generation rules
|
||||||
async def test_endpoint(client, user_token):
|
- Use parent-controlled spacing (see `04-spacing-philosophy.md`)
|
||||||
response = await client.get(
|
- WCAG AA compliance required (see `07-accessibility.md`)
|
||||||
"/api/v1/endpoint",
|
|
||||||
headers={"Authorization": f"Bearer {user_token}"}
|
**Security Considerations:**
|
||||||
)
|
- Backend has comprehensive security tests (JWT attacks, session hijacking)
|
||||||
assert response.status_code == 200
|
- Never skip security headers in production
|
||||||
```
|
- Rate limiting is configured in route decorators: `@limiter.limit("10/minute")`
|
||||||
|
- Session revocation is database-backed, not just JWT expiry
|
||||||
**IMPORTANT**: Use `@pytest_asyncio.fixture` for async fixtures, not `@pytest.fixture`
|
- Run `make audit` to check for dependency vulnerabilities and license compliance
|
||||||
|
- Run `make check` for the full pipeline: quality + security + tests
|
||||||
### Frontend Test Patterns
|
- Pre-commit hooks enforce Ruff lint/format and detect-secrets on every commit
|
||||||
|
- Setup hooks: `cd backend && uv run pre-commit install`
|
||||||
**Unit Tests (Jest)**:
|
|
||||||
```typescript
|
### Common Workflows Guidance
|
||||||
// SSR-safe mocking
|
|
||||||
jest.mock('@/lib/stores/authStore', () => ({
|
**When Adding a New Feature:**
|
||||||
useAuthStore: jest.fn()
|
1. Start with backend schema and repository
|
||||||
}));
|
2. Implement API route with proper authorization
|
||||||
|
3. Write backend tests (aim for >90% coverage)
|
||||||
beforeEach(() => {
|
4. Generate frontend API client: `bun run generate:api`
|
||||||
(useAuthStore as jest.Mock).mockReturnValue({
|
5. Implement frontend components
|
||||||
user: mockUser,
|
6. Write frontend unit tests
|
||||||
login: mockLogin
|
7. Add E2E tests for critical flows
|
||||||
});
|
8. Update relevant documentation
|
||||||
});
|
|
||||||
```
|
**When Fixing Tests:**
|
||||||
|
- Backend: Check test database isolation and async fixture usage
|
||||||
**E2E Tests (Playwright)**:
|
- Frontend unit: Verify mocking of `useAuth()` not `useAuthStore`
|
||||||
```typescript
|
- E2E: Use `Promise.all()` pattern and regex URL assertions
|
||||||
test('navigation', async ({ page }) => {
|
|
||||||
await page.goto('/');
|
**When Debugging:**
|
||||||
|
- Backend: Check `IS_TEST=True` environment variable is set
|
||||||
const link = page.getByRole('link', { name: 'Login' });
|
- Frontend: Run `bun run type-check` first
|
||||||
await Promise.all([
|
- E2E: Use `bun run test:e2e:debug` for step-by-step debugging
|
||||||
page.waitForURL(/\/auth\/login/, { timeout: 10000 }),
|
- Check logs: Backend has detailed error logging
|
||||||
link.click()
|
|
||||||
]);
|
**Demo Mode (Frontend-Only Showcase):**
|
||||||
|
- Enable: `echo "NEXT_PUBLIC_DEMO_MODE=true" > frontend/.env.local`
|
||||||
await expect(page).toHaveURL(/\/auth\/login/);
|
- Uses MSW (Mock Service Worker) to intercept API calls in browser
|
||||||
});
|
- Zero backend required - perfect for Vercel deployments
|
||||||
```
|
- **Fully Automated**: MSW handlers auto-generated from OpenAPI spec
|
||||||
|
- Run `bun run generate:api` → updates both API client AND MSW handlers
|
||||||
## Configuration
|
- No manual synchronization needed!
|
||||||
|
- Demo credentials (any password ≥8 chars works):
|
||||||
### Environment Variables
|
- User: `demo@example.com` / `DemoPass123`
|
||||||
|
- Admin: `admin@example.com` / `AdminPass123`
|
||||||
**Backend** (`.env`):
|
- **Safe**: MSW never runs during tests (Jest or Playwright)
|
||||||
```bash
|
- **Coverage**: Mock files excluded from linting and coverage
|
||||||
# Database
|
- **Documentation**: `frontend/docs/DEMO_MODE.md` for complete guide
|
||||||
POSTGRES_USER=postgres
|
|
||||||
POSTGRES_PASSWORD=your_password
|
### Tool Usage Preferences
|
||||||
POSTGRES_HOST=db
|
|
||||||
POSTGRES_PORT=5432
|
**Prefer specialized tools over bash:**
|
||||||
POSTGRES_DB=app
|
- Use Read/Write/Edit tools for file operations
|
||||||
|
- Never use `cat`, `echo >`, or heredoc for file manipulation
|
||||||
# Security
|
- Use Task tool with `subagent_type=Explore` for codebase exploration
|
||||||
SECRET_KEY=your-secret-key-min-32-chars
|
- Use Grep tool for code search, not bash `grep`
|
||||||
ENVIRONMENT=development|production
|
|
||||||
CSP_MODE=relaxed|strict|disabled
|
**When to use parallel tool calls:**
|
||||||
|
- Independent git commands: `git status`, `git diff`, `git log`
|
||||||
# First Superuser (auto-created on init)
|
- Reading multiple unrelated files
|
||||||
FIRST_SUPERUSER_EMAIL=admin@example.com
|
- Running multiple test suites simultaneously
|
||||||
FIRST_SUPERUSER_PASSWORD=admin123
|
- Independent validation steps
|
||||||
|
|
||||||
# CORS
|
## Custom Skills
|
||||||
BACKEND_CORS_ORIGINS=["http://localhost:3000"]
|
|
||||||
```
|
No Claude Code Skills installed yet. To create one, invoke the built-in "skill-creator" skill.
|
||||||
|
|
||||||
**Frontend** (`.env.local`):
|
**Potential skill ideas for this project:**
|
||||||
```bash
|
- API endpoint generator workflow (schema → repository → route → tests → frontend client)
|
||||||
NEXT_PUBLIC_API_URL=http://localhost:8000/api/v1
|
- Component generator with design system compliance
|
||||||
```
|
- Database migration troubleshooting helper
|
||||||
|
- Test coverage analyzer and improvement suggester
|
||||||
### Database Connection Pooling
|
- E2E test generator for new features
|
||||||
Configured in `core/config.py`:
|
|
||||||
- `db_pool_size`: 20 (default connections)
|
## Additional Resources
|
||||||
- `db_max_overflow`: 50 (max overflow)
|
|
||||||
- `db_pool_timeout`: 30 seconds
|
**Comprehensive Documentation:**
|
||||||
- `db_pool_recycle`: 3600 seconds (recycle after 1 hour)
|
- [AGENTS.md](./AGENTS.md) - Framework-agnostic AI assistant context
|
||||||
|
- [README.md](./README.md) - User-facing project overview
|
||||||
### Security Headers
|
- `backend/docs/` - Backend architecture, coding standards, common pitfalls
|
||||||
Automatically applied via middleware in `main.py`:
|
- `frontend/docs/design-system/` - Complete design system guide
|
||||||
- `X-Frame-Options: DENY`
|
|
||||||
- `X-Content-Type-Options: nosniff`
|
**API Documentation (when running):**
|
||||||
- `X-XSS-Protection: 1; mode=block`
|
- Swagger UI: http://localhost:8000/docs
|
||||||
- `Strict-Transport-Security` (production only)
|
- ReDoc: http://localhost:8000/redoc
|
||||||
- Content-Security-Policy (configurable via `CSP_MODE`)
|
- OpenAPI JSON: http://localhost:8000/api/v1/openapi.json
|
||||||
|
|
||||||
### Rate Limiting
|
**Testing Documentation:**
|
||||||
- Implemented with `slowapi`
|
- Backend tests: `backend/tests/` (97% coverage)
|
||||||
- Default: 60 requests/minute per IP
|
- Frontend E2E: `frontend/e2e/README.md`
|
||||||
- Applied to auth endpoints (login, register, password reset)
|
- Design system: `frontend/docs/design-system/08-ai-guidelines.md`
|
||||||
- Override in route decorators: `@limiter.limit("10/minute")`
|
|
||||||
|
---
|
||||||
## Common Workflows
|
|
||||||
|
**For project architecture, development commands, and general context, see [AGENTS.md](./AGENTS.md).**
|
||||||
### Adding a New API Endpoint
|
|
||||||
|
|
||||||
1. **Create schema** (`backend/app/schemas/`):
|
|
||||||
```python
|
|
||||||
class ItemCreate(BaseModel):
|
|
||||||
name: str
|
|
||||||
description: Optional[str] = None
|
|
||||||
|
|
||||||
class ItemResponse(BaseModel):
|
|
||||||
id: UUID
|
|
||||||
name: str
|
|
||||||
created_at: datetime
|
|
||||||
```
|
|
||||||
|
|
||||||
2. **Create CRUD operations** (`backend/app/crud/`):
|
|
||||||
```python
|
|
||||||
class CRUDItem(CRUDBase[Item, ItemCreate, ItemUpdate]):
|
|
||||||
async def get_by_name(self, db: AsyncSession, name: str) -> Item | None:
|
|
||||||
result = await db.execute(select(Item).where(Item.name == name))
|
|
||||||
return result.scalar_one_or_none()
|
|
||||||
|
|
||||||
item = CRUDItem(Item)
|
|
||||||
```
|
|
||||||
|
|
||||||
3. **Create route** (`backend/app/api/routes/items.py`):
|
|
||||||
```python
|
|
||||||
from app.api.dependencies.auth import get_current_user
|
|
||||||
|
|
||||||
@router.post("/", response_model=ItemResponse)
|
|
||||||
async def create_item(
|
|
||||||
item_in: ItemCreate,
|
|
||||||
current_user: User = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_db)
|
|
||||||
):
|
|
||||||
item = await item_crud.create(db, obj_in=item_in)
|
|
||||||
return item
|
|
||||||
```
|
|
||||||
|
|
||||||
4. **Register router** (`backend/app/api/main.py`):
|
|
||||||
```python
|
|
||||||
from app.api.routes import items
|
|
||||||
api_router.include_router(items.router, prefix="/items", tags=["Items"])
|
|
||||||
```
|
|
||||||
|
|
||||||
5. **Write tests** (`backend/tests/api/test_items.py`):
|
|
||||||
```python
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_item(client, user_token):
|
|
||||||
response = await client.post(
|
|
||||||
"/api/v1/items",
|
|
||||||
headers={"Authorization": f"Bearer {user_token}"},
|
|
||||||
json={"name": "Test Item"}
|
|
||||||
)
|
|
||||||
assert response.status_code == 201
|
|
||||||
```
|
|
||||||
|
|
||||||
6. **Generate frontend client**:
|
|
||||||
```bash
|
|
||||||
cd frontend
|
|
||||||
npm run generate:api
|
|
||||||
```
|
|
||||||
|
|
||||||
### Adding a New React Component
|
|
||||||
|
|
||||||
1. **Create component** (`frontend/src/components/`):
|
|
||||||
```typescript
|
|
||||||
export function MyComponent() {
|
|
||||||
const { user } = useAuthStore();
|
|
||||||
return <div>Hello {user?.firstName}</div>;
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
2. **Add tests** (`frontend/src/components/__tests__/`):
|
|
||||||
```typescript
|
|
||||||
import { render, screen } from '@testing-library/react';
|
|
||||||
|
|
||||||
test('renders component', () => {
|
|
||||||
render(<MyComponent />);
|
|
||||||
expect(screen.getByText(/Hello/)).toBeInTheDocument();
|
|
||||||
});
|
|
||||||
```
|
|
||||||
|
|
||||||
3. **Add to page** (`frontend/src/app/page.tsx`):
|
|
||||||
```typescript
|
|
||||||
import { MyComponent } from '@/components/MyComponent';
|
|
||||||
|
|
||||||
export default function Page() {
|
|
||||||
return <MyComponent />;
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## Development Tooling Stack
|
|
||||||
|
|
||||||
**State-of-the-art Python tooling (Nov 2025):**
|
|
||||||
|
|
||||||
### Dependency Management: uv
|
|
||||||
- **Fast**: 10-100x faster than pip
|
|
||||||
- **Reliable**: Reproducible builds with `uv.lock` lockfile
|
|
||||||
- **Modern**: Built by Astral (Ruff creators) in Rust
|
|
||||||
- **Commands**:
|
|
||||||
- `make install-dev` - Install all dependencies
|
|
||||||
- `make sync` - Sync from lockfile
|
|
||||||
- `uv add <package>` - Add new dependency
|
|
||||||
- `uv add --dev <package>` - Add dev dependency
|
|
||||||
|
|
||||||
### Code Quality: Ruff + mypy
|
|
||||||
- **Ruff**: All-in-one linting, formatting, and import sorting
|
|
||||||
- Replaces: Black, Flake8, isort
|
|
||||||
- **10-100x faster** than alternatives
|
|
||||||
- `make lint`, `make format`, `make validate`
|
|
||||||
- **mypy**: Type checking with Pydantic plugin
|
|
||||||
- Gradual typing approach
|
|
||||||
- Strategic per-module configurations
|
|
||||||
|
|
||||||
### Configuration: pyproject.toml
|
|
||||||
- Single source of truth for all tools
|
|
||||||
- Dependencies defined in `[project.dependencies]`
|
|
||||||
- Dev dependencies in `[project.optional-dependencies]`
|
|
||||||
- Tool configs: Ruff, mypy, pytest, coverage
|
|
||||||
|
|
||||||
## Current Project Status (Nov 2025)
|
|
||||||
|
|
||||||
### Completed Features
|
|
||||||
- ✅ Authentication system (JWT with refresh tokens)
|
|
||||||
- ✅ Session management (device tracking, revocation)
|
|
||||||
- ✅ User management (CRUD, password change)
|
|
||||||
- ✅ Organization system (multi-tenant with roles)
|
|
||||||
- ✅ Admin panel (user/org management, bulk operations)
|
|
||||||
- ✅ E2E test suite (56 passing, 1 skipped, zero flaky tests)
|
|
||||||
|
|
||||||
### Test Coverage
|
|
||||||
- **Backend**: 97% overall (743 tests, all passing) ✅
|
|
||||||
- Comprehensive security testing (JWT attacks, session hijacking, privilege escalation)
|
|
||||||
- User CRUD: 100% ✅
|
|
||||||
- Session CRUD: 100% ✅
|
|
||||||
- Auth routes: 99% ✅
|
|
||||||
- Organization routes: 100% ✅
|
|
||||||
- Permissions: 100% ✅
|
|
||||||
- 84 missing lines justified (defensive code, error handlers, production-only code)
|
|
||||||
|
|
||||||
- **Frontend E2E**: 56 passing, 1 skipped across 7 files ✅
|
|
||||||
- auth-login.spec.ts (19 tests)
|
|
||||||
- auth-register.spec.ts (14 tests)
|
|
||||||
- auth-password-reset.spec.ts (10 tests)
|
|
||||||
- navigation.spec.ts (10 tests)
|
|
||||||
- settings-password.spec.ts (3 tests)
|
|
||||||
- settings-profile.spec.ts (2 tests)
|
|
||||||
- settings-navigation.spec.ts (5 tests)
|
|
||||||
- settings-sessions.spec.ts (1 skipped - route not yet implemented)
|
|
||||||
|
|
||||||
## Email Service Integration
|
|
||||||
|
|
||||||
The project includes a **placeholder email service** (`backend/app/services/email_service.py`) designed for easy integration with production email providers.
|
|
||||||
|
|
||||||
### Current Implementation
|
|
||||||
|
|
||||||
**Console Backend (Default)**:
|
|
||||||
- Logs email content to console/logs instead of sending
|
|
||||||
- Safe for development and testing
|
|
||||||
- No external dependencies required
|
|
||||||
|
|
||||||
### Production Integration
|
|
||||||
|
|
||||||
To enable email functionality, implement one of these approaches:
|
|
||||||
|
|
||||||
**Option 1: SMTP Integration** (Recommended for most use cases)
|
|
||||||
```python
|
|
||||||
# In app/services/email_service.py, complete the SMTPEmailBackend implementation
|
|
||||||
from aiosmtplib import SMTP
|
|
||||||
from email.mime.text import MIMEText
|
|
||||||
from email.mime.multipart import MIMEMultipart
|
|
||||||
|
|
||||||
# Add environment variables to .env:
|
|
||||||
# SMTP_HOST=smtp.gmail.com
|
|
||||||
# SMTP_PORT=587
|
|
||||||
# SMTP_USERNAME=your-email@gmail.com
|
|
||||||
# SMTP_PASSWORD=your-app-password
|
|
||||||
```
|
|
||||||
|
|
||||||
**Option 2: Third-Party Service** (SendGrid, AWS SES, Mailgun, etc.)
|
|
||||||
```python
|
|
||||||
# Create a new backend class, e.g., SendGridEmailBackend
|
|
||||||
class SendGridEmailBackend(EmailBackend):
|
|
||||||
def __init__(self, api_key: str):
|
|
||||||
self.api_key = api_key
|
|
||||||
self.client = sendgrid.SendGridAPIClient(api_key)
|
|
||||||
|
|
||||||
async def send_email(self, to, subject, html_content, text_content=None):
|
|
||||||
# Implement SendGrid sending logic
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Update global instance in email_service.py:
|
|
||||||
# email_service = EmailService(SendGridEmailBackend(settings.SENDGRID_API_KEY))
|
|
||||||
```
|
|
||||||
|
|
||||||
**Option 3: External Microservice**
|
|
||||||
- Use a dedicated email microservice via HTTP API
|
|
||||||
- Implement `HTTPEmailBackend` that makes async HTTP requests
|
|
||||||
|
|
||||||
### Email Templates Included
|
|
||||||
|
|
||||||
The service includes pre-built templates for:
|
|
||||||
- **Password Reset**: `send_password_reset_email()` - 1 hour expiry
|
|
||||||
- **Email Verification**: `send_email_verification()` - 24 hour expiry
|
|
||||||
|
|
||||||
Both include responsive HTML and plain text versions.
|
|
||||||
|
|
||||||
### Integration Points
|
|
||||||
|
|
||||||
Email sending is called from:
|
|
||||||
- `app/api/routes/auth.py` - Password reset flow (placeholder comments)
|
|
||||||
- Registration flow - Ready for email verification integration
|
|
||||||
|
|
||||||
**Note**: Current auth routes have placeholder comments where email functionality should be integrated. Search for "TODO: Send email" in the codebase.
|
|
||||||
|
|
||||||
## API Documentation
|
|
||||||
|
|
||||||
Once backend is running:
|
|
||||||
- **Swagger UI**: http://localhost:8000/docs
|
|
||||||
- **ReDoc**: http://localhost:8000/redoc
|
|
||||||
- **OpenAPI JSON**: http://localhost:8000/api/v1/openapi.json
|
|
||||||
|
|
||||||
## Troubleshooting
|
|
||||||
|
|
||||||
### Tests failing with "Module was never imported"
|
|
||||||
Run with single process: `pytest -n 0`
|
|
||||||
|
|
||||||
### Coverage not improving despite new tests
|
|
||||||
- Verify tests actually execute endpoints (check response.status_code)
|
|
||||||
- Generate HTML coverage: `pytest --cov=app --cov-report=html -n 0`
|
|
||||||
- Check for dependency override issues in test fixtures
|
|
||||||
|
|
||||||
### Frontend type errors
|
|
||||||
```bash
|
|
||||||
npm run type-check # Check all types
|
|
||||||
npx tsc --noEmit # Same but shorter
|
|
||||||
```
|
|
||||||
|
|
||||||
### E2E tests flaking
|
|
||||||
- Check worker count (should be 4, not 16+)
|
|
||||||
- Use `Promise.all()` for navigation
|
|
||||||
- Use regex for URL assertions
|
|
||||||
- Target specific selectors (avoid generic `[role="alert"]`)
|
|
||||||
|
|
||||||
### Database migration conflicts
|
|
||||||
```bash
|
|
||||||
python migrate.py list # Check migration history
|
|
||||||
alembic downgrade -1 # Downgrade one revision
|
|
||||||
alembic upgrade head # Re-apply
|
|
||||||
```
|
|
||||||
|
|
||||||
## Additional Documentation
|
|
||||||
|
|
||||||
### Backend Documentation
|
|
||||||
- `backend/docs/ARCHITECTURE.md`: System architecture and design patterns
|
|
||||||
- `backend/docs/CODING_STANDARDS.md`: Code quality standards and best practices
|
|
||||||
- `backend/docs/COMMON_PITFALLS.md`: Common mistakes and how to avoid them
|
|
||||||
- `backend/docs/FEATURE_EXAMPLE.md`: Step-by-step feature implementation guide
|
|
||||||
|
|
||||||
### Frontend Documentation
|
|
||||||
- **`frontend/docs/ARCHITECTURE_FIX_REPORT.md`**: ⭐ Critical DI pattern fixes (READ THIS!)
|
|
||||||
- `frontend/e2e/README.md`: E2E testing setup and guidelines
|
|
||||||
- **`frontend/docs/design-system/`**: Comprehensive design system documentation
|
|
||||||
- `README.md`: Hub with learning paths (start here)
|
|
||||||
- `00-quick-start.md`: 5-minute crash course
|
|
||||||
- `01-foundations.md`: Colors (OKLCH), typography, spacing, shadows
|
|
||||||
- `02-components.md`: shadcn/ui component library guide
|
|
||||||
- `03-layouts.md`: Layout patterns (Grid vs Flex decision trees)
|
|
||||||
- `04-spacing-philosophy.md`: Parent-controlled spacing strategy
|
|
||||||
- `05-component-creation.md`: When to create vs compose components
|
|
||||||
- `06-forms.md`: Form patterns with react-hook-form + Zod
|
|
||||||
- `07-accessibility.md`: WCAG AA compliance, keyboard navigation, screen readers
|
|
||||||
- `08-ai-guidelines.md`: **AI code generation rules (read this!)**
|
|
||||||
- `99-reference.md`: Quick reference cheat sheet (bookmark this)
|
|
||||||
|
|||||||
@@ -90,22 +90,27 @@ Ready to write some code? Awesome!
|
|||||||
```bash
|
```bash
|
||||||
cd backend
|
cd backend
|
||||||
|
|
||||||
# Setup virtual environment
|
# Install dependencies (uv manages virtual environment automatically)
|
||||||
python -m venv .venv
|
make install-dev
|
||||||
source .venv/bin/activate
|
|
||||||
|
|
||||||
# Install dependencies
|
# Setup pre-commit hooks
|
||||||
pip install -r requirements.txt
|
uv run pre-commit install
|
||||||
|
|
||||||
# Setup environment
|
# Setup environment
|
||||||
cp .env.example .env
|
cp .env.example .env
|
||||||
# Edit .env with your settings
|
# Edit .env with your settings
|
||||||
|
|
||||||
# Run migrations
|
# Run migrations
|
||||||
alembic upgrade head
|
python migrate.py apply
|
||||||
|
|
||||||
|
# Run quality + security checks
|
||||||
|
make validate-all
|
||||||
|
|
||||||
# Run tests
|
# Run tests
|
||||||
IS_TEST=True pytest
|
make test
|
||||||
|
|
||||||
|
# Run full pipeline (quality + security + tests)
|
||||||
|
make check
|
||||||
|
|
||||||
# Start dev server
|
# Start dev server
|
||||||
uvicorn app.main:app --reload
|
uvicorn app.main:app --reload
|
||||||
@@ -117,20 +122,20 @@ uvicorn app.main:app --reload
|
|||||||
cd frontend
|
cd frontend
|
||||||
|
|
||||||
# Install dependencies
|
# Install dependencies
|
||||||
npm install
|
bun install
|
||||||
|
|
||||||
# Setup environment
|
# Setup environment
|
||||||
cp .env.local.example .env.local
|
cp .env.local.example .env.local
|
||||||
|
|
||||||
# Generate API client
|
# Generate API client
|
||||||
npm run generate:api
|
bun run generate:api
|
||||||
|
|
||||||
# Run tests
|
# Run tests
|
||||||
npm test
|
bun run test
|
||||||
npm run test:e2e:ui
|
bun run test:e2e:ui
|
||||||
|
|
||||||
# Start dev server
|
# Start dev server
|
||||||
npm run dev
|
bun run dev
|
||||||
```
|
```
|
||||||
|
|
||||||
---
|
---
|
||||||
@@ -199,7 +204,7 @@ export function UserProfile({ userId }: UserProfileProps) {
|
|||||||
|
|
||||||
### Key Patterns
|
### Key Patterns
|
||||||
|
|
||||||
- **Backend**: Use CRUD pattern, keep routes thin, business logic in services
|
- **Backend**: Use repository pattern, keep routes thin, business logic in services
|
||||||
- **Frontend**: Use React Query for server state, Zustand for client state
|
- **Frontend**: Use React Query for server state, Zustand for client state
|
||||||
- **Both**: Handle errors gracefully, log appropriately, write tests
|
- **Both**: Handle errors gracefully, log appropriately, write tests
|
||||||
|
|
||||||
@@ -320,7 +325,7 @@ Fixed stuff
|
|||||||
### Before Submitting
|
### Before Submitting
|
||||||
|
|
||||||
- [ ] Code follows project style guidelines
|
- [ ] Code follows project style guidelines
|
||||||
- [ ] All tests pass locally
|
- [ ] `make check` passes (quality + security + tests) in backend
|
||||||
- [ ] New tests added for new features
|
- [ ] New tests added for new features
|
||||||
- [ ] Documentation updated if needed
|
- [ ] Documentation updated if needed
|
||||||
- [ ] No merge conflicts with `main`
|
- [ ] No merge conflicts with `main`
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
108
Makefile
108
Makefile
@@ -1,8 +1,40 @@
|
|||||||
.PHONY: dev dev-full prod down clean clean-slate
|
.PHONY: help dev dev-full prod down logs logs-dev clean clean-slate drop-db reset-db push-images deploy scan-images
|
||||||
|
|
||||||
VERSION ?= latest
|
VERSION ?= latest
|
||||||
REGISTRY := gitea.pragmazest.com/cardosofelipe/app
|
REGISTRY ?= ghcr.io/cardosofelipe/pragma-stack
|
||||||
|
|
||||||
|
# Default target
|
||||||
|
help:
|
||||||
|
@echo "FastAPI + Next.js Full-Stack Template"
|
||||||
|
@echo ""
|
||||||
|
@echo "Development:"
|
||||||
|
@echo " make dev - Start backend + db (frontend runs separately)"
|
||||||
|
@echo " make dev-full - Start all services including frontend"
|
||||||
|
@echo " make down - Stop all services"
|
||||||
|
@echo " make logs-dev - Follow dev container logs"
|
||||||
|
@echo ""
|
||||||
|
@echo "Database:"
|
||||||
|
@echo " make drop-db - Drop and recreate empty database"
|
||||||
|
@echo " make reset-db - Drop database and apply all migrations"
|
||||||
|
@echo ""
|
||||||
|
@echo "Production:"
|
||||||
|
@echo " make prod - Start production stack"
|
||||||
|
@echo " make deploy - Pull and deploy latest images"
|
||||||
|
@echo " make push-images - Build and push images to registry"
|
||||||
|
@echo " make scan-images - Scan production images for CVEs (requires trivy)"
|
||||||
|
@echo " make logs - Follow production container logs"
|
||||||
|
@echo ""
|
||||||
|
@echo "Cleanup:"
|
||||||
|
@echo " make clean - Stop containers"
|
||||||
|
@echo " make clean-slate - Stop containers AND delete volumes (DATA LOSS!)"
|
||||||
|
@echo ""
|
||||||
|
@echo "Subdirectory commands:"
|
||||||
|
@echo " cd backend && make help - Backend-specific commands"
|
||||||
|
@echo " cd frontend && npm run - Frontend-specific commands"
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Development
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
dev:
|
dev:
|
||||||
# Bring up all dev services except the frontend
|
# Bring up all dev services except the frontend
|
||||||
@@ -16,25 +48,77 @@ dev-full:
|
|||||||
# Bring up all dev services including the frontend (full stack)
|
# Bring up all dev services including the frontend (full stack)
|
||||||
docker compose -f docker-compose.dev.yml up --build -d
|
docker compose -f docker-compose.dev.yml up --build -d
|
||||||
|
|
||||||
prod:
|
|
||||||
docker compose up --build -d
|
|
||||||
|
|
||||||
down:
|
down:
|
||||||
docker compose down
|
docker compose down
|
||||||
|
|
||||||
|
logs:
|
||||||
|
docker compose logs -f
|
||||||
|
|
||||||
|
logs-dev:
|
||||||
|
docker compose -f docker-compose.dev.yml logs -f
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Database Management
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
drop-db:
|
||||||
|
@echo "Dropping local database..."
|
||||||
|
@docker compose -f docker-compose.dev.yml exec -T db psql -U postgres -c "DROP DATABASE IF EXISTS app WITH (FORCE);" 2>/dev/null || \
|
||||||
|
docker compose -f docker-compose.dev.yml exec -T db psql -U postgres -c "DROP DATABASE IF EXISTS app;"
|
||||||
|
@docker compose -f docker-compose.dev.yml exec -T db psql -U postgres -c "CREATE DATABASE app;"
|
||||||
|
@echo "Database dropped and recreated (empty)"
|
||||||
|
|
||||||
|
reset-db: drop-db
|
||||||
|
@echo "Applying migrations..."
|
||||||
|
@cd backend && uv run python migrate.py --local apply
|
||||||
|
@echo "Database reset complete!"
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Production / Deployment
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
prod:
|
||||||
|
docker compose up --build -d
|
||||||
|
|
||||||
deploy:
|
deploy:
|
||||||
docker compose -f docker-compose.deploy.yml pull
|
docker compose -f docker-compose.deploy.yml pull
|
||||||
docker compose -f docker-compose.deploy.yml up -d
|
docker compose -f docker-compose.deploy.yml up -d
|
||||||
|
|
||||||
clean:
|
|
||||||
docker compose down -
|
|
||||||
|
|
||||||
# WARNING! THIS REMOVES CONTAINERS AND VOLUMES AS WELL - DO NOT USE THIS UNLESS YOU WANT TO START OVER WITH DATA AND ALL
|
|
||||||
clean-slate:
|
|
||||||
docker compose down -v
|
|
||||||
|
|
||||||
push-images:
|
push-images:
|
||||||
docker build -t $(REGISTRY)/backend:$(VERSION) ./backend
|
docker build -t $(REGISTRY)/backend:$(VERSION) ./backend
|
||||||
docker build -t $(REGISTRY)/frontend:$(VERSION) ./frontend
|
docker build -t $(REGISTRY)/frontend:$(VERSION) ./frontend
|
||||||
docker push $(REGISTRY)/backend:$(VERSION)
|
docker push $(REGISTRY)/backend:$(VERSION)
|
||||||
docker push $(REGISTRY)/frontend:$(VERSION)
|
docker push $(REGISTRY)/frontend:$(VERSION)
|
||||||
|
|
||||||
|
scan-images:
|
||||||
|
@docker info > /dev/null 2>&1 || (echo "❌ Docker is not running!"; exit 1)
|
||||||
|
@echo "🐳 Building and scanning production images for CVEs..."
|
||||||
|
docker build -t $(REGISTRY)/backend:scan --target production ./backend
|
||||||
|
docker build -t $(REGISTRY)/frontend:scan --target runner ./frontend
|
||||||
|
@echo ""
|
||||||
|
@echo "=== Backend Image Scan ==="
|
||||||
|
@if command -v trivy > /dev/null 2>&1; then \
|
||||||
|
trivy image --severity HIGH,CRITICAL --exit-code 1 $(REGISTRY)/backend:scan; \
|
||||||
|
else \
|
||||||
|
echo "ℹ️ Trivy not found locally, using Docker to run Trivy..."; \
|
||||||
|
docker run --rm -v /var/run/docker.sock:/var/run/docker.sock aquasec/trivy image --severity HIGH,CRITICAL --exit-code 1 $(REGISTRY)/backend:scan; \
|
||||||
|
fi
|
||||||
|
@echo ""
|
||||||
|
@echo "=== Frontend Image Scan ==="
|
||||||
|
@if command -v trivy > /dev/null 2>&1; then \
|
||||||
|
trivy image --severity HIGH,CRITICAL --exit-code 1 $(REGISTRY)/frontend:scan; \
|
||||||
|
else \
|
||||||
|
docker run --rm -v /var/run/docker.sock:/var/run/docker.sock aquasec/trivy image --severity HIGH,CRITICAL --exit-code 1 $(REGISTRY)/frontend:scan; \
|
||||||
|
fi
|
||||||
|
@echo "✅ No HIGH/CRITICAL CVEs found in production images!"
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Cleanup
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
clean:
|
||||||
|
docker compose down
|
||||||
|
|
||||||
|
# WARNING! THIS REMOVES CONTAINERS AND VOLUMES AS WELL - DO NOT USE THIS UNLESS YOU WANT TO START OVER WITH DATA AND ALL
|
||||||
|
clean-slate:
|
||||||
|
docker compose -f docker-compose.dev.yml down -v --remove-orphans
|
||||||
|
|||||||
224
README.md
224
README.md
@@ -1,29 +1,29 @@
|
|||||||
# FastAPI + Next.js Full-Stack Template
|
# <img src="frontend/public/logo.svg" alt="PragmaStack" width="32" height="32" style="vertical-align: middle" /> PragmaStack
|
||||||
|
|
||||||
> **Production-ready, security-first, full-stack TypeScript/Python template with authentication, multi-tenancy, and a comprehensive admin panel.**
|
> **The Pragmatic Full-Stack Template. Production-ready, security-first, and opinionated.**
|
||||||
|
|
||||||
<!--
|
|
||||||
TODO: Replace these static badges with dynamic CI/CD badges when GitHub Actions is set up
|
|
||||||
Example: https://github.com/YOUR_ORG/YOUR_REPO/actions/workflows/backend-tests.yml/badge.svg
|
|
||||||
-->
|
|
||||||
|
|
||||||
[](./backend/tests)
|
|
||||||
[](./backend/tests)
|
[](./backend/tests)
|
||||||
[](./frontend/tests)
|
|
||||||
[](./frontend/tests)
|
[](./frontend/tests)
|
||||||
[](./frontend/e2e)
|
[](./frontend/e2e)
|
||||||
[](./LICENSE)
|
[](./LICENSE)
|
||||||
[](./CONTRIBUTING.md)
|
[](./CONTRIBUTING.md)
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Why This Template?
|
## Why PragmaStack?
|
||||||
|
|
||||||
Building a modern full-stack application from scratch means solving the same problems over and over: authentication, authorization, multi-tenancy, admin panels, session management, database migrations, API documentation, testing infrastructure...
|
Building a modern full-stack application often leads to "analysis paralysis" or "boilerplate fatigue". You spend weeks setting up authentication, testing, and linting before writing a single line of business logic.
|
||||||
|
|
||||||
**This template gives you all of that, battle-tested and ready to go.**
|
**PragmaStack cuts through the noise.**
|
||||||
|
|
||||||
Instead of spending weeks on boilerplate, you can focus on building your unique features. Whether you're building a SaaS product, an internal tool, or a side project, this template provides a rock-solid foundation with modern best practices baked in.
|
We provide a **pragmatic**, opinionated foundation that prioritizes:
|
||||||
|
- **Speed**: Ship features, not config files.
|
||||||
|
- **Robustness**: Security and testing are not optional.
|
||||||
|
- **Clarity**: Code that is easy to read and maintain.
|
||||||
|
|
||||||
|
Whether you're building a SaaS, an internal tool, or a side project, PragmaStack gives you a solid starting point without the bloat.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -31,12 +31,26 @@ Instead of spending weeks on boilerplate, you can focus on building your unique
|
|||||||
|
|
||||||
### 🔐 **Authentication & Security**
|
### 🔐 **Authentication & Security**
|
||||||
- JWT-based authentication with access + refresh tokens
|
- JWT-based authentication with access + refresh tokens
|
||||||
|
- **OAuth/Social Login** (Google, GitHub) with PKCE support
|
||||||
|
- **OAuth 2.0 Authorization Server** (MCP-ready) for third-party integrations
|
||||||
- Session management with device tracking and revocation
|
- Session management with device tracking and revocation
|
||||||
- Password reset flow (email integration ready)
|
- Password reset flow (email integration ready)
|
||||||
- Secure password hashing (bcrypt)
|
- Secure password hashing (bcrypt)
|
||||||
- CSRF protection, rate limiting, and security headers
|
- CSRF protection, rate limiting, and security headers
|
||||||
- Comprehensive security tests (JWT algorithm attacks, session hijacking, privilege escalation)
|
- Comprehensive security tests (JWT algorithm attacks, session hijacking, privilege escalation)
|
||||||
|
|
||||||
|
### 🔌 **OAuth Provider Mode (MCP Integration)**
|
||||||
|
Full OAuth 2.0 Authorization Server for Model Context Protocol (MCP) and third-party clients:
|
||||||
|
- **RFC 7636**: Authorization Code Flow with PKCE (S256 only)
|
||||||
|
- **RFC 8414**: Server metadata discovery at `/.well-known/oauth-authorization-server`
|
||||||
|
- **RFC 7662**: Token introspection endpoint
|
||||||
|
- **RFC 7009**: Token revocation endpoint
|
||||||
|
- **JWT access tokens**: Self-contained, configurable lifetime
|
||||||
|
- **Opaque refresh tokens**: Secure rotation, database-backed revocation
|
||||||
|
- **Consent management**: Users can review and revoke app permissions
|
||||||
|
- **Client management**: Admin endpoints for registering OAuth clients
|
||||||
|
- **Scopes**: `openid`, `profile`, `email`, `read:users`, `write:users`, `admin`
|
||||||
|
|
||||||
### 👥 **Multi-Tenancy & Organizations**
|
### 👥 **Multi-Tenancy & Organizations**
|
||||||
- Full organization system with role-based access control (Owner, Admin, Member)
|
- Full organization system with role-based access control (Owner, Admin, Member)
|
||||||
- Invite/remove members, manage permissions
|
- Invite/remove members, manage permissions
|
||||||
@@ -44,18 +58,35 @@ Instead of spending weeks on boilerplate, you can focus on building your unique
|
|||||||
- User can belong to multiple organizations
|
- User can belong to multiple organizations
|
||||||
|
|
||||||
### 🛠️ **Admin Panel**
|
### 🛠️ **Admin Panel**
|
||||||
- Complete user management (CRUD, activate/deactivate, bulk operations)
|
- Complete user management (full lifecycle, activate/deactivate, bulk operations)
|
||||||
- Organization management (create, edit, delete, member management)
|
- Organization management (create, edit, delete, member management)
|
||||||
- Session monitoring across all users
|
- Session monitoring across all users
|
||||||
- Real-time statistics dashboard
|
- Real-time statistics dashboard
|
||||||
- Admin-only routes with proper authorization
|
- Admin-only routes with proper authorization
|
||||||
|
|
||||||
### 🎨 **Modern Frontend**
|
### 🎨 **Modern Frontend**
|
||||||
- Next.js 15 with App Router and React 19
|
- Next.js 16 with App Router and React 19
|
||||||
- Comprehensive design system built on shadcn/ui + TailwindCSS
|
- **PragmaStack Design System** built on shadcn/ui + TailwindCSS
|
||||||
- Pre-configured theme with dark mode support (coming soon)
|
- Pre-configured theme with dark mode support (coming soon)
|
||||||
- Responsive, accessible components (WCAG AA compliant)
|
- Responsive, accessible components (WCAG AA compliant)
|
||||||
- Developer documentation at `/dev` (in progress)
|
- Rich marketing landing page with animated components
|
||||||
|
- Live component showcase and documentation at `/dev`
|
||||||
|
|
||||||
|
### 🌍 **Internationalization (i18n)**
|
||||||
|
- Built-in multi-language support with next-intl v4
|
||||||
|
- Locale-based routing (`/en/*`, `/it/*`)
|
||||||
|
- Seamless language switching with LocaleSwitcher component
|
||||||
|
- SEO-friendly URLs and metadata per locale
|
||||||
|
- Translation files for English and Italian (easily extensible)
|
||||||
|
- Type-safe translations throughout the app
|
||||||
|
|
||||||
|
### 🎯 **Content & UX Features**
|
||||||
|
- **Toast notifications** with Sonner for elegant user feedback
|
||||||
|
- **Smooth animations** powered by Framer Motion
|
||||||
|
- **Markdown rendering** with syntax highlighting (GitHub Flavored Markdown)
|
||||||
|
- **Charts and visualizations** ready with Recharts
|
||||||
|
- **SEO optimization** with dynamic sitemap and robots.txt generation
|
||||||
|
- **Session tracking UI** with device information and revocation controls
|
||||||
|
|
||||||
### 🧪 **Comprehensive Testing**
|
### 🧪 **Comprehensive Testing**
|
||||||
- **Backend Testing**: ~97% unit test coverage
|
- **Backend Testing**: ~97% unit test coverage
|
||||||
@@ -75,9 +106,10 @@ Instead of spending weeks on boilerplate, you can focus on building your unique
|
|||||||
### 📚 **Developer Experience**
|
### 📚 **Developer Experience**
|
||||||
- Auto-generated TypeScript API client from OpenAPI spec
|
- Auto-generated TypeScript API client from OpenAPI spec
|
||||||
- Interactive API documentation (Swagger + ReDoc)
|
- Interactive API documentation (Swagger + ReDoc)
|
||||||
- Database migrations with Alembic
|
- Database migrations with Alembic helper script
|
||||||
- Hot reload in development
|
- Hot reload in development for both frontend and backend
|
||||||
- Comprehensive code documentation
|
- Comprehensive code documentation and design system docs
|
||||||
|
- Live component playground at `/dev` with code examples
|
||||||
- Docker support for easy deployment
|
- Docker support for easy deployment
|
||||||
- VSCode workspace settings included
|
- VSCode workspace settings included
|
||||||
|
|
||||||
@@ -89,6 +121,68 @@ Instead of spending weeks on boilerplate, you can focus on building your unique
|
|||||||
- Health check endpoints
|
- Health check endpoints
|
||||||
- Production security headers
|
- Production security headers
|
||||||
- Rate limiting on sensitive endpoints
|
- Rate limiting on sensitive endpoints
|
||||||
|
- SEO optimization with dynamic sitemaps and robots.txt
|
||||||
|
- Multi-language SEO with locale-specific metadata
|
||||||
|
- Performance monitoring and bundle analysis
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 📸 Screenshots
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>Click to view screenshots</summary>
|
||||||
|
|
||||||
|
### Landing Page
|
||||||
|

|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
### Authentication
|
||||||
|

|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
### Admin Dashboard
|
||||||
|

|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
### Design System
|
||||||
|

|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 🎭 Demo Mode
|
||||||
|
|
||||||
|
**Try the frontend without a backend!** Perfect for:
|
||||||
|
- **Free deployment** on Vercel (no backend costs)
|
||||||
|
- **Portfolio showcasing** with live demos
|
||||||
|
- **Client presentations** without infrastructure setup
|
||||||
|
|
||||||
|
### Quick Start
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd frontend
|
||||||
|
echo "NEXT_PUBLIC_DEMO_MODE=true" > .env.local
|
||||||
|
bun run dev
|
||||||
|
```
|
||||||
|
|
||||||
|
**Demo Credentials:**
|
||||||
|
- Regular user: `demo@example.com` / `DemoPass123`
|
||||||
|
- Admin user: `admin@example.com` / `AdminPass123`
|
||||||
|
|
||||||
|
Demo mode uses [Mock Service Worker (MSW)](https://mswjs.io/) to intercept API calls in the browser. Your code remains unchanged - the same components work with both real and mocked backends.
|
||||||
|
|
||||||
|
**Key Features:**
|
||||||
|
- ✅ Zero backend required
|
||||||
|
- ✅ All features functional (auth, admin, stats)
|
||||||
|
- ✅ Realistic network delays and errors
|
||||||
|
- ✅ Does NOT interfere with tests (97%+ coverage maintained)
|
||||||
|
- ✅ One-line toggle: `NEXT_PUBLIC_DEMO_MODE=true`
|
||||||
|
|
||||||
|
📖 **[Complete Demo Mode Documentation](./frontend/docs/DEMO_MODE.md)**
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -103,13 +197,18 @@ Instead of spending weeks on boilerplate, you can focus on building your unique
|
|||||||
- **[pytest](https://pytest.org/)** - Testing framework with async support
|
- **[pytest](https://pytest.org/)** - Testing framework with async support
|
||||||
|
|
||||||
### Frontend
|
### Frontend
|
||||||
- **[Next.js 15](https://nextjs.org/)** - React framework with App Router
|
- **[Next.js 16](https://nextjs.org/)** - React framework with App Router
|
||||||
- **[React 19](https://react.dev/)** - UI library
|
- **[React 19](https://react.dev/)** - UI library
|
||||||
- **[TypeScript](https://www.typescriptlang.org/)** - Type-safe JavaScript
|
- **[TypeScript](https://www.typescriptlang.org/)** - Type-safe JavaScript
|
||||||
- **[TailwindCSS](https://tailwindcss.com/)** - Utility-first CSS framework
|
- **[TailwindCSS](https://tailwindcss.com/)** - Utility-first CSS framework
|
||||||
- **[shadcn/ui](https://ui.shadcn.com/)** - Beautiful, accessible component library
|
- **[shadcn/ui](https://ui.shadcn.com/)** - Beautiful, accessible component library
|
||||||
|
- **[next-intl](https://next-intl.dev/)** - Internationalization (i18n) with type safety
|
||||||
- **[TanStack Query](https://tanstack.com/query)** - Powerful data fetching/caching
|
- **[TanStack Query](https://tanstack.com/query)** - Powerful data fetching/caching
|
||||||
- **[Zustand](https://zustand-demo.pmnd.rs/)** - Lightweight state management
|
- **[Zustand](https://zustand-demo.pmnd.rs/)** - Lightweight state management
|
||||||
|
- **[Framer Motion](https://www.framer.com/motion/)** - Production-ready animation library
|
||||||
|
- **[Sonner](https://sonner.emilkowal.ski/)** - Beautiful toast notifications
|
||||||
|
- **[Recharts](https://recharts.org/)** - Composable charting library
|
||||||
|
- **[React Markdown](https://github.com/remarkjs/react-markdown)** - Markdown rendering with GFM support
|
||||||
- **[Playwright](https://playwright.dev/)** - End-to-end testing
|
- **[Playwright](https://playwright.dev/)** - End-to-end testing
|
||||||
|
|
||||||
### DevOps
|
### DevOps
|
||||||
@@ -135,12 +234,11 @@ The fastest way to get started is with Docker:
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Clone the repository
|
# Clone the repository
|
||||||
git clone https://github.com/yourusername/fast-next-template.git
|
git clone https://github.com/cardosofelipe/pragma-stack.git
|
||||||
cd fast-next-template
|
cd fast-next-template
|
||||||
|
|
||||||
# Copy environment files
|
# Copy environment file
|
||||||
cp backend/.env.example backend/.env
|
cp .env.template .env
|
||||||
cp frontend/.env.local.example frontend/.env.local
|
|
||||||
|
|
||||||
# Start all services (backend, frontend, database)
|
# Start all services (backend, frontend, database)
|
||||||
docker-compose up
|
docker-compose up
|
||||||
@@ -200,17 +298,17 @@ uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
|
|||||||
cd frontend
|
cd frontend
|
||||||
|
|
||||||
# Install dependencies
|
# Install dependencies
|
||||||
npm install
|
bun install
|
||||||
|
|
||||||
# Setup environment
|
# Setup environment
|
||||||
cp .env.local.example .env.local
|
cp .env.local.example .env.local
|
||||||
# Edit .env.local with your backend URL
|
# Edit .env.local with your backend URL
|
||||||
|
|
||||||
# Generate API client
|
# Generate API client
|
||||||
npm run generate:api
|
bun run generate:api
|
||||||
|
|
||||||
# Start development server
|
# Start development server
|
||||||
npm run dev
|
bun run dev
|
||||||
```
|
```
|
||||||
|
|
||||||
Visit http://localhost:3000 to see your app!
|
Visit http://localhost:3000 to see your app!
|
||||||
@@ -224,7 +322,7 @@ Visit http://localhost:3000 to see your app!
|
|||||||
│ ├── app/
|
│ ├── app/
|
||||||
│ │ ├── api/ # API routes and dependencies
|
│ │ ├── api/ # API routes and dependencies
|
||||||
│ │ ├── core/ # Core functionality (auth, config, database)
|
│ │ ├── core/ # Core functionality (auth, config, database)
|
||||||
│ │ ├── crud/ # Database operations
|
│ │ ├── repositories/ # Repository pattern (database operations)
|
||||||
│ │ ├── models/ # SQLAlchemy models
|
│ │ ├── models/ # SQLAlchemy models
|
||||||
│ │ ├── schemas/ # Pydantic schemas
|
│ │ ├── schemas/ # Pydantic schemas
|
||||||
│ │ ├── services/ # Business logic
|
│ │ ├── services/ # Business logic
|
||||||
@@ -279,7 +377,7 @@ open htmlcov/index.html
|
|||||||
```
|
```
|
||||||
|
|
||||||
**Test types:**
|
**Test types:**
|
||||||
- **Unit tests**: CRUD operations, utilities, business logic
|
- **Unit tests**: Repository operations, utilities, business logic
|
||||||
- **Integration tests**: API endpoints with database
|
- **Integration tests**: API endpoints with database
|
||||||
- **Security tests**: JWT algorithm attacks, session hijacking, privilege escalation
|
- **Security tests**: JWT algorithm attacks, session hijacking, privilege escalation
|
||||||
- **Error handling tests**: Database failures, validation errors
|
- **Error handling tests**: Database failures, validation errors
|
||||||
@@ -292,13 +390,13 @@ open htmlcov/index.html
|
|||||||
cd frontend
|
cd frontend
|
||||||
|
|
||||||
# Run unit tests
|
# Run unit tests
|
||||||
npm test
|
bun run test
|
||||||
|
|
||||||
# Run with coverage
|
# Run with coverage
|
||||||
npm run test:coverage
|
bun run test:coverage
|
||||||
|
|
||||||
# Watch mode
|
# Watch mode
|
||||||
npm run test:watch
|
bun run test:watch
|
||||||
```
|
```
|
||||||
|
|
||||||
**Test types:**
|
**Test types:**
|
||||||
@@ -316,10 +414,10 @@ npm run test:watch
|
|||||||
cd frontend
|
cd frontend
|
||||||
|
|
||||||
# Run E2E tests
|
# Run E2E tests
|
||||||
npm run test:e2e
|
bun run test:e2e
|
||||||
|
|
||||||
# Run E2E tests in UI mode (recommended for development)
|
# Run E2E tests in UI mode (recommended for development)
|
||||||
npm run test:e2e:ui
|
bun run test:e2e:ui
|
||||||
|
|
||||||
# Run specific test file
|
# Run specific test file
|
||||||
npx playwright test auth-login.spec.ts
|
npx playwright test auth-login.spec.ts
|
||||||
@@ -338,6 +436,17 @@ npx playwright show-report
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
## 🤖 AI-Friendly Documentation
|
||||||
|
|
||||||
|
This project includes comprehensive documentation designed for AI coding assistants:
|
||||||
|
|
||||||
|
- **[AGENTS.md](./AGENTS.md)** - Framework-agnostic AI assistant context for PragmaStack
|
||||||
|
- **[CLAUDE.md](./CLAUDE.md)** - Claude Code-specific guidance
|
||||||
|
|
||||||
|
These files provide AI assistants with the **PragmaStack** architecture, patterns, and best practices.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## 🗄️ Database Migrations
|
## 🗄️ Database Migrations
|
||||||
|
|
||||||
The template uses Alembic for database migrations:
|
The template uses Alembic for database migrations:
|
||||||
@@ -365,22 +474,25 @@ python migrate.py current
|
|||||||
|
|
||||||
## 📖 Documentation
|
## 📖 Documentation
|
||||||
|
|
||||||
|
### AI Assistant Documentation
|
||||||
|
|
||||||
|
- **[AGENTS.md](./AGENTS.md)** - Framework-agnostic AI coding assistant context
|
||||||
|
- **[CLAUDE.md](./CLAUDE.md)** - Claude Code-specific guidance and preferences
|
||||||
|
|
||||||
### Backend Documentation
|
### Backend Documentation
|
||||||
|
|
||||||
- **[ARCHITECTURE.md](./backend/docs/ARCHITECTURE.md)** - System architecture and design patterns
|
- **[ARCHITECTURE.md](./backend/docs/ARCHITECTURE.md)** - System architecture and design patterns
|
||||||
- **[CODING_STANDARDS.md](./backend/docs/CODING_STANDARDS.md)** - Code quality standards
|
- **[CODING_STANDARDS.md](./backend/docs/CODING_STANDARDS.md)** - Code quality standards
|
||||||
- **[COMMON_PITFALLS.md](./backend/docs/COMMON_PITFALLS.md)** - Common mistakes to avoid
|
- **[COMMON_PITFALLS.md](./backend/docs/COMMON_PITFALLS.md)** - Common mistakes to avoid
|
||||||
- **[FEATURE_EXAMPLE.md](./backend/docs/FEATURE_EXAMPLE.md)** - Step-by-step feature guide
|
- **[FEATURE_EXAMPLE.md](./backend/docs/FEATURE_EXAMPLE.md)** - Step-by-step feature guide
|
||||||
- **[CLAUDE.md](./CLAUDE.md)** - Comprehensive development guide
|
|
||||||
|
|
||||||
### Frontend Documentation
|
### Frontend Documentation
|
||||||
|
|
||||||
- **[Design System Docs](./frontend/docs/design-system/)** - Complete design system guide
|
- **[PragmaStack Design System](./frontend/docs/design-system/)** - Complete design system guide
|
||||||
- Quick start, foundations (colors, typography, spacing)
|
- Quick start, foundations (colors, typography, spacing)
|
||||||
- Component library guide
|
- Component library guide
|
||||||
- Layout patterns, spacing philosophy
|
- Layout patterns, spacing philosophy
|
||||||
- Forms, accessibility, AI guidelines
|
- Forms, accessibility, AI guidelines
|
||||||
- **[ARCHITECTURE_FIX_REPORT.md](./frontend/docs/ARCHITECTURE_FIX_REPORT.md)** - Critical dependency injection patterns
|
|
||||||
- **[E2E Testing Guide](./frontend/e2e/README.md)** - E2E testing setup and best practices
|
- **[E2E Testing Guide](./frontend/e2e/README.md)** - E2E testing setup and best practices
|
||||||
|
|
||||||
### API Documentation
|
### API Documentation
|
||||||
@@ -429,37 +541,43 @@ docker-compose down
|
|||||||
## 🛣️ Roadmap & Status
|
## 🛣️ Roadmap & Status
|
||||||
|
|
||||||
### ✅ Completed
|
### ✅ Completed
|
||||||
- [x] Authentication system (JWT, refresh tokens, session management)
|
- [x] Authentication system (JWT, refresh tokens, session management, OAuth)
|
||||||
- [x] User management (CRUD, profile, password change)
|
- [x] User management (full lifecycle, profile, password change)
|
||||||
- [x] Organization system with RBAC (Owner, Admin, Member)
|
- [x] Organization system with RBAC (Owner, Admin, Member)
|
||||||
- [x] Admin panel (users, organizations, sessions, statistics)
|
- [x] Admin panel (users, organizations, sessions, statistics)
|
||||||
|
- [x] **Internationalization (i18n)** with next-intl (English + Italian)
|
||||||
- [x] Backend testing infrastructure (~97% coverage)
|
- [x] Backend testing infrastructure (~97% coverage)
|
||||||
- [x] Frontend unit testing infrastructure (~97% coverage)
|
- [x] Frontend unit testing infrastructure (~97% coverage)
|
||||||
- [x] Frontend E2E testing (Playwright, zero flaky tests)
|
- [x] Frontend E2E testing (Playwright, zero flaky tests)
|
||||||
- [x] Design system documentation
|
- [x] Design system documentation
|
||||||
- [x] Database migrations
|
- [x] **Marketing landing page** with animated components
|
||||||
|
- [x] **`/dev` documentation portal** with live component examples
|
||||||
|
- [x] **Toast notifications** system (Sonner)
|
||||||
|
- [x] **Charts and visualizations** (Recharts)
|
||||||
|
- [x] **Animation system** (Framer Motion)
|
||||||
|
- [x] **Markdown rendering** with syntax highlighting
|
||||||
|
- [x] **SEO optimization** (sitemap, robots.txt, locale-aware metadata)
|
||||||
|
- [x] Database migrations with helper script
|
||||||
- [x] Docker deployment
|
- [x] Docker deployment
|
||||||
- [x] API documentation (OpenAPI/Swagger)
|
- [x] API documentation (OpenAPI/Swagger)
|
||||||
|
|
||||||
### 🚧 In Progress
|
### 🚧 In Progress
|
||||||
- [ ] Frontend admin pages (70% complete)
|
|
||||||
- [ ] Dark mode theme
|
|
||||||
- [ ] `/dev` documentation page with examples
|
|
||||||
- [ ] Email integration (templates ready, SMTP pending)
|
- [ ] Email integration (templates ready, SMTP pending)
|
||||||
- [ ] Chart/visualization components
|
|
||||||
|
|
||||||
### 🔮 Planned
|
### 🔮 Planned
|
||||||
- [ ] GitHub Actions CI/CD pipelines
|
- [ ] GitHub Actions CI/CD pipelines
|
||||||
- [ ] Dynamic test coverage badges from CI
|
- [ ] Dynamic test coverage badges from CI
|
||||||
- [ ] E2E test coverage reporting
|
- [ ] E2E test coverage reporting
|
||||||
- [ ] Additional authentication methods (OAuth, SSO)
|
- [ ] OAuth token encryption at rest (security hardening)
|
||||||
|
- [ ] Additional languages (Spanish, French, German, etc.)
|
||||||
|
- [ ] SSO/SAML authentication
|
||||||
|
- [ ] Real-time notifications with WebSockets
|
||||||
- [ ] Webhook system
|
- [ ] Webhook system
|
||||||
- [ ] Background job processing
|
- [ ] File upload/storage (S3-compatible)
|
||||||
- [ ] File upload/storage
|
- [ ] Audit logging system
|
||||||
- [ ] Notification system
|
|
||||||
- [ ] Audit logging
|
|
||||||
- [ ] API versioning example
|
- [ ] API versioning example
|
||||||
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 🤝 Contributing
|
## 🤝 Contributing
|
||||||
@@ -489,7 +607,7 @@ Contributions are welcome! Whether you're fixing bugs, improving documentation,
|
|||||||
|
|
||||||
### Reporting Issues
|
### Reporting Issues
|
||||||
|
|
||||||
Found a bug? Have a suggestion? [Open an issue](https://github.com/yourusername/fast-next-template/issues)!
|
Found a bug? Have a suggestion? [Open an issue](https://github.com/cardosofelipe/pragma-stack/issues)!
|
||||||
|
|
||||||
Please include:
|
Please include:
|
||||||
- Clear description of the issue/suggestion
|
- Clear description of the issue/suggestion
|
||||||
@@ -523,8 +641,8 @@ This template is built on the shoulders of giants:
|
|||||||
## 💬 Questions?
|
## 💬 Questions?
|
||||||
|
|
||||||
- **Documentation**: Check the `/docs` folders in backend and frontend
|
- **Documentation**: Check the `/docs` folders in backend and frontend
|
||||||
- **Issues**: [GitHub Issues](https://github.com/yourusername/fast-next-template/issues)
|
- **Issues**: [GitHub Issues](https://github.com/cardosofelipe/pragma-stack/issues)
|
||||||
- **Discussions**: [GitHub Discussions](https://github.com/yourusername/fast-next-template/discussions)
|
- **Discussions**: [GitHub Discussions](https://github.com/cardosofelipe/pragma-stack/discussions)
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|||||||
@@ -11,16 +11,19 @@ omit =
|
|||||||
app/utils/auth_test_utils.py
|
app/utils/auth_test_utils.py
|
||||||
|
|
||||||
# Async implementations not yet in use
|
# Async implementations not yet in use
|
||||||
app/crud/base_async.py
|
app/repositories/base_async.py
|
||||||
app/core/database_async.py
|
app/core/database_async.py
|
||||||
|
|
||||||
|
# CLI scripts - run manually, not tested
|
||||||
|
app/init_db.py
|
||||||
|
|
||||||
# __init__ files with no logic
|
# __init__ files with no logic
|
||||||
app/__init__.py
|
app/__init__.py
|
||||||
app/api/__init__.py
|
app/api/__init__.py
|
||||||
app/api/routes/__init__.py
|
app/api/routes/__init__.py
|
||||||
app/api/dependencies/__init__.py
|
app/api/dependencies/__init__.py
|
||||||
app/core/__init__.py
|
app/core/__init__.py
|
||||||
app/crud/__init__.py
|
app/repositories/__init__.py
|
||||||
app/models/__init__.py
|
app/models/__init__.py
|
||||||
app/schemas/__init__.py
|
app/schemas/__init__.py
|
||||||
app/services/__init__.py
|
app/services/__init__.py
|
||||||
|
|||||||
@@ -1,2 +1,17 @@
|
|||||||
.venv
|
.venv
|
||||||
*.iml
|
*.iml
|
||||||
|
|
||||||
|
# Python build and cache artifacts
|
||||||
|
__pycache__/
|
||||||
|
.pytest_cache/
|
||||||
|
.mypy_cache/
|
||||||
|
.ruff_cache/
|
||||||
|
*.pyc
|
||||||
|
*.pyo
|
||||||
|
|
||||||
|
# Packaging artifacts
|
||||||
|
*.egg-info/
|
||||||
|
build/
|
||||||
|
dist/
|
||||||
|
htmlcov/
|
||||||
|
.uv_cache/
|
||||||
44
backend/.pre-commit-config.yaml
Normal file
44
backend/.pre-commit-config.yaml
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
# Pre-commit hooks for backend quality and security checks.
|
||||||
|
#
|
||||||
|
# Install:
|
||||||
|
# cd backend && uv run pre-commit install
|
||||||
|
#
|
||||||
|
# Run manually on all files:
|
||||||
|
# cd backend && uv run pre-commit run --all-files
|
||||||
|
#
|
||||||
|
# Skip hooks temporarily:
|
||||||
|
# git commit --no-verify
|
||||||
|
#
|
||||||
|
repos:
|
||||||
|
# ── Code Quality ──────────────────────────────────────────────────────────
|
||||||
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
|
rev: v0.14.4
|
||||||
|
hooks:
|
||||||
|
- id: ruff
|
||||||
|
args: [--fix, --exit-non-zero-on-fix]
|
||||||
|
- id: ruff-format
|
||||||
|
|
||||||
|
# ── General File Hygiene ──────────────────────────────────────────────────
|
||||||
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
|
rev: v5.0.0
|
||||||
|
hooks:
|
||||||
|
- id: trailing-whitespace
|
||||||
|
- id: end-of-file-fixer
|
||||||
|
- id: check-yaml
|
||||||
|
- id: check-toml
|
||||||
|
- id: check-merge-conflict
|
||||||
|
- id: check-added-large-files
|
||||||
|
args: [--maxkb=500]
|
||||||
|
- id: debug-statements
|
||||||
|
|
||||||
|
# ── Security ──────────────────────────────────────────────────────────────
|
||||||
|
- repo: https://github.com/Yelp/detect-secrets
|
||||||
|
rev: v1.5.0
|
||||||
|
hooks:
|
||||||
|
- id: detect-secrets
|
||||||
|
args: ['--baseline', '.secrets.baseline']
|
||||||
|
exclude: |
|
||||||
|
(?x)^(
|
||||||
|
.*\.lock$|
|
||||||
|
.*\.svg$
|
||||||
|
)$
|
||||||
1073
backend/.secrets.baseline
Normal file
1073
backend/.secrets.baseline
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,9 +1,6 @@
|
|||||||
# Development stage
|
# Development stage
|
||||||
FROM python:3.12-slim AS development
|
FROM python:3.12-slim AS development
|
||||||
|
|
||||||
# Create non-root user
|
|
||||||
RUN groupadd -r appuser && useradd -r -g appuser appuser
|
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
ENV PYTHONDONTWRITEBYTECODE=1 \
|
ENV PYTHONDONTWRITEBYTECODE=1 \
|
||||||
PYTHONUNBUFFERED=1 \
|
PYTHONUNBUFFERED=1 \
|
||||||
@@ -31,19 +28,16 @@ COPY . .
|
|||||||
COPY entrypoint.sh /usr/local/bin/
|
COPY entrypoint.sh /usr/local/bin/
|
||||||
RUN chmod +x /usr/local/bin/entrypoint.sh
|
RUN chmod +x /usr/local/bin/entrypoint.sh
|
||||||
|
|
||||||
# Set ownership to non-root user
|
# Note: Running as root in development for bind mount compatibility
|
||||||
RUN chown -R appuser:appuser /app
|
# Production stage uses non-root user for security
|
||||||
|
|
||||||
# Switch to non-root user
|
|
||||||
USER appuser
|
|
||||||
|
|
||||||
ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
|
ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
|
||||||
|
|
||||||
# Production stage
|
# Production stage — Alpine eliminates glibc CVEs (e.g. CVE-2026-0861)
|
||||||
FROM python:3.12-slim AS production
|
FROM python:3.12-alpine AS production
|
||||||
|
|
||||||
# Create non-root user
|
# Create non-root user
|
||||||
RUN groupadd -r appuser && useradd -r -g appuser appuser
|
RUN addgroup -S appuser && adduser -S -G appuser appuser
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
ENV PYTHONDONTWRITEBYTECODE=1 \
|
ENV PYTHONDONTWRITEBYTECODE=1 \
|
||||||
@@ -54,18 +48,18 @@ ENV PYTHONDONTWRITEBYTECODE=1 \
|
|||||||
UV_NO_CACHE=1
|
UV_NO_CACHE=1
|
||||||
|
|
||||||
# Install system dependencies and uv
|
# Install system dependencies and uv
|
||||||
RUN apt-get update && \
|
RUN apk add --no-cache postgresql-client curl ca-certificates && \
|
||||||
apt-get install -y --no-install-recommends postgresql-client curl ca-certificates && \
|
|
||||||
curl -LsSf https://astral.sh/uv/install.sh | sh && \
|
curl -LsSf https://astral.sh/uv/install.sh | sh && \
|
||||||
mv /root/.local/bin/uv* /usr/local/bin/ && \
|
mv /root/.local/bin/uv* /usr/local/bin/
|
||||||
apt-get clean && \
|
|
||||||
rm -rf /var/lib/apt/lists/*
|
|
||||||
|
|
||||||
# Copy dependency files
|
# Copy dependency files
|
||||||
COPY pyproject.toml uv.lock ./
|
COPY pyproject.toml uv.lock ./
|
||||||
|
|
||||||
# Install only production dependencies using uv (no dev dependencies)
|
# Install build dependencies, compile Python packages, then remove build deps
|
||||||
RUN uv sync --frozen --no-dev
|
RUN apk add --no-cache --virtual .build-deps \
|
||||||
|
gcc g++ musl-dev python3-dev linux-headers libffi-dev openssl-dev && \
|
||||||
|
uv sync --frozen --no-dev && \
|
||||||
|
apk del .build-deps
|
||||||
|
|
||||||
# Copy application code
|
# Copy application code
|
||||||
COPY . .
|
COPY . .
|
||||||
|
|||||||
140
backend/Makefile
140
backend/Makefile
@@ -1,4 +1,7 @@
|
|||||||
.PHONY: help lint lint-fix format format-check type-check test test-cov validate clean install-dev sync
|
.PHONY: help lint lint-fix format format-check type-check test test-cov validate clean install-dev sync check-docker install-e2e test-e2e test-e2e-schema test-all dep-audit license-check audit validate-all check benchmark benchmark-check benchmark-save scan-image test-api-security
|
||||||
|
|
||||||
|
# Prevent a stale VIRTUAL_ENV in the caller's shell from confusing uv
|
||||||
|
unexport VIRTUAL_ENV
|
||||||
|
|
||||||
# Default target
|
# Default target
|
||||||
help:
|
help:
|
||||||
@@ -6,6 +9,7 @@ help:
|
|||||||
@echo ""
|
@echo ""
|
||||||
@echo "Setup:"
|
@echo "Setup:"
|
||||||
@echo " make install-dev - Install all dependencies with uv (includes dev)"
|
@echo " make install-dev - Install all dependencies with uv (includes dev)"
|
||||||
|
@echo " make install-e2e - Install E2E test dependencies (requires Docker)"
|
||||||
@echo " make sync - Sync dependencies from uv.lock"
|
@echo " make sync - Sync dependencies from uv.lock"
|
||||||
@echo ""
|
@echo ""
|
||||||
@echo "Quality Checks:"
|
@echo "Quality Checks:"
|
||||||
@@ -13,12 +17,30 @@ help:
|
|||||||
@echo " make lint-fix - Run Ruff linter with auto-fix"
|
@echo " make lint-fix - Run Ruff linter with auto-fix"
|
||||||
@echo " make format - Format code with Ruff"
|
@echo " make format - Format code with Ruff"
|
||||||
@echo " make format-check - Check if code is formatted"
|
@echo " make format-check - Check if code is formatted"
|
||||||
@echo " make type-check - Run mypy type checking"
|
@echo " make type-check - Run pyright type checking"
|
||||||
@echo " make validate - Run all checks (lint + format + types)"
|
@echo " make validate - Run all checks (lint + format + types + schema fuzz)"
|
||||||
|
@echo ""
|
||||||
|
@echo "Performance:"
|
||||||
|
@echo " make benchmark - Run performance benchmarks"
|
||||||
|
@echo " make benchmark-save - Run benchmarks and save as baseline"
|
||||||
|
@echo " make benchmark-check - Run benchmarks and compare against baseline"
|
||||||
|
@echo ""
|
||||||
|
@echo "Security & Audit:"
|
||||||
|
@echo " make dep-audit - Scan dependencies for known vulnerabilities"
|
||||||
|
@echo " make license-check - Check dependency license compliance"
|
||||||
|
@echo " make audit - Run all security audits (deps + licenses)"
|
||||||
|
@echo " make scan-image - Scan Docker image for CVEs (requires trivy)"
|
||||||
|
@echo " make validate-all - Run all quality + security checks"
|
||||||
|
@echo " make check - Full pipeline: quality + security + tests"
|
||||||
@echo ""
|
@echo ""
|
||||||
@echo "Testing:"
|
@echo "Testing:"
|
||||||
@echo " make test - Run pytest"
|
@echo " make test - Run pytest (unit/integration, SQLite)"
|
||||||
@echo " make test-cov - Run pytest with coverage report"
|
@echo " make test-cov - Run pytest with coverage report"
|
||||||
|
@echo " make test-e2e - Run E2E tests (PostgreSQL, requires Docker)"
|
||||||
|
@echo " make test-e2e-schema - Run Schemathesis API schema tests"
|
||||||
|
@echo " make test-all - Run all tests (unit + E2E)"
|
||||||
|
@echo " make check-docker - Check if Docker is available"
|
||||||
|
@echo " make check - Full pipeline: quality + security + tests"
|
||||||
@echo ""
|
@echo ""
|
||||||
@echo "Cleanup:"
|
@echo "Cleanup:"
|
||||||
@echo " make clean - Remove cache and build artifacts"
|
@echo " make clean - Remove cache and build artifacts"
|
||||||
@@ -58,12 +80,52 @@ format-check:
|
|||||||
@uv run ruff format --check app/ tests/
|
@uv run ruff format --check app/ tests/
|
||||||
|
|
||||||
type-check:
|
type-check:
|
||||||
@echo "🔎 Running mypy type checking..."
|
@echo "🔎 Running pyright type checking..."
|
||||||
@uv run mypy app/
|
@uv run pyright app/
|
||||||
|
|
||||||
validate: lint format-check type-check
|
validate: lint format-check type-check test-api-security
|
||||||
@echo "✅ All quality checks passed!"
|
@echo "✅ All quality checks passed!"
|
||||||
|
|
||||||
|
# API Security Testing (Schemathesis property-based fuzzing)
|
||||||
|
test-api-security: check-docker
|
||||||
|
@echo "🔐 Running Schemathesis API security fuzzing..."
|
||||||
|
@IS_TEST=True PYTHONPATH=. uv run pytest tests/e2e/ -v -m "schemathesis" --tb=short -n 0
|
||||||
|
@echo "✅ API schema security tests passed!"
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Security & Audit
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
dep-audit:
|
||||||
|
@echo "🔒 Scanning dependencies for known vulnerabilities..."
|
||||||
|
@uv run pip-audit --desc --skip-editable
|
||||||
|
@echo "✅ No known vulnerabilities found!"
|
||||||
|
|
||||||
|
license-check:
|
||||||
|
@echo "📜 Checking dependency license compliance..."
|
||||||
|
@uv run pip-licenses --fail-on="GPL-3.0-or-later;AGPL-3.0-or-later" --format=plain > /dev/null
|
||||||
|
@echo "✅ All dependency licenses are compliant!"
|
||||||
|
|
||||||
|
audit: dep-audit license-check
|
||||||
|
@echo "✅ All security audits passed!"
|
||||||
|
|
||||||
|
scan-image: check-docker
|
||||||
|
@echo "🐳 Scanning Docker image for OS-level CVEs with Trivy..."
|
||||||
|
@docker build -t pragma-backend:scan -q --target production .
|
||||||
|
@if command -v trivy > /dev/null 2>&1; then \
|
||||||
|
trivy image --severity HIGH,CRITICAL --exit-code 1 pragma-backend:scan; \
|
||||||
|
else \
|
||||||
|
echo "ℹ️ Trivy not found locally, using Docker to run Trivy..."; \
|
||||||
|
docker run --rm -v /var/run/docker.sock:/var/run/docker.sock aquasec/trivy image --severity HIGH,CRITICAL --exit-code 1 pragma-backend:scan; \
|
||||||
|
fi
|
||||||
|
@echo "✅ No HIGH/CRITICAL CVEs found in Docker image!"
|
||||||
|
|
||||||
|
validate-all: validate audit
|
||||||
|
@echo "✅ All quality + security checks passed!"
|
||||||
|
|
||||||
|
check: validate-all test
|
||||||
|
@echo "✅ Full validation pipeline complete!"
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
# Testing
|
# Testing
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
@@ -77,6 +139,68 @@ test-cov:
|
|||||||
@IS_TEST=True PYTHONPATH=. uv run pytest --cov=app --cov-report=term-missing --cov-report=html -n 16
|
@IS_TEST=True PYTHONPATH=. uv run pytest --cov=app --cov-report=term-missing --cov-report=html -n 16
|
||||||
@echo "📊 Coverage report generated in htmlcov/index.html"
|
@echo "📊 Coverage report generated in htmlcov/index.html"
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# E2E Testing (requires Docker)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
check-docker:
|
||||||
|
@docker info > /dev/null 2>&1 || (echo ""; \
|
||||||
|
echo "Docker is not running!"; \
|
||||||
|
echo ""; \
|
||||||
|
echo "E2E tests require Docker to be running."; \
|
||||||
|
echo "Please start Docker Desktop or Docker Engine and try again."; \
|
||||||
|
echo ""; \
|
||||||
|
echo "Quick start:"; \
|
||||||
|
echo " macOS/Windows: Open Docker Desktop"; \
|
||||||
|
echo " Linux: sudo systemctl start docker"; \
|
||||||
|
echo ""; \
|
||||||
|
exit 1)
|
||||||
|
@echo "Docker is available"
|
||||||
|
|
||||||
|
install-e2e:
|
||||||
|
@echo "📦 Installing E2E test dependencies..."
|
||||||
|
@uv sync --extra dev --extra e2e
|
||||||
|
@echo "✅ E2E dependencies installed!"
|
||||||
|
|
||||||
|
test-e2e: check-docker
|
||||||
|
@echo "🧪 Running E2E tests with PostgreSQL..."
|
||||||
|
@IS_TEST=True PYTHONPATH=. uv run pytest tests/e2e/ -v --tb=short -n 0
|
||||||
|
@echo "✅ E2E tests complete!"
|
||||||
|
|
||||||
|
test-e2e-schema: check-docker
|
||||||
|
@echo "🧪 Running Schemathesis API schema tests..."
|
||||||
|
@IS_TEST=True PYTHONPATH=. uv run pytest tests/e2e/ -v -m "schemathesis" --tb=short -n 0
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Performance Benchmarks
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
benchmark:
|
||||||
|
@echo "⏱️ Running performance benchmarks..."
|
||||||
|
@IS_TEST=True PYTHONPATH=. uv run pytest tests/benchmarks/ -v --benchmark-only --benchmark-sort=mean -p no:xdist --override-ini='addopts='
|
||||||
|
|
||||||
|
benchmark-save:
|
||||||
|
@echo "⏱️ Running benchmarks and saving baseline..."
|
||||||
|
@IS_TEST=True PYTHONPATH=. uv run pytest tests/benchmarks/ -v --benchmark-only --benchmark-save=baseline --benchmark-sort=mean -p no:xdist --override-ini='addopts='
|
||||||
|
@echo "✅ Benchmark baseline saved to .benchmarks/"
|
||||||
|
|
||||||
|
benchmark-check:
|
||||||
|
@echo "⏱️ Running benchmarks and comparing against baseline..."
|
||||||
|
@if find .benchmarks -name '*_baseline*' -print -quit 2>/dev/null | grep -q .; then \
|
||||||
|
IS_TEST=True PYTHONPATH=. uv run pytest tests/benchmarks/ -v --benchmark-only --benchmark-compare=0001_baseline --benchmark-sort=mean --benchmark-compare-fail=mean:200% -p no:xdist --override-ini='addopts='; \
|
||||||
|
echo "✅ No performance regressions detected!"; \
|
||||||
|
else \
|
||||||
|
echo "⚠️ No benchmark baseline found. Run 'make benchmark-save' first to create one."; \
|
||||||
|
echo " Running benchmarks without comparison..."; \
|
||||||
|
IS_TEST=True PYTHONPATH=. uv run pytest tests/benchmarks/ -v --benchmark-only --benchmark-save=baseline --benchmark-sort=mean -p no:xdist --override-ini='addopts='; \
|
||||||
|
echo "✅ Benchmark baseline created. Future runs of 'make benchmark-check' will compare against it."; \
|
||||||
|
fi
|
||||||
|
|
||||||
|
test-all:
|
||||||
|
@echo "🧪 Running ALL tests (unit + E2E)..."
|
||||||
|
@$(MAKE) test
|
||||||
|
@$(MAKE) test-e2e
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
# Cleanup
|
# Cleanup
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
@@ -85,7 +209,7 @@ clean:
|
|||||||
@echo "🧹 Cleaning up..."
|
@echo "🧹 Cleaning up..."
|
||||||
@find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true
|
@find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true
|
||||||
@find . -type d -name ".pytest_cache" -exec rm -rf {} + 2>/dev/null || true
|
@find . -type d -name ".pytest_cache" -exec rm -rf {} + 2>/dev/null || true
|
||||||
@find . -type d -name ".mypy_cache" -exec rm -rf {} + 2>/dev/null || true
|
@find . -type d -name ".pyright" -exec rm -rf {} + 2>/dev/null || true
|
||||||
@find . -type d -name ".ruff_cache" -exec rm -rf {} + 2>/dev/null || true
|
@find . -type d -name ".ruff_cache" -exec rm -rf {} + 2>/dev/null || true
|
||||||
@find . -type d -name "*.egg-info" -exec rm -rf {} + 2>/dev/null || true
|
@find . -type d -name "*.egg-info" -exec rm -rf {} + 2>/dev/null || true
|
||||||
@find . -type d -name "htmlcov" -exec rm -rf {} + 2>/dev/null || true
|
@find . -type d -name "htmlcov" -exec rm -rf {} + 2>/dev/null || true
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
# Backend API
|
# PragmaStack Backend API
|
||||||
|
|
||||||
> FastAPI-based REST API with async SQLAlchemy, JWT authentication, and comprehensive testing.
|
> The pragmatic, production-ready FastAPI backend for PragmaStack.
|
||||||
|
|
||||||
## Overview
|
## Overview
|
||||||
|
|
||||||
Production-ready FastAPI backend featuring:
|
Opinionated, secure, and fast. This backend provides the solid foundation you need to ship features, not boilerplate.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
|
||||||
- **Authentication**: JWT with refresh tokens, session management, device tracking
|
- **Authentication**: JWT with refresh tokens, session management, device tracking
|
||||||
- **Database**: Async PostgreSQL with SQLAlchemy 2.0, Alembic migrations
|
- **Database**: Async PostgreSQL with SQLAlchemy 2.0, Alembic migrations
|
||||||
@@ -12,7 +14,9 @@ Production-ready FastAPI backend featuring:
|
|||||||
- **Multi-tenancy**: Organization-based access control with roles (Owner/Admin/Member)
|
- **Multi-tenancy**: Organization-based access control with roles (Owner/Admin/Member)
|
||||||
- **Testing**: 97%+ coverage with security-focused test suite
|
- **Testing**: 97%+ coverage with security-focused test suite
|
||||||
- **Performance**: Async throughout, connection pooling, optimized queries
|
- **Performance**: Async throughout, connection pooling, optimized queries
|
||||||
- **Modern Tooling**: uv for dependencies, Ruff for linting/formatting, mypy for type checking
|
- **Modern Tooling**: uv for dependencies, Ruff for linting/formatting, Pyright for type checking
|
||||||
|
- **Security Auditing**: Automated dependency vulnerability scanning, license compliance, secrets detection
|
||||||
|
- **Pre-commit Hooks**: Ruff, detect-secrets, and standard checks on every commit
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|
||||||
@@ -147,7 +151,7 @@ uv pip list --outdated
|
|||||||
# Run any Python command via uv (no activation needed)
|
# Run any Python command via uv (no activation needed)
|
||||||
uv run python script.py
|
uv run python script.py
|
||||||
uv run pytest
|
uv run pytest
|
||||||
uv run mypy app/
|
uv run pyright app/
|
||||||
|
|
||||||
# Or activate the virtual environment
|
# Or activate the virtual environment
|
||||||
source .venv/bin/activate
|
source .venv/bin/activate
|
||||||
@@ -169,12 +173,22 @@ make lint # Run Ruff linter (check only)
|
|||||||
make lint-fix # Run Ruff with auto-fix
|
make lint-fix # Run Ruff with auto-fix
|
||||||
make format # Format code with Ruff
|
make format # Format code with Ruff
|
||||||
make format-check # Check if code is formatted
|
make format-check # Check if code is formatted
|
||||||
make type-check # Run mypy type checking
|
make type-check # Run Pyright type checking
|
||||||
make validate # Run all checks (lint + format + types)
|
make validate # Run all checks (lint + format + types)
|
||||||
|
|
||||||
|
# Security & Audit
|
||||||
|
make dep-audit # Scan dependencies for known vulnerabilities (CVEs)
|
||||||
|
make license-check # Check dependency license compliance
|
||||||
|
make audit # Run all security audits (deps + licenses)
|
||||||
|
make validate-all # Run all quality + security checks
|
||||||
|
make check # Full pipeline: quality + security + tests
|
||||||
|
|
||||||
# Testing
|
# Testing
|
||||||
make test # Run all tests
|
make test # Run all tests
|
||||||
make test-cov # Run tests with coverage report
|
make test-cov # Run tests with coverage report
|
||||||
|
make test-e2e # Run E2E tests (PostgreSQL, requires Docker)
|
||||||
|
make test-e2e-schema # Run Schemathesis API schema tests
|
||||||
|
make test-all # Run all tests (unit + E2E)
|
||||||
|
|
||||||
# Utilities
|
# Utilities
|
||||||
make clean # Remove cache and build artifacts
|
make clean # Remove cache and build artifacts
|
||||||
@@ -250,7 +264,7 @@ app/
|
|||||||
│ ├── database.py # Database engine setup
|
│ ├── database.py # Database engine setup
|
||||||
│ ├── auth.py # JWT token handling
|
│ ├── auth.py # JWT token handling
|
||||||
│ └── exceptions.py # Custom exceptions
|
│ └── exceptions.py # Custom exceptions
|
||||||
├── crud/ # Database operations
|
├── repositories/ # Repository pattern (database operations)
|
||||||
├── models/ # SQLAlchemy ORM models
|
├── models/ # SQLAlchemy ORM models
|
||||||
├── schemas/ # Pydantic request/response schemas
|
├── schemas/ # Pydantic request/response schemas
|
||||||
├── services/ # Business logic layer
|
├── services/ # Business logic layer
|
||||||
@@ -350,18 +364,29 @@ open htmlcov/index.html
|
|||||||
# Using Makefile (recommended)
|
# Using Makefile (recommended)
|
||||||
make lint # Ruff linting
|
make lint # Ruff linting
|
||||||
make format # Ruff formatting
|
make format # Ruff formatting
|
||||||
make type-check # mypy type checking
|
make type-check # Pyright type checking
|
||||||
make validate # All checks at once
|
make validate # All checks at once
|
||||||
|
|
||||||
|
# Security audits
|
||||||
|
make dep-audit # Scan dependencies for CVEs
|
||||||
|
make license-check # Check license compliance
|
||||||
|
make audit # All security audits
|
||||||
|
make validate-all # Quality + security checks
|
||||||
|
make check # Full pipeline: quality + security + tests
|
||||||
|
|
||||||
# Using uv directly
|
# Using uv directly
|
||||||
uv run ruff check app/ tests/
|
uv run ruff check app/ tests/
|
||||||
uv run ruff format app/ tests/
|
uv run ruff format app/ tests/
|
||||||
uv run mypy app/
|
uv run pyright app/
|
||||||
```
|
```
|
||||||
|
|
||||||
**Tools:**
|
**Tools:**
|
||||||
- **Ruff**: All-in-one linting, formatting, and import sorting (replaces Black, Flake8, isort)
|
- **Ruff**: All-in-one linting, formatting, and import sorting (replaces Black, Flake8, isort)
|
||||||
- **mypy**: Static type checking with Pydantic plugin
|
- **Pyright**: Static type checking (strict mode)
|
||||||
|
- **pip-audit**: Dependency vulnerability scanning against the OSV database
|
||||||
|
- **pip-licenses**: Dependency license compliance checking
|
||||||
|
- **detect-secrets**: Hardcoded secrets/credentials detection
|
||||||
|
- **pre-commit**: Git hook framework for automated checks on every commit
|
||||||
|
|
||||||
All configurations are in `pyproject.toml`.
|
All configurations are in `pyproject.toml`.
|
||||||
|
|
||||||
@@ -437,7 +462,7 @@ See [docs/FEATURE_EXAMPLE.md](docs/FEATURE_EXAMPLE.md) for step-by-step guide.
|
|||||||
|
|
||||||
Quick overview:
|
Quick overview:
|
||||||
1. Create Pydantic schemas in `app/schemas/`
|
1. Create Pydantic schemas in `app/schemas/`
|
||||||
2. Create CRUD operations in `app/crud/`
|
2. Create repository in `app/repositories/`
|
||||||
3. Create route in `app/api/routes/`
|
3. Create route in `app/api/routes/`
|
||||||
4. Register router in `app/api/main.py`
|
4. Register router in `app/api/main.py`
|
||||||
5. Write tests in `tests/api/`
|
5. Write tests in `tests/api/`
|
||||||
@@ -587,13 +612,42 @@ Configured in `app/core/config.py`:
|
|||||||
- **Security Headers**: CSP, HSTS, X-Frame-Options, etc.
|
- **Security Headers**: CSP, HSTS, X-Frame-Options, etc.
|
||||||
- **Input Validation**: Pydantic schemas, SQL injection prevention (ORM)
|
- **Input Validation**: Pydantic schemas, SQL injection prevention (ORM)
|
||||||
|
|
||||||
|
### Security Auditing
|
||||||
|
|
||||||
|
Automated, deterministic security checks are built into the development workflow:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Scan dependencies for known vulnerabilities (CVEs)
|
||||||
|
make dep-audit
|
||||||
|
|
||||||
|
# Check dependency license compliance (blocks GPL-3.0/AGPL)
|
||||||
|
make license-check
|
||||||
|
|
||||||
|
# Run all security audits
|
||||||
|
make audit
|
||||||
|
|
||||||
|
# Full pipeline: quality + security + tests
|
||||||
|
make check
|
||||||
|
```
|
||||||
|
|
||||||
|
**Pre-commit hooks** automatically run on every commit:
|
||||||
|
- **Ruff** lint + format checks
|
||||||
|
- **detect-secrets** blocks commits containing hardcoded secrets
|
||||||
|
- **Standard checks**: trailing whitespace, YAML/TOML validation, merge conflict detection, large file prevention
|
||||||
|
|
||||||
|
Setup pre-commit hooks:
|
||||||
|
```bash
|
||||||
|
uv run pre-commit install
|
||||||
|
```
|
||||||
|
|
||||||
### Security Best Practices
|
### Security Best Practices
|
||||||
|
|
||||||
1. **Never commit secrets**: Use `.env` files (git-ignored)
|
1. **Never commit secrets**: Use `.env` files (git-ignored), enforced by detect-secrets pre-commit hook
|
||||||
2. **Strong SECRET_KEY**: Min 32 chars, cryptographically random
|
2. **Strong SECRET_KEY**: Min 32 chars, cryptographically random
|
||||||
3. **HTTPS in production**: Required for token security
|
3. **HTTPS in production**: Required for token security
|
||||||
4. **Regular updates**: Keep dependencies current (`uv sync --upgrade`)
|
4. **Regular updates**: Keep dependencies current (`uv sync --upgrade`), run `make dep-audit` to check for CVEs
|
||||||
5. **Audit logs**: Monitor authentication events
|
5. **Audit logs**: Monitor authentication events
|
||||||
|
6. **Run `make check` before pushing**: Validates quality, security, and tests in one command
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -643,7 +697,11 @@ logging.basicConfig(level=logging.INFO)
|
|||||||
**Built with modern Python tooling:**
|
**Built with modern Python tooling:**
|
||||||
- 🚀 **uv** - 10-100x faster dependency management
|
- 🚀 **uv** - 10-100x faster dependency management
|
||||||
- ⚡ **Ruff** - 10-100x faster linting & formatting
|
- ⚡ **Ruff** - 10-100x faster linting & formatting
|
||||||
- 🔍 **mypy** - Static type checking
|
- 🔍 **Pyright** - Static type checking (strict mode)
|
||||||
- ✅ **pytest** - Comprehensive test suite
|
- ✅ **pytest** - Comprehensive test suite
|
||||||
|
- 🔒 **pip-audit** - Dependency vulnerability scanning
|
||||||
|
- 🔑 **detect-secrets** - Hardcoded secrets detection
|
||||||
|
- 📜 **pip-licenses** - License compliance checking
|
||||||
|
- 🪝 **pre-commit** - Automated git hooks
|
||||||
|
|
||||||
**All configured in a single `pyproject.toml` file!**
|
**All configured in a single `pyproject.toml` file!**
|
||||||
|
|||||||
@@ -2,6 +2,13 @@
|
|||||||
script_location = app/alembic
|
script_location = app/alembic
|
||||||
sqlalchemy.url = postgresql://postgres:postgres@db:5432/app
|
sqlalchemy.url = postgresql://postgres:postgres@db:5432/app
|
||||||
|
|
||||||
|
# Use sequential naming: 0001_message.py, 0002_message.py, etc.
|
||||||
|
# The rev_id is still used internally but filename is cleaner
|
||||||
|
file_template = %%(rev)s_%%(slug)s
|
||||||
|
|
||||||
|
# Allow specifying custom revision IDs via --rev-id flag
|
||||||
|
revision_environment = true
|
||||||
|
|
||||||
[loggers]
|
[loggers]
|
||||||
keys = root,sqlalchemy,alembic
|
keys = root,sqlalchemy,alembic
|
||||||
|
|
||||||
|
|||||||
@@ -22,6 +22,25 @@ from app.models import *
|
|||||||
# access to the values within the .ini file in use.
|
# access to the values within the .ini file in use.
|
||||||
config = context.config
|
config = context.config
|
||||||
|
|
||||||
|
|
||||||
|
def include_object(object, name, type_, reflected, compare_to):
|
||||||
|
"""
|
||||||
|
Filter objects for autogenerate.
|
||||||
|
|
||||||
|
Skip comparing functional indexes (like LOWER(column)) and partial indexes
|
||||||
|
(with WHERE clauses) as Alembic cannot reliably detect these from models.
|
||||||
|
These should be managed manually via dedicated performance migrations.
|
||||||
|
|
||||||
|
Convention: Any index starting with "ix_perf_" is automatically excluded.
|
||||||
|
This allows adding new performance indexes without updating this file.
|
||||||
|
"""
|
||||||
|
if type_ == "index" and name:
|
||||||
|
# Convention-based: any index prefixed with ix_perf_ is manual
|
||||||
|
if name.startswith("ix_perf_"):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
# Interpret the config file for Python logging.
|
# Interpret the config file for Python logging.
|
||||||
# This line sets up loggers basically.
|
# This line sets up loggers basically.
|
||||||
if config.config_file_name is not None:
|
if config.config_file_name is not None:
|
||||||
@@ -100,6 +119,8 @@ def run_migrations_offline() -> None:
|
|||||||
target_metadata=target_metadata,
|
target_metadata=target_metadata,
|
||||||
literal_binds=True,
|
literal_binds=True,
|
||||||
dialect_opts={"paramstyle": "named"},
|
dialect_opts={"paramstyle": "named"},
|
||||||
|
compare_type=True,
|
||||||
|
include_object=include_object,
|
||||||
)
|
)
|
||||||
|
|
||||||
with context.begin_transaction():
|
with context.begin_transaction():
|
||||||
@@ -123,7 +144,12 @@ def run_migrations_online() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
with connectable.connect() as connection:
|
with connectable.connect() as connection:
|
||||||
context.configure(connection=connection, target_metadata=target_metadata)
|
context.configure(
|
||||||
|
connection=connection,
|
||||||
|
target_metadata=target_metadata,
|
||||||
|
compare_type=True,
|
||||||
|
include_object=include_object,
|
||||||
|
)
|
||||||
|
|
||||||
with context.begin_transaction():
|
with context.begin_transaction():
|
||||||
context.run_migrations()
|
context.run_migrations()
|
||||||
|
|||||||
446
backend/app/alembic/versions/0001_initial_models.py
Normal file
446
backend/app/alembic/versions/0001_initial_models.py
Normal file
@@ -0,0 +1,446 @@
|
|||||||
|
"""initial models
|
||||||
|
|
||||||
|
Revision ID: 0001
|
||||||
|
Revises:
|
||||||
|
Create Date: 2025-11-27 09:08:09.464506
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "0001"
|
||||||
|
down_revision: str | None = None
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table(
|
||||||
|
"oauth_states",
|
||||||
|
sa.Column("state", sa.String(length=255), nullable=False),
|
||||||
|
sa.Column("code_verifier", sa.String(length=128), nullable=True),
|
||||||
|
sa.Column("nonce", sa.String(length=255), nullable=True),
|
||||||
|
sa.Column("provider", sa.String(length=50), nullable=False),
|
||||||
|
sa.Column("redirect_uri", sa.String(length=500), nullable=True),
|
||||||
|
sa.Column("user_id", sa.UUID(), nullable=True),
|
||||||
|
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("id", sa.UUID(), nullable=False),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_oauth_states_state"), "oauth_states", ["state"], unique=True
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"organizations",
|
||||||
|
sa.Column("name", sa.String(length=255), nullable=False),
|
||||||
|
sa.Column("slug", sa.String(length=255), nullable=False),
|
||||||
|
sa.Column("description", sa.Text(), nullable=True),
|
||||||
|
sa.Column("is_active", sa.Boolean(), nullable=False),
|
||||||
|
sa.Column("settings", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||||
|
sa.Column("id", sa.UUID(), nullable=False),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_organizations_is_active"), "organizations", ["is_active"], unique=False
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_organizations_name"), "organizations", ["name"], unique=False
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_organizations_name_active",
|
||||||
|
"organizations",
|
||||||
|
["name", "is_active"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_organizations_slug"), "organizations", ["slug"], unique=True
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_organizations_slug_active",
|
||||||
|
"organizations",
|
||||||
|
["slug", "is_active"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"users",
|
||||||
|
sa.Column("email", sa.String(length=255), nullable=False),
|
||||||
|
sa.Column("password_hash", sa.String(length=255), nullable=True),
|
||||||
|
sa.Column("first_name", sa.String(length=100), nullable=False),
|
||||||
|
sa.Column("last_name", sa.String(length=100), nullable=True),
|
||||||
|
sa.Column("phone_number", sa.String(length=20), nullable=True),
|
||||||
|
sa.Column("is_active", sa.Boolean(), nullable=False),
|
||||||
|
sa.Column("is_superuser", sa.Boolean(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"preferences", postgresql.JSONB(astext_type=sa.Text()), nullable=True
|
||||||
|
),
|
||||||
|
sa.Column("locale", sa.String(length=10), nullable=True),
|
||||||
|
sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("id", sa.UUID(), nullable=False),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
op.create_index(op.f("ix_users_deleted_at"), "users", ["deleted_at"], unique=False)
|
||||||
|
op.create_index(op.f("ix_users_email"), "users", ["email"], unique=True)
|
||||||
|
op.create_index(op.f("ix_users_is_active"), "users", ["is_active"], unique=False)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_users_is_superuser"), "users", ["is_superuser"], unique=False
|
||||||
|
)
|
||||||
|
op.create_index(op.f("ix_users_locale"), "users", ["locale"], unique=False)
|
||||||
|
op.create_table(
|
||||||
|
"oauth_accounts",
|
||||||
|
sa.Column("user_id", sa.UUID(), nullable=False),
|
||||||
|
sa.Column("provider", sa.String(length=50), nullable=False),
|
||||||
|
sa.Column("provider_user_id", sa.String(length=255), nullable=False),
|
||||||
|
sa.Column("provider_email", sa.String(length=255), nullable=True),
|
||||||
|
sa.Column("access_token_encrypted", sa.String(length=2048), nullable=True),
|
||||||
|
sa.Column("refresh_token_encrypted", sa.String(length=2048), nullable=True),
|
||||||
|
sa.Column("token_expires_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("id", sa.UUID(), nullable=False),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.UniqueConstraint(
|
||||||
|
"provider", "provider_user_id", name="uq_oauth_provider_user"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_oauth_accounts_provider"), "oauth_accounts", ["provider"], unique=False
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_oauth_accounts_provider_email"),
|
||||||
|
"oauth_accounts",
|
||||||
|
["provider_email"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_oauth_accounts_user_id"), "oauth_accounts", ["user_id"], unique=False
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_oauth_accounts_user_provider",
|
||||||
|
"oauth_accounts",
|
||||||
|
["user_id", "provider"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"oauth_clients",
|
||||||
|
sa.Column("client_id", sa.String(length=64), nullable=False),
|
||||||
|
sa.Column("client_secret_hash", sa.String(length=255), nullable=True),
|
||||||
|
sa.Column("client_name", sa.String(length=255), nullable=False),
|
||||||
|
sa.Column("client_description", sa.String(length=1000), nullable=True),
|
||||||
|
sa.Column("client_type", sa.String(length=20), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"redirect_uris", postgresql.JSONB(astext_type=sa.Text()), nullable=False
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"allowed_scopes", postgresql.JSONB(astext_type=sa.Text()), nullable=False
|
||||||
|
),
|
||||||
|
sa.Column("access_token_lifetime", sa.String(length=10), nullable=False),
|
||||||
|
sa.Column("refresh_token_lifetime", sa.String(length=10), nullable=False),
|
||||||
|
sa.Column("is_active", sa.Boolean(), nullable=False),
|
||||||
|
sa.Column("owner_user_id", sa.UUID(), nullable=True),
|
||||||
|
sa.Column("mcp_server_url", sa.String(length=2048), nullable=True),
|
||||||
|
sa.Column("id", sa.UUID(), nullable=False),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(["owner_user_id"], ["users.id"], ondelete="SET NULL"),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_oauth_clients_client_id"), "oauth_clients", ["client_id"], unique=True
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_oauth_clients_is_active"), "oauth_clients", ["is_active"], unique=False
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"user_organizations",
|
||||||
|
sa.Column("user_id", sa.UUID(), nullable=False),
|
||||||
|
sa.Column("organization_id", sa.UUID(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"role",
|
||||||
|
sa.Enum("OWNER", "ADMIN", "MEMBER", "GUEST", name="organizationrole"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("is_active", sa.Boolean(), nullable=False),
|
||||||
|
sa.Column("custom_permissions", sa.String(length=500), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["organization_id"], ["organizations.id"], ondelete="CASCADE"
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
sa.PrimaryKeyConstraint("user_id", "organization_id"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_user_org_org_active",
|
||||||
|
"user_organizations",
|
||||||
|
["organization_id", "is_active"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_index("ix_user_org_role", "user_organizations", ["role"], unique=False)
|
||||||
|
op.create_index(
|
||||||
|
"ix_user_org_user_active",
|
||||||
|
"user_organizations",
|
||||||
|
["user_id", "is_active"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_user_organizations_is_active"),
|
||||||
|
"user_organizations",
|
||||||
|
["is_active"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"user_sessions",
|
||||||
|
sa.Column("user_id", sa.UUID(), nullable=False),
|
||||||
|
sa.Column("refresh_token_jti", sa.String(length=255), nullable=False),
|
||||||
|
sa.Column("device_name", sa.String(length=255), nullable=True),
|
||||||
|
sa.Column("device_id", sa.String(length=255), nullable=True),
|
||||||
|
sa.Column("ip_address", sa.String(length=45), nullable=True),
|
||||||
|
sa.Column("user_agent", sa.String(length=500), nullable=True),
|
||||||
|
sa.Column("last_used_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("is_active", sa.Boolean(), nullable=False),
|
||||||
|
sa.Column("location_city", sa.String(length=100), nullable=True),
|
||||||
|
sa.Column("location_country", sa.String(length=100), nullable=True),
|
||||||
|
sa.Column("id", sa.UUID(), nullable=False),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_user_sessions_is_active"), "user_sessions", ["is_active"], unique=False
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_user_sessions_jti_active",
|
||||||
|
"user_sessions",
|
||||||
|
["refresh_token_jti", "is_active"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_user_sessions_refresh_token_jti"),
|
||||||
|
"user_sessions",
|
||||||
|
["refresh_token_jti"],
|
||||||
|
unique=True,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_user_sessions_user_active",
|
||||||
|
"user_sessions",
|
||||||
|
["user_id", "is_active"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_user_sessions_user_id"), "user_sessions", ["user_id"], unique=False
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"oauth_authorization_codes",
|
||||||
|
sa.Column("code", sa.String(length=128), nullable=False),
|
||||||
|
sa.Column("client_id", sa.String(length=64), nullable=False),
|
||||||
|
sa.Column("user_id", sa.UUID(), nullable=False),
|
||||||
|
sa.Column("redirect_uri", sa.String(length=2048), nullable=False),
|
||||||
|
sa.Column("scope", sa.String(length=1000), nullable=False),
|
||||||
|
sa.Column("code_challenge", sa.String(length=128), nullable=True),
|
||||||
|
sa.Column("code_challenge_method", sa.String(length=10), nullable=True),
|
||||||
|
sa.Column("state", sa.String(length=256), nullable=True),
|
||||||
|
sa.Column("nonce", sa.String(length=256), nullable=True),
|
||||||
|
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("used", sa.Boolean(), nullable=False),
|
||||||
|
sa.Column("id", sa.UUID(), nullable=False),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["client_id"], ["oauth_clients.client_id"], ondelete="CASCADE"
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_oauth_authorization_codes_client_user",
|
||||||
|
"oauth_authorization_codes",
|
||||||
|
["client_id", "user_id"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_oauth_authorization_codes_code"),
|
||||||
|
"oauth_authorization_codes",
|
||||||
|
["code"],
|
||||||
|
unique=True,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_oauth_authorization_codes_expires_at",
|
||||||
|
"oauth_authorization_codes",
|
||||||
|
["expires_at"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"oauth_consents",
|
||||||
|
sa.Column("user_id", sa.UUID(), nullable=False),
|
||||||
|
sa.Column("client_id", sa.String(length=64), nullable=False),
|
||||||
|
sa.Column("granted_scopes", sa.String(length=1000), nullable=False),
|
||||||
|
sa.Column("id", sa.UUID(), nullable=False),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["client_id"], ["oauth_clients.client_id"], ondelete="CASCADE"
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_oauth_consents_user_client",
|
||||||
|
"oauth_consents",
|
||||||
|
["user_id", "client_id"],
|
||||||
|
unique=True,
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"oauth_provider_refresh_tokens",
|
||||||
|
sa.Column("token_hash", sa.String(length=64), nullable=False),
|
||||||
|
sa.Column("jti", sa.String(length=64), nullable=False),
|
||||||
|
sa.Column("client_id", sa.String(length=64), nullable=False),
|
||||||
|
sa.Column("user_id", sa.UUID(), nullable=False),
|
||||||
|
sa.Column("scope", sa.String(length=1000), nullable=False),
|
||||||
|
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("revoked", sa.Boolean(), nullable=False),
|
||||||
|
sa.Column("last_used_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("device_info", sa.String(length=500), nullable=True),
|
||||||
|
sa.Column("ip_address", sa.String(length=45), nullable=True),
|
||||||
|
sa.Column("id", sa.UUID(), nullable=False),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["client_id"], ["oauth_clients.client_id"], ondelete="CASCADE"
|
||||||
|
),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_oauth_provider_refresh_tokens_client_user",
|
||||||
|
"oauth_provider_refresh_tokens",
|
||||||
|
["client_id", "user_id"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_oauth_provider_refresh_tokens_expires_at",
|
||||||
|
"oauth_provider_refresh_tokens",
|
||||||
|
["expires_at"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_oauth_provider_refresh_tokens_jti"),
|
||||||
|
"oauth_provider_refresh_tokens",
|
||||||
|
["jti"],
|
||||||
|
unique=True,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_oauth_provider_refresh_tokens_revoked"),
|
||||||
|
"oauth_provider_refresh_tokens",
|
||||||
|
["revoked"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_oauth_provider_refresh_tokens_token_hash"),
|
||||||
|
"oauth_provider_refresh_tokens",
|
||||||
|
["token_hash"],
|
||||||
|
unique=True,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_oauth_provider_refresh_tokens_user_revoked",
|
||||||
|
"oauth_provider_refresh_tokens",
|
||||||
|
["user_id", "revoked"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_index(
|
||||||
|
"ix_oauth_provider_refresh_tokens_user_revoked",
|
||||||
|
table_name="oauth_provider_refresh_tokens",
|
||||||
|
)
|
||||||
|
op.drop_index(
|
||||||
|
op.f("ix_oauth_provider_refresh_tokens_token_hash"),
|
||||||
|
table_name="oauth_provider_refresh_tokens",
|
||||||
|
)
|
||||||
|
op.drop_index(
|
||||||
|
op.f("ix_oauth_provider_refresh_tokens_revoked"),
|
||||||
|
table_name="oauth_provider_refresh_tokens",
|
||||||
|
)
|
||||||
|
op.drop_index(
|
||||||
|
op.f("ix_oauth_provider_refresh_tokens_jti"),
|
||||||
|
table_name="oauth_provider_refresh_tokens",
|
||||||
|
)
|
||||||
|
op.drop_index(
|
||||||
|
"ix_oauth_provider_refresh_tokens_expires_at",
|
||||||
|
table_name="oauth_provider_refresh_tokens",
|
||||||
|
)
|
||||||
|
op.drop_index(
|
||||||
|
"ix_oauth_provider_refresh_tokens_client_user",
|
||||||
|
table_name="oauth_provider_refresh_tokens",
|
||||||
|
)
|
||||||
|
op.drop_table("oauth_provider_refresh_tokens")
|
||||||
|
op.drop_index("ix_oauth_consents_user_client", table_name="oauth_consents")
|
||||||
|
op.drop_table("oauth_consents")
|
||||||
|
op.drop_index(
|
||||||
|
"ix_oauth_authorization_codes_expires_at",
|
||||||
|
table_name="oauth_authorization_codes",
|
||||||
|
)
|
||||||
|
op.drop_index(
|
||||||
|
op.f("ix_oauth_authorization_codes_code"),
|
||||||
|
table_name="oauth_authorization_codes",
|
||||||
|
)
|
||||||
|
op.drop_index(
|
||||||
|
"ix_oauth_authorization_codes_client_user",
|
||||||
|
table_name="oauth_authorization_codes",
|
||||||
|
)
|
||||||
|
op.drop_table("oauth_authorization_codes")
|
||||||
|
op.drop_index(op.f("ix_user_sessions_user_id"), table_name="user_sessions")
|
||||||
|
op.drop_index("ix_user_sessions_user_active", table_name="user_sessions")
|
||||||
|
op.drop_index(
|
||||||
|
op.f("ix_user_sessions_refresh_token_jti"), table_name="user_sessions"
|
||||||
|
)
|
||||||
|
op.drop_index("ix_user_sessions_jti_active", table_name="user_sessions")
|
||||||
|
op.drop_index(op.f("ix_user_sessions_is_active"), table_name="user_sessions")
|
||||||
|
op.drop_table("user_sessions")
|
||||||
|
op.drop_index(
|
||||||
|
op.f("ix_user_organizations_is_active"), table_name="user_organizations"
|
||||||
|
)
|
||||||
|
op.drop_index("ix_user_org_user_active", table_name="user_organizations")
|
||||||
|
op.drop_index("ix_user_org_role", table_name="user_organizations")
|
||||||
|
op.drop_index("ix_user_org_org_active", table_name="user_organizations")
|
||||||
|
op.drop_table("user_organizations")
|
||||||
|
op.drop_index(op.f("ix_oauth_clients_is_active"), table_name="oauth_clients")
|
||||||
|
op.drop_index(op.f("ix_oauth_clients_client_id"), table_name="oauth_clients")
|
||||||
|
op.drop_table("oauth_clients")
|
||||||
|
op.drop_index("ix_oauth_accounts_user_provider", table_name="oauth_accounts")
|
||||||
|
op.drop_index(op.f("ix_oauth_accounts_user_id"), table_name="oauth_accounts")
|
||||||
|
op.drop_index(op.f("ix_oauth_accounts_provider_email"), table_name="oauth_accounts")
|
||||||
|
op.drop_index(op.f("ix_oauth_accounts_provider"), table_name="oauth_accounts")
|
||||||
|
op.drop_table("oauth_accounts")
|
||||||
|
op.drop_index(op.f("ix_users_locale"), table_name="users")
|
||||||
|
op.drop_index(op.f("ix_users_is_superuser"), table_name="users")
|
||||||
|
op.drop_index(op.f("ix_users_is_active"), table_name="users")
|
||||||
|
op.drop_index(op.f("ix_users_email"), table_name="users")
|
||||||
|
op.drop_index(op.f("ix_users_deleted_at"), table_name="users")
|
||||||
|
op.drop_table("users")
|
||||||
|
op.drop_index("ix_organizations_slug_active", table_name="organizations")
|
||||||
|
op.drop_index(op.f("ix_organizations_slug"), table_name="organizations")
|
||||||
|
op.drop_index("ix_organizations_name_active", table_name="organizations")
|
||||||
|
op.drop_index(op.f("ix_organizations_name"), table_name="organizations")
|
||||||
|
op.drop_index(op.f("ix_organizations_is_active"), table_name="organizations")
|
||||||
|
op.drop_table("organizations")
|
||||||
|
op.drop_index(op.f("ix_oauth_states_state"), table_name="oauth_states")
|
||||||
|
op.drop_table("oauth_states")
|
||||||
|
# ### end Alembic commands ###
|
||||||
127
backend/app/alembic/versions/0002_add_performance_indexes.py
Normal file
127
backend/app/alembic/versions/0002_add_performance_indexes.py
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
"""Add performance indexes
|
||||||
|
|
||||||
|
Revision ID: 0002
|
||||||
|
Revises: 0001
|
||||||
|
Create Date: 2025-11-27
|
||||||
|
|
||||||
|
Performance indexes that Alembic cannot auto-detect:
|
||||||
|
- Functional indexes (LOWER expressions)
|
||||||
|
- Partial indexes (WHERE clauses)
|
||||||
|
|
||||||
|
These indexes use the ix_perf_ prefix and are excluded from autogenerate
|
||||||
|
via the include_object() function in env.py.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "0002"
|
||||||
|
down_revision: str | None = "0001"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ==========================================================================
|
||||||
|
# USERS TABLE - Performance indexes for authentication
|
||||||
|
# ==========================================================================
|
||||||
|
|
||||||
|
# Case-insensitive email lookup for login/registration
|
||||||
|
# Query: SELECT * FROM users WHERE LOWER(email) = LOWER(:email) AND deleted_at IS NULL
|
||||||
|
# Impact: High - every login, registration check, password reset
|
||||||
|
op.create_index(
|
||||||
|
"ix_perf_users_email_lower",
|
||||||
|
"users",
|
||||||
|
[sa.text("LOWER(email)")],
|
||||||
|
unique=False,
|
||||||
|
postgresql_where=sa.text("deleted_at IS NULL"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Active users lookup (non-soft-deleted)
|
||||||
|
# Query: SELECT * FROM users WHERE deleted_at IS NULL AND ...
|
||||||
|
# Impact: Medium - user listings, admin queries
|
||||||
|
op.create_index(
|
||||||
|
"ix_perf_users_active",
|
||||||
|
"users",
|
||||||
|
["is_active"],
|
||||||
|
unique=False,
|
||||||
|
postgresql_where=sa.text("deleted_at IS NULL"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ==========================================================================
|
||||||
|
# ORGANIZATIONS TABLE - Performance indexes for multi-tenant lookups
|
||||||
|
# ==========================================================================
|
||||||
|
|
||||||
|
# Case-insensitive slug lookup for URL routing
|
||||||
|
# Query: SELECT * FROM organizations WHERE LOWER(slug) = LOWER(:slug) AND is_active = true
|
||||||
|
# Impact: Medium - every organization page load
|
||||||
|
op.create_index(
|
||||||
|
"ix_perf_organizations_slug_lower",
|
||||||
|
"organizations",
|
||||||
|
[sa.text("LOWER(slug)")],
|
||||||
|
unique=False,
|
||||||
|
postgresql_where=sa.text("is_active = true"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ==========================================================================
|
||||||
|
# USER SESSIONS TABLE - Performance indexes for session management
|
||||||
|
# ==========================================================================
|
||||||
|
|
||||||
|
# Expired session cleanup
|
||||||
|
# Query: SELECT * FROM user_sessions WHERE expires_at < NOW() AND is_active = true
|
||||||
|
# Impact: Medium - background cleanup jobs
|
||||||
|
op.create_index(
|
||||||
|
"ix_perf_user_sessions_expires",
|
||||||
|
"user_sessions",
|
||||||
|
["expires_at"],
|
||||||
|
unique=False,
|
||||||
|
postgresql_where=sa.text("is_active = true"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ==========================================================================
|
||||||
|
# OAUTH PROVIDER TOKENS - Performance indexes for token management
|
||||||
|
# ==========================================================================
|
||||||
|
|
||||||
|
# Expired refresh token cleanup
|
||||||
|
# Query: SELECT * FROM oauth_provider_refresh_tokens WHERE expires_at < NOW() AND revoked = false
|
||||||
|
# Impact: Medium - OAuth token cleanup, validation
|
||||||
|
op.create_index(
|
||||||
|
"ix_perf_oauth_refresh_tokens_expires",
|
||||||
|
"oauth_provider_refresh_tokens",
|
||||||
|
["expires_at"],
|
||||||
|
unique=False,
|
||||||
|
postgresql_where=sa.text("revoked = false"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ==========================================================================
|
||||||
|
# OAUTH AUTHORIZATION CODES - Performance indexes for auth flow
|
||||||
|
# ==========================================================================
|
||||||
|
|
||||||
|
# Expired authorization code cleanup
|
||||||
|
# Query: DELETE FROM oauth_authorization_codes WHERE expires_at < NOW() AND used = false
|
||||||
|
# Impact: Low-Medium - OAuth cleanup jobs
|
||||||
|
op.create_index(
|
||||||
|
"ix_perf_oauth_auth_codes_expires",
|
||||||
|
"oauth_authorization_codes",
|
||||||
|
["expires_at"],
|
||||||
|
unique=False,
|
||||||
|
postgresql_where=sa.text("used = false"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# Drop indexes in reverse order
|
||||||
|
op.drop_index(
|
||||||
|
"ix_perf_oauth_auth_codes_expires", table_name="oauth_authorization_codes"
|
||||||
|
)
|
||||||
|
op.drop_index(
|
||||||
|
"ix_perf_oauth_refresh_tokens_expires",
|
||||||
|
table_name="oauth_provider_refresh_tokens",
|
||||||
|
)
|
||||||
|
op.drop_index("ix_perf_user_sessions_expires", table_name="user_sessions")
|
||||||
|
op.drop_index("ix_perf_organizations_slug_lower", table_name="organizations")
|
||||||
|
op.drop_index("ix_perf_users_active", table_name="users")
|
||||||
|
op.drop_index("ix_perf_users_email_lower", table_name="users")
|
||||||
@@ -0,0 +1,35 @@
|
|||||||
|
"""rename oauth account token fields drop encrypted suffix
|
||||||
|
|
||||||
|
Revision ID: 0003
|
||||||
|
Revises: 0002
|
||||||
|
Create Date: 2026-02-27 01:03:18.869178
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "0003"
|
||||||
|
down_revision: str | None = "0002"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.alter_column(
|
||||||
|
"oauth_accounts", "access_token_encrypted", new_column_name="access_token"
|
||||||
|
)
|
||||||
|
op.alter_column(
|
||||||
|
"oauth_accounts", "refresh_token_encrypted", new_column_name="refresh_token"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.alter_column(
|
||||||
|
"oauth_accounts", "access_token", new_column_name="access_token_encrypted"
|
||||||
|
)
|
||||||
|
op.alter_column(
|
||||||
|
"oauth_accounts", "refresh_token", new_column_name="refresh_token_encrypted"
|
||||||
|
)
|
||||||
@@ -1,78 +0,0 @@
|
|||||||
"""add_performance_indexes
|
|
||||||
|
|
||||||
Revision ID: 1174fffbe3e4
|
|
||||||
Revises: fbf6318a8a36
|
|
||||||
Create Date: 2025-11-01 04:15:25.367010
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
from collections.abc import Sequence
|
|
||||||
|
|
||||||
import sqlalchemy as sa
|
|
||||||
from alembic import op
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = "1174fffbe3e4"
|
|
||||||
down_revision: str | None = "fbf6318a8a36"
|
|
||||||
branch_labels: str | Sequence[str] | None = None
|
|
||||||
depends_on: str | Sequence[str] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
"""Add performance indexes for optimized queries."""
|
|
||||||
|
|
||||||
# Index for session cleanup queries
|
|
||||||
# Optimizes: DELETE WHERE is_active = FALSE AND expires_at < now AND created_at < cutoff
|
|
||||||
op.create_index(
|
|
||||||
"ix_user_sessions_cleanup",
|
|
||||||
"user_sessions",
|
|
||||||
["is_active", "expires_at", "created_at"],
|
|
||||||
unique=False,
|
|
||||||
postgresql_where=sa.text("is_active = false"),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Index for user search queries (basic trigram support without pg_trgm extension)
|
|
||||||
# Optimizes: WHERE email ILIKE '%search%' OR first_name ILIKE '%search%'
|
|
||||||
# Note: For better performance, consider enabling pg_trgm extension
|
|
||||||
op.create_index(
|
|
||||||
"ix_users_email_lower",
|
|
||||||
"users",
|
|
||||||
[sa.text("LOWER(email)")],
|
|
||||||
unique=False,
|
|
||||||
postgresql_where=sa.text("deleted_at IS NULL"),
|
|
||||||
)
|
|
||||||
|
|
||||||
op.create_index(
|
|
||||||
"ix_users_first_name_lower",
|
|
||||||
"users",
|
|
||||||
[sa.text("LOWER(first_name)")],
|
|
||||||
unique=False,
|
|
||||||
postgresql_where=sa.text("deleted_at IS NULL"),
|
|
||||||
)
|
|
||||||
|
|
||||||
op.create_index(
|
|
||||||
"ix_users_last_name_lower",
|
|
||||||
"users",
|
|
||||||
[sa.text("LOWER(last_name)")],
|
|
||||||
unique=False,
|
|
||||||
postgresql_where=sa.text("deleted_at IS NULL"),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Index for organization search
|
|
||||||
op.create_index(
|
|
||||||
"ix_organizations_name_lower",
|
|
||||||
"organizations",
|
|
||||||
[sa.text("LOWER(name)")],
|
|
||||||
unique=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
"""Remove performance indexes."""
|
|
||||||
|
|
||||||
# Drop indexes in reverse order
|
|
||||||
op.drop_index("ix_organizations_name_lower", table_name="organizations")
|
|
||||||
op.drop_index("ix_users_last_name_lower", table_name="users")
|
|
||||||
op.drop_index("ix_users_first_name_lower", table_name="users")
|
|
||||||
op.drop_index("ix_users_email_lower", table_name="users")
|
|
||||||
op.drop_index("ix_user_sessions_cleanup", table_name="user_sessions")
|
|
||||||
@@ -1,36 +0,0 @@
|
|||||||
"""add_soft_delete_to_users
|
|
||||||
|
|
||||||
Revision ID: 2d0fcec3b06d
|
|
||||||
Revises: 9e4f2a1b8c7d
|
|
||||||
Create Date: 2025-10-30 16:40:21.000021
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
from collections.abc import Sequence
|
|
||||||
|
|
||||||
import sqlalchemy as sa
|
|
||||||
from alembic import op
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = "2d0fcec3b06d"
|
|
||||||
down_revision: str | None = "9e4f2a1b8c7d"
|
|
||||||
branch_labels: str | Sequence[str] | None = None
|
|
||||||
depends_on: str | Sequence[str] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
# Add deleted_at column for soft deletes
|
|
||||||
op.add_column(
|
|
||||||
"users", sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add index on deleted_at for efficient queries
|
|
||||||
op.create_index("ix_users_deleted_at", "users", ["deleted_at"])
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
# Remove index
|
|
||||||
op.drop_index("ix_users_deleted_at", table_name="users")
|
|
||||||
|
|
||||||
# Remove column
|
|
||||||
op.drop_column("users", "deleted_at")
|
|
||||||
@@ -1,46 +0,0 @@
|
|||||||
"""Add all initial models
|
|
||||||
|
|
||||||
Revision ID: 38bf9e7e74b3
|
|
||||||
Revises: 7396957cbe80
|
|
||||||
Create Date: 2025-02-28 09:19:33.212278
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
from collections.abc import Sequence
|
|
||||||
|
|
||||||
import sqlalchemy as sa
|
|
||||||
from alembic import op
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = "38bf9e7e74b3"
|
|
||||||
down_revision: str | None = "7396957cbe80"
|
|
||||||
branch_labels: str | Sequence[str] | None = None
|
|
||||||
depends_on: str | Sequence[str] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
op.create_table(
|
|
||||||
"users",
|
|
||||||
sa.Column("email", sa.String(), nullable=False),
|
|
||||||
sa.Column("password_hash", sa.String(), nullable=False),
|
|
||||||
sa.Column("first_name", sa.String(), nullable=False),
|
|
||||||
sa.Column("last_name", sa.String(), nullable=True),
|
|
||||||
sa.Column("phone_number", sa.String(), nullable=True),
|
|
||||||
sa.Column("is_active", sa.Boolean(), nullable=False),
|
|
||||||
sa.Column("is_superuser", sa.Boolean(), nullable=False),
|
|
||||||
sa.Column("preferences", sa.JSON(), nullable=True),
|
|
||||||
sa.Column("id", sa.UUID(), nullable=False),
|
|
||||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
|
||||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
|
||||||
sa.PrimaryKeyConstraint("id"),
|
|
||||||
)
|
|
||||||
op.create_index(op.f("ix_users_email"), "users", ["email"], unique=True)
|
|
||||||
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
op.drop_index(op.f("ix_users_email"), table_name="users")
|
|
||||||
op.drop_table("users")
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
@@ -1,89 +0,0 @@
|
|||||||
"""add_user_sessions_table
|
|
||||||
|
|
||||||
Revision ID: 549b50ea888d
|
|
||||||
Revises: b76c725fc3cf
|
|
||||||
Create Date: 2025-10-31 07:41:18.729544
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
from collections.abc import Sequence
|
|
||||||
|
|
||||||
import sqlalchemy as sa
|
|
||||||
from alembic import op
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = "549b50ea888d"
|
|
||||||
down_revision: str | None = "b76c725fc3cf"
|
|
||||||
branch_labels: str | Sequence[str] | None = None
|
|
||||||
depends_on: str | Sequence[str] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
# Create user_sessions table for per-device session management
|
|
||||||
op.create_table(
|
|
||||||
"user_sessions",
|
|
||||||
sa.Column("id", sa.UUID(), nullable=False),
|
|
||||||
sa.Column("user_id", sa.UUID(), nullable=False),
|
|
||||||
sa.Column("refresh_token_jti", sa.String(length=255), nullable=False),
|
|
||||||
sa.Column("device_name", sa.String(length=255), nullable=True),
|
|
||||||
sa.Column("device_id", sa.String(length=255), nullable=True),
|
|
||||||
sa.Column("ip_address", sa.String(length=45), nullable=True),
|
|
||||||
sa.Column("user_agent", sa.String(length=500), nullable=True),
|
|
||||||
sa.Column("last_used_at", sa.DateTime(timezone=True), nullable=False),
|
|
||||||
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
|
||||||
sa.Column("is_active", sa.Boolean(), nullable=False, server_default="true"),
|
|
||||||
sa.Column("location_city", sa.String(length=100), nullable=True),
|
|
||||||
sa.Column("location_country", sa.String(length=100), nullable=True),
|
|
||||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
|
||||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
|
||||||
sa.PrimaryKeyConstraint("id"),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create foreign key to users table
|
|
||||||
op.create_foreign_key(
|
|
||||||
"fk_user_sessions_user_id",
|
|
||||||
"user_sessions",
|
|
||||||
"users",
|
|
||||||
["user_id"],
|
|
||||||
["id"],
|
|
||||||
ondelete="CASCADE",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create indexes for performance
|
|
||||||
# 1. Lookup session by refresh token JTI (most common query)
|
|
||||||
op.create_index(
|
|
||||||
"ix_user_sessions_jti", "user_sessions", ["refresh_token_jti"], unique=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. Lookup sessions by user ID
|
|
||||||
op.create_index("ix_user_sessions_user_id", "user_sessions", ["user_id"])
|
|
||||||
|
|
||||||
# 3. Composite index for active sessions by user
|
|
||||||
op.create_index(
|
|
||||||
"ix_user_sessions_user_active", "user_sessions", ["user_id", "is_active"]
|
|
||||||
)
|
|
||||||
|
|
||||||
# 4. Index on expires_at for cleanup job
|
|
||||||
op.create_index("ix_user_sessions_expires_at", "user_sessions", ["expires_at"])
|
|
||||||
|
|
||||||
# 5. Composite index for active session lookup by JTI
|
|
||||||
op.create_index(
|
|
||||||
"ix_user_sessions_jti_active",
|
|
||||||
"user_sessions",
|
|
||||||
["refresh_token_jti", "is_active"],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
# Drop indexes first
|
|
||||||
op.drop_index("ix_user_sessions_jti_active", table_name="user_sessions")
|
|
||||||
op.drop_index("ix_user_sessions_expires_at", table_name="user_sessions")
|
|
||||||
op.drop_index("ix_user_sessions_user_active", table_name="user_sessions")
|
|
||||||
op.drop_index("ix_user_sessions_user_id", table_name="user_sessions")
|
|
||||||
op.drop_index("ix_user_sessions_jti", table_name="user_sessions")
|
|
||||||
|
|
||||||
# Drop foreign key
|
|
||||||
op.drop_constraint("fk_user_sessions_user_id", "user_sessions", type_="foreignkey")
|
|
||||||
|
|
||||||
# Drop table
|
|
||||||
op.drop_table("user_sessions")
|
|
||||||
@@ -1,23 +0,0 @@
|
|||||||
"""Initial empty migration
|
|
||||||
|
|
||||||
Revision ID: 7396957cbe80
|
|
||||||
Revises:
|
|
||||||
Create Date: 2025-02-27 12:47:46.445313
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
from collections.abc import Sequence
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = "7396957cbe80"
|
|
||||||
down_revision: str | None = None
|
|
||||||
branch_labels: str | Sequence[str] | None = None
|
|
||||||
depends_on: str | Sequence[str] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
pass
|
|
||||||
@@ -1,116 +0,0 @@
|
|||||||
"""Add missing indexes and fix column types
|
|
||||||
|
|
||||||
Revision ID: 9e4f2a1b8c7d
|
|
||||||
Revises: 38bf9e7e74b3
|
|
||||||
Create Date: 2025-10-30 10:00:00.000000
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
from collections.abc import Sequence
|
|
||||||
|
|
||||||
import sqlalchemy as sa
|
|
||||||
from alembic import op
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = "9e4f2a1b8c7d"
|
|
||||||
down_revision: str | None = "38bf9e7e74b3"
|
|
||||||
branch_labels: str | Sequence[str] | None = None
|
|
||||||
depends_on: str | Sequence[str] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
# Add missing indexes for is_active and is_superuser
|
|
||||||
op.create_index(op.f("ix_users_is_active"), "users", ["is_active"], unique=False)
|
|
||||||
op.create_index(
|
|
||||||
op.f("ix_users_is_superuser"), "users", ["is_superuser"], unique=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# Fix column types to match model definitions with explicit lengths
|
|
||||||
op.alter_column(
|
|
||||||
"users",
|
|
||||||
"email",
|
|
||||||
existing_type=sa.String(),
|
|
||||||
type_=sa.String(length=255),
|
|
||||||
nullable=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
op.alter_column(
|
|
||||||
"users",
|
|
||||||
"password_hash",
|
|
||||||
existing_type=sa.String(),
|
|
||||||
type_=sa.String(length=255),
|
|
||||||
nullable=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
op.alter_column(
|
|
||||||
"users",
|
|
||||||
"first_name",
|
|
||||||
existing_type=sa.String(),
|
|
||||||
type_=sa.String(length=100),
|
|
||||||
nullable=False,
|
|
||||||
server_default="user",
|
|
||||||
) # Add server default
|
|
||||||
|
|
||||||
op.alter_column(
|
|
||||||
"users",
|
|
||||||
"last_name",
|
|
||||||
existing_type=sa.String(),
|
|
||||||
type_=sa.String(length=100),
|
|
||||||
nullable=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
op.alter_column(
|
|
||||||
"users",
|
|
||||||
"phone_number",
|
|
||||||
existing_type=sa.String(),
|
|
||||||
type_=sa.String(length=20),
|
|
||||||
nullable=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
# Revert column types
|
|
||||||
op.alter_column(
|
|
||||||
"users",
|
|
||||||
"phone_number",
|
|
||||||
existing_type=sa.String(length=20),
|
|
||||||
type_=sa.String(),
|
|
||||||
nullable=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
op.alter_column(
|
|
||||||
"users",
|
|
||||||
"last_name",
|
|
||||||
existing_type=sa.String(length=100),
|
|
||||||
type_=sa.String(),
|
|
||||||
nullable=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
op.alter_column(
|
|
||||||
"users",
|
|
||||||
"first_name",
|
|
||||||
existing_type=sa.String(length=100),
|
|
||||||
type_=sa.String(),
|
|
||||||
nullable=False,
|
|
||||||
server_default=None,
|
|
||||||
) # Remove server default
|
|
||||||
|
|
||||||
op.alter_column(
|
|
||||||
"users",
|
|
||||||
"password_hash",
|
|
||||||
existing_type=sa.String(length=255),
|
|
||||||
type_=sa.String(),
|
|
||||||
nullable=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
op.alter_column(
|
|
||||||
"users",
|
|
||||||
"email",
|
|
||||||
existing_type=sa.String(length=255),
|
|
||||||
type_=sa.String(),
|
|
||||||
nullable=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Drop indexes
|
|
||||||
op.drop_index(op.f("ix_users_is_superuser"), table_name="users")
|
|
||||||
op.drop_index(op.f("ix_users_is_active"), table_name="users")
|
|
||||||
@@ -1,48 +0,0 @@
|
|||||||
"""add_composite_indexes
|
|
||||||
|
|
||||||
Revision ID: b76c725fc3cf
|
|
||||||
Revises: 2d0fcec3b06d
|
|
||||||
Create Date: 2025-10-30 16:41:33.273135
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
from collections.abc import Sequence
|
|
||||||
|
|
||||||
import sqlalchemy as sa
|
|
||||||
from alembic import op
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = "b76c725fc3cf"
|
|
||||||
down_revision: str | None = "2d0fcec3b06d"
|
|
||||||
branch_labels: str | Sequence[str] | None = None
|
|
||||||
depends_on: str | Sequence[str] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
# Add composite indexes for common query patterns
|
|
||||||
|
|
||||||
# Composite index for filtering active users by role
|
|
||||||
op.create_index(
|
|
||||||
"ix_users_active_superuser",
|
|
||||||
"users",
|
|
||||||
["is_active", "is_superuser"],
|
|
||||||
postgresql_where=sa.text("deleted_at IS NULL"),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Composite index for sorting active users by creation date
|
|
||||||
op.create_index(
|
|
||||||
"ix_users_active_created",
|
|
||||||
"users",
|
|
||||||
["is_active", "created_at"],
|
|
||||||
postgresql_where=sa.text("deleted_at IS NULL"),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Composite index for email lookup of non-deleted users
|
|
||||||
op.create_index("ix_users_email_not_deleted", "users", ["email", "deleted_at"])
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
# Remove composite indexes
|
|
||||||
op.drop_index("ix_users_email_not_deleted", table_name="users")
|
|
||||||
op.drop_index("ix_users_active_created", table_name="users")
|
|
||||||
op.drop_index("ix_users_active_superuser", table_name="users")
|
|
||||||
@@ -1,42 +0,0 @@
|
|||||||
"""add user locale preference column
|
|
||||||
|
|
||||||
Revision ID: c8e9f3a2d1b4
|
|
||||||
Revises: b76c725fc3cf
|
|
||||||
Create Date: 2025-11-17 18:00:00.000000
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
from collections.abc import Sequence
|
|
||||||
|
|
||||||
import sqlalchemy as sa
|
|
||||||
from alembic import op
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = "c8e9f3a2d1b4"
|
|
||||||
down_revision: str | None = "b76c725fc3cf"
|
|
||||||
branch_labels: str | Sequence[str] | None = None
|
|
||||||
depends_on: str | Sequence[str] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
# Add locale column to users table
|
|
||||||
# VARCHAR(10) supports BCP 47 format (e.g., "en", "it", "en-US", "it-IT")
|
|
||||||
# Nullable: NULL means "not set yet", will use Accept-Language header fallback
|
|
||||||
# Indexed: For analytics queries and filtering by locale
|
|
||||||
op.add_column(
|
|
||||||
"users",
|
|
||||||
sa.Column("locale", sa.String(length=10), nullable=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create index on locale column for performance
|
|
||||||
op.create_index(
|
|
||||||
"ix_users_locale",
|
|
||||||
"users",
|
|
||||||
["locale"],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
# Remove locale index and column
|
|
||||||
op.drop_index("ix_users_locale", table_name="users")
|
|
||||||
op.drop_column("users", "locale")
|
|
||||||
@@ -1,127 +0,0 @@
|
|||||||
"""add_organizations_and_user_organizations
|
|
||||||
|
|
||||||
Revision ID: fbf6318a8a36
|
|
||||||
Revises: 549b50ea888d
|
|
||||||
Create Date: 2025-10-31 12:08:05.141353
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
from collections.abc import Sequence
|
|
||||||
|
|
||||||
import sqlalchemy as sa
|
|
||||||
from alembic import op
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = "fbf6318a8a36"
|
|
||||||
down_revision: str | None = "549b50ea888d"
|
|
||||||
branch_labels: str | Sequence[str] | None = None
|
|
||||||
depends_on: str | Sequence[str] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
# Create organizations table
|
|
||||||
op.create_table(
|
|
||||||
"organizations",
|
|
||||||
sa.Column("id", sa.UUID(), nullable=False),
|
|
||||||
sa.Column("name", sa.String(length=255), nullable=False),
|
|
||||||
sa.Column("slug", sa.String(length=255), nullable=False),
|
|
||||||
sa.Column("description", sa.Text(), nullable=True),
|
|
||||||
sa.Column("is_active", sa.Boolean(), nullable=False, server_default="true"),
|
|
||||||
sa.Column("settings", sa.JSON(), nullable=True),
|
|
||||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
|
||||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
|
||||||
sa.PrimaryKeyConstraint("id"),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create indexes for organizations
|
|
||||||
op.create_index("ix_organizations_name", "organizations", ["name"])
|
|
||||||
op.create_index("ix_organizations_slug", "organizations", ["slug"], unique=True)
|
|
||||||
op.create_index("ix_organizations_is_active", "organizations", ["is_active"])
|
|
||||||
op.create_index(
|
|
||||||
"ix_organizations_name_active", "organizations", ["name", "is_active"]
|
|
||||||
)
|
|
||||||
op.create_index(
|
|
||||||
"ix_organizations_slug_active", "organizations", ["slug", "is_active"]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create user_organizations junction table
|
|
||||||
op.create_table(
|
|
||||||
"user_organizations",
|
|
||||||
sa.Column("user_id", sa.UUID(), nullable=False),
|
|
||||||
sa.Column("organization_id", sa.UUID(), nullable=False),
|
|
||||||
sa.Column(
|
|
||||||
"role",
|
|
||||||
sa.Enum("OWNER", "ADMIN", "MEMBER", "GUEST", name="organizationrole"),
|
|
||||||
nullable=False,
|
|
||||||
server_default="MEMBER",
|
|
||||||
),
|
|
||||||
sa.Column("is_active", sa.Boolean(), nullable=False, server_default="true"),
|
|
||||||
sa.Column("custom_permissions", sa.String(length=500), nullable=True),
|
|
||||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
|
||||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
|
||||||
sa.PrimaryKeyConstraint("user_id", "organization_id"),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create foreign keys
|
|
||||||
op.create_foreign_key(
|
|
||||||
"fk_user_organizations_user_id",
|
|
||||||
"user_organizations",
|
|
||||||
"users",
|
|
||||||
["user_id"],
|
|
||||||
["id"],
|
|
||||||
ondelete="CASCADE",
|
|
||||||
)
|
|
||||||
op.create_foreign_key(
|
|
||||||
"fk_user_organizations_organization_id",
|
|
||||||
"user_organizations",
|
|
||||||
"organizations",
|
|
||||||
["organization_id"],
|
|
||||||
["id"],
|
|
||||||
ondelete="CASCADE",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create indexes for user_organizations
|
|
||||||
op.create_index("ix_user_organizations_role", "user_organizations", ["role"])
|
|
||||||
op.create_index(
|
|
||||||
"ix_user_organizations_is_active", "user_organizations", ["is_active"]
|
|
||||||
)
|
|
||||||
op.create_index(
|
|
||||||
"ix_user_org_user_active", "user_organizations", ["user_id", "is_active"]
|
|
||||||
)
|
|
||||||
op.create_index(
|
|
||||||
"ix_user_org_org_active", "user_organizations", ["organization_id", "is_active"]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
# Drop indexes for user_organizations
|
|
||||||
op.drop_index("ix_user_org_org_active", table_name="user_organizations")
|
|
||||||
op.drop_index("ix_user_org_user_active", table_name="user_organizations")
|
|
||||||
op.drop_index("ix_user_organizations_is_active", table_name="user_organizations")
|
|
||||||
op.drop_index("ix_user_organizations_role", table_name="user_organizations")
|
|
||||||
|
|
||||||
# Drop foreign keys
|
|
||||||
op.drop_constraint(
|
|
||||||
"fk_user_organizations_organization_id",
|
|
||||||
"user_organizations",
|
|
||||||
type_="foreignkey",
|
|
||||||
)
|
|
||||||
op.drop_constraint(
|
|
||||||
"fk_user_organizations_user_id", "user_organizations", type_="foreignkey"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Drop user_organizations table
|
|
||||||
op.drop_table("user_organizations")
|
|
||||||
|
|
||||||
# Drop indexes for organizations
|
|
||||||
op.drop_index("ix_organizations_slug_active", table_name="organizations")
|
|
||||||
op.drop_index("ix_organizations_name_active", table_name="organizations")
|
|
||||||
op.drop_index("ix_organizations_is_active", table_name="organizations")
|
|
||||||
op.drop_index("ix_organizations_slug", table_name="organizations")
|
|
||||||
op.drop_index("ix_organizations_name", table_name="organizations")
|
|
||||||
|
|
||||||
# Drop organizations table
|
|
||||||
op.drop_table("organizations")
|
|
||||||
|
|
||||||
# Drop enum type
|
|
||||||
op.execute("DROP TYPE IF EXISTS organizationrole")
|
|
||||||
@@ -1,12 +1,12 @@
|
|||||||
from fastapi import Depends, Header, HTTPException, status
|
from fastapi import Depends, Header, HTTPException, status
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
from fastapi.security.utils import get_authorization_scheme_param
|
from fastapi.security.utils import get_authorization_scheme_param
|
||||||
from sqlalchemy import select
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.core.auth import TokenExpiredError, TokenInvalidError, get_token_data
|
from app.core.auth import TokenExpiredError, TokenInvalidError, get_token_data
|
||||||
from app.core.database import get_db
|
from app.core.database import get_db
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
from app.repositories.user import user_repo
|
||||||
|
|
||||||
# OAuth2 configuration
|
# OAuth2 configuration
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||||
@@ -32,9 +32,8 @@ async def get_current_user(
|
|||||||
# Decode token and get user ID
|
# Decode token and get user ID
|
||||||
token_data = get_token_data(token)
|
token_data = get_token_data(token)
|
||||||
|
|
||||||
# Get user from database
|
# Get user from database via repository
|
||||||
result = await db.execute(select(User).where(User.id == token_data.user_id))
|
user = await user_repo.get(db, id=str(token_data.user_id))
|
||||||
user = result.scalar_one_or_none()
|
|
||||||
|
|
||||||
if not user:
|
if not user:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
@@ -144,8 +143,7 @@ async def get_optional_current_user(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
token_data = get_token_data(token)
|
token_data = get_token_data(token)
|
||||||
result = await db.execute(select(User).where(User.id == token_data.user_id))
|
user = await user_repo.get(db, id=str(token_data.user_id))
|
||||||
user = result.scalar_one_or_none()
|
|
||||||
if not user or not user.is_active:
|
if not user or not user.is_active:
|
||||||
return None
|
return None
|
||||||
return user
|
return user
|
||||||
|
|||||||
@@ -117,8 +117,9 @@ async def get_locale(
|
|||||||
if current_user and current_user.locale:
|
if current_user and current_user.locale:
|
||||||
# Validate that saved locale is still supported
|
# Validate that saved locale is still supported
|
||||||
# (in case SUPPORTED_LOCALES changed after user set preference)
|
# (in case SUPPORTED_LOCALES changed after user set preference)
|
||||||
if current_user.locale in SUPPORTED_LOCALES:
|
locale_value = str(current_user.locale)
|
||||||
return current_user.locale
|
if locale_value in SUPPORTED_LOCALES:
|
||||||
|
return locale_value
|
||||||
|
|
||||||
# Priority 2: Accept-Language header
|
# Priority 2: Accept-Language header
|
||||||
accept_language = request.headers.get("accept-language", "")
|
accept_language = request.headers.get("accept-language", "")
|
||||||
|
|||||||
@@ -15,9 +15,9 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||||||
|
|
||||||
from app.api.dependencies.auth import get_current_user
|
from app.api.dependencies.auth import get_current_user
|
||||||
from app.core.database import get_db
|
from app.core.database import get_db
|
||||||
from app.crud.organization import organization as organization_crud
|
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.models.user_organization import OrganizationRole
|
from app.models.user_organization import OrganizationRole
|
||||||
|
from app.services.organization_service import organization_service
|
||||||
|
|
||||||
|
|
||||||
def require_superuser(current_user: User = Depends(get_current_user)) -> User:
|
def require_superuser(current_user: User = Depends(get_current_user)) -> User:
|
||||||
@@ -81,7 +81,7 @@ class OrganizationPermission:
|
|||||||
return current_user
|
return current_user
|
||||||
|
|
||||||
# Get user's role in organization
|
# Get user's role in organization
|
||||||
user_role = await organization_crud.get_user_role_in_org(
|
user_role = await organization_service.get_user_role_in_org(
|
||||||
db, user_id=current_user.id, organization_id=organization_id
|
db, user_id=current_user.id, organization_id=organization_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -123,7 +123,7 @@ async def require_org_membership(
|
|||||||
if current_user.is_superuser:
|
if current_user.is_superuser:
|
||||||
return current_user
|
return current_user
|
||||||
|
|
||||||
user_role = await organization_crud.get_user_role_in_org(
|
user_role = await organization_service.get_user_role_in_org(
|
||||||
db, user_id=current_user.id, organization_id=organization_id
|
db, user_id=current_user.id, organization_id=organization_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
41
backend/app/api/dependencies/services.py
Normal file
41
backend/app/api/dependencies/services.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
# app/api/dependencies/services.py
|
||||||
|
"""FastAPI dependency functions for service singletons."""
|
||||||
|
|
||||||
|
from app.services import oauth_provider_service
|
||||||
|
from app.services.auth_service import AuthService
|
||||||
|
from app.services.oauth_service import OAuthService
|
||||||
|
from app.services.organization_service import OrganizationService, organization_service
|
||||||
|
from app.services.session_service import SessionService, session_service
|
||||||
|
from app.services.user_service import UserService, user_service
|
||||||
|
|
||||||
|
|
||||||
|
def get_auth_service() -> AuthService:
|
||||||
|
"""Return the AuthService singleton for dependency injection."""
|
||||||
|
from app.services.auth_service import AuthService as _AuthService
|
||||||
|
|
||||||
|
return _AuthService()
|
||||||
|
|
||||||
|
|
||||||
|
def get_user_service() -> UserService:
|
||||||
|
"""Return the UserService singleton for dependency injection."""
|
||||||
|
return user_service
|
||||||
|
|
||||||
|
|
||||||
|
def get_organization_service() -> OrganizationService:
|
||||||
|
"""Return the OrganizationService singleton for dependency injection."""
|
||||||
|
return organization_service
|
||||||
|
|
||||||
|
|
||||||
|
def get_session_service() -> SessionService:
|
||||||
|
"""Return the SessionService singleton for dependency injection."""
|
||||||
|
return session_service
|
||||||
|
|
||||||
|
|
||||||
|
def get_oauth_service() -> OAuthService:
|
||||||
|
"""Return OAuthService for dependency injection."""
|
||||||
|
return OAuthService()
|
||||||
|
|
||||||
|
|
||||||
|
def get_oauth_provider_service():
|
||||||
|
"""Return the oauth_provider_service module for dependency injection."""
|
||||||
|
return oauth_provider_service
|
||||||
@@ -1,9 +1,21 @@
|
|||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
|
|
||||||
from app.api.routes import admin, auth, organizations, sessions, users
|
from app.api.routes import (
|
||||||
|
admin,
|
||||||
|
auth,
|
||||||
|
oauth,
|
||||||
|
oauth_provider,
|
||||||
|
organizations,
|
||||||
|
sessions,
|
||||||
|
users,
|
||||||
|
)
|
||||||
|
|
||||||
api_router = APIRouter()
|
api_router = APIRouter()
|
||||||
api_router.include_router(auth.router, prefix="/auth", tags=["Authentication"])
|
api_router.include_router(auth.router, prefix="/auth", tags=["Authentication"])
|
||||||
|
api_router.include_router(oauth.router, prefix="/oauth", tags=["OAuth"])
|
||||||
|
api_router.include_router(
|
||||||
|
oauth_provider.router, prefix="/oauth", tags=["OAuth Provider"]
|
||||||
|
)
|
||||||
api_router.include_router(users.router, prefix="/users", tags=["Users"])
|
api_router.include_router(users.router, prefix="/users", tags=["Users"])
|
||||||
api_router.include_router(sessions.router, prefix="/sessions", tags=["Sessions"])
|
api_router.include_router(sessions.router, prefix="/sessions", tags=["Sessions"])
|
||||||
api_router.include_router(admin.router, prefix="/admin", tags=["Admin"])
|
api_router.include_router(admin.router, prefix="/admin", tags=["Admin"])
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ for managing the application.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
@@ -23,9 +24,7 @@ from app.core.exceptions import (
|
|||||||
ErrorCode,
|
ErrorCode,
|
||||||
NotFoundError,
|
NotFoundError,
|
||||||
)
|
)
|
||||||
from app.crud.organization import organization as organization_crud
|
from app.core.repository_exceptions import DuplicateEntryError
|
||||||
from app.crud.session import session as session_crud
|
|
||||||
from app.crud.user import user as user_crud
|
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.models.user_organization import OrganizationRole
|
from app.models.user_organization import OrganizationRole
|
||||||
from app.schemas.common import (
|
from app.schemas.common import (
|
||||||
@@ -43,6 +42,9 @@ from app.schemas.organizations import (
|
|||||||
)
|
)
|
||||||
from app.schemas.sessions import AdminSessionResponse
|
from app.schemas.sessions import AdminSessionResponse
|
||||||
from app.schemas.users import UserCreate, UserResponse, UserUpdate
|
from app.schemas.users import UserCreate, UserResponse, UserUpdate
|
||||||
|
from app.services.organization_service import organization_service
|
||||||
|
from app.services.session_service import session_service
|
||||||
|
from app.services.user_service import user_service
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -63,7 +65,7 @@ class BulkUserAction(BaseModel):
|
|||||||
|
|
||||||
action: BulkAction = Field(..., description="Action to perform on selected users")
|
action: BulkAction = Field(..., description="Action to perform on selected users")
|
||||||
user_ids: list[UUID] = Field(
|
user_ids: list[UUID] = Field(
|
||||||
..., min_items=1, max_items=100, description="List of user IDs (max 100)"
|
..., min_length=1, max_length=100, description="List of user IDs (max 100)"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -80,6 +82,186 @@ class BulkActionResult(BaseModel):
|
|||||||
# ===== User Management Endpoints =====
|
# ===== User Management Endpoints =====
|
||||||
|
|
||||||
|
|
||||||
|
class UserGrowthData(BaseModel):
|
||||||
|
date: str
|
||||||
|
total_users: int
|
||||||
|
active_users: int
|
||||||
|
|
||||||
|
|
||||||
|
class OrgDistributionData(BaseModel):
|
||||||
|
name: str
|
||||||
|
value: int
|
||||||
|
|
||||||
|
|
||||||
|
class RegistrationActivityData(BaseModel):
|
||||||
|
date: str
|
||||||
|
registrations: int
|
||||||
|
|
||||||
|
|
||||||
|
class UserStatusData(BaseModel):
|
||||||
|
name: str
|
||||||
|
value: int
|
||||||
|
|
||||||
|
|
||||||
|
class AdminStatsResponse(BaseModel):
|
||||||
|
user_growth: list[UserGrowthData]
|
||||||
|
organization_distribution: list[OrgDistributionData]
|
||||||
|
registration_activity: list[RegistrationActivityData]
|
||||||
|
user_status: list[UserStatusData]
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_demo_stats() -> AdminStatsResponse: # pragma: no cover
|
||||||
|
"""Generate demo statistics for empty databases."""
|
||||||
|
from random import randint
|
||||||
|
|
||||||
|
# Demo user growth (last 30 days)
|
||||||
|
user_growth = []
|
||||||
|
total = 10
|
||||||
|
for i in range(29, -1, -1):
|
||||||
|
date = datetime.now(UTC) - timedelta(days=i)
|
||||||
|
total += randint(0, 3) # noqa: S311
|
||||||
|
user_growth.append(
|
||||||
|
UserGrowthData(
|
||||||
|
date=date.strftime("%b %d"),
|
||||||
|
total_users=total,
|
||||||
|
active_users=int(total * 0.85),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Demo organization distribution
|
||||||
|
org_dist = [
|
||||||
|
OrgDistributionData(name="Engineering", value=12),
|
||||||
|
OrgDistributionData(name="Product", value=8),
|
||||||
|
OrgDistributionData(name="Sales", value=15),
|
||||||
|
OrgDistributionData(name="Marketing", value=6),
|
||||||
|
OrgDistributionData(name="Support", value=5),
|
||||||
|
OrgDistributionData(name="Operations", value=4),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Demo registration activity (last 14 days)
|
||||||
|
registration_activity = []
|
||||||
|
for i in range(13, -1, -1):
|
||||||
|
date = datetime.now(UTC) - timedelta(days=i)
|
||||||
|
registration_activity.append(
|
||||||
|
RegistrationActivityData(
|
||||||
|
date=date.strftime("%b %d"),
|
||||||
|
registrations=randint(0, 5), # noqa: S311
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Demo user status
|
||||||
|
user_status = [
|
||||||
|
UserStatusData(name="Active", value=45),
|
||||||
|
UserStatusData(name="Inactive", value=5),
|
||||||
|
]
|
||||||
|
|
||||||
|
return AdminStatsResponse(
|
||||||
|
user_growth=user_growth,
|
||||||
|
organization_distribution=org_dist,
|
||||||
|
registration_activity=registration_activity,
|
||||||
|
user_status=user_status,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/stats",
|
||||||
|
response_model=AdminStatsResponse,
|
||||||
|
summary="Admin: Get Dashboard Stats",
|
||||||
|
description="Get aggregated statistics for the admin dashboard (admin only)",
|
||||||
|
operation_id="admin_get_stats",
|
||||||
|
)
|
||||||
|
async def admin_get_stats(
|
||||||
|
admin: User = Depends(require_superuser),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
) -> Any:
|
||||||
|
"""Get admin dashboard statistics with real data from database."""
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
stats = await user_service.get_stats(db)
|
||||||
|
total_users = stats["total_users"]
|
||||||
|
active_count = stats["active_count"]
|
||||||
|
inactive_count = stats["inactive_count"]
|
||||||
|
all_users = stats["all_users"]
|
||||||
|
|
||||||
|
# If database is essentially empty (only admin user), return demo data
|
||||||
|
if total_users <= 1 and settings.DEMO_MODE: # pragma: no cover
|
||||||
|
logger.info("Returning demo stats data (empty database in demo mode)")
|
||||||
|
return _generate_demo_stats()
|
||||||
|
|
||||||
|
# 1. User Growth (Last 30 days)
|
||||||
|
user_growth = []
|
||||||
|
for i in range(29, -1, -1):
|
||||||
|
date = datetime.now(UTC) - timedelta(days=i)
|
||||||
|
date_start = date.replace(hour=0, minute=0, second=0, microsecond=0, tzinfo=UTC)
|
||||||
|
date_end = date_start + timedelta(days=1)
|
||||||
|
|
||||||
|
total_users_on_date = sum(
|
||||||
|
1
|
||||||
|
for u in all_users
|
||||||
|
if u.created_at and u.created_at.replace(tzinfo=UTC) < date_end
|
||||||
|
)
|
||||||
|
active_users_on_date = sum(
|
||||||
|
1
|
||||||
|
for u in all_users
|
||||||
|
if u.created_at
|
||||||
|
and u.created_at.replace(tzinfo=UTC) < date_end
|
||||||
|
and u.is_active
|
||||||
|
)
|
||||||
|
|
||||||
|
user_growth.append(
|
||||||
|
UserGrowthData(
|
||||||
|
date=date.strftime("%b %d"),
|
||||||
|
total_users=total_users_on_date,
|
||||||
|
active_users=active_users_on_date,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Organization Distribution - Top 6 organizations by member count
|
||||||
|
org_rows = await organization_service.get_org_distribution(db, limit=6)
|
||||||
|
org_dist = [OrgDistributionData(name=r["name"], value=r["value"]) for r in org_rows]
|
||||||
|
|
||||||
|
# 3. User Registration Activity (Last 14 days)
|
||||||
|
registration_activity = []
|
||||||
|
for i in range(13, -1, -1):
|
||||||
|
date = datetime.now(UTC) - timedelta(days=i)
|
||||||
|
date_start = date.replace(hour=0, minute=0, second=0, microsecond=0, tzinfo=UTC)
|
||||||
|
date_end = date_start + timedelta(days=1)
|
||||||
|
|
||||||
|
day_registrations = sum(
|
||||||
|
1
|
||||||
|
for u in all_users
|
||||||
|
if u.created_at
|
||||||
|
and date_start <= u.created_at.replace(tzinfo=UTC) < date_end
|
||||||
|
)
|
||||||
|
|
||||||
|
registration_activity.append(
|
||||||
|
RegistrationActivityData(
|
||||||
|
date=date.strftime("%b %d"),
|
||||||
|
registrations=day_registrations,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. User Status - Active vs Inactive
|
||||||
|
logger.info(
|
||||||
|
"User status counts - Active: %s, Inactive: %s", active_count, inactive_count
|
||||||
|
)
|
||||||
|
|
||||||
|
user_status = [
|
||||||
|
UserStatusData(name="Active", value=active_count),
|
||||||
|
UserStatusData(name="Inactive", value=inactive_count),
|
||||||
|
]
|
||||||
|
|
||||||
|
return AdminStatsResponse(
|
||||||
|
user_growth=user_growth,
|
||||||
|
organization_distribution=org_dist,
|
||||||
|
registration_activity=registration_activity,
|
||||||
|
user_status=user_status,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ===== User Management Endpoints =====
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/users",
|
"/users",
|
||||||
response_model=PaginatedResponse[UserResponse],
|
response_model=PaginatedResponse[UserResponse],
|
||||||
@@ -110,7 +292,7 @@ async def admin_list_users(
|
|||||||
filters["is_superuser"] = is_superuser
|
filters["is_superuser"] = is_superuser
|
||||||
|
|
||||||
# Get users with search
|
# Get users with search
|
||||||
users, total = await user_crud.get_multi_with_total(
|
users, total = await user_service.list_users(
|
||||||
db,
|
db,
|
||||||
skip=pagination.offset,
|
skip=pagination.offset,
|
||||||
limit=pagination.limit,
|
limit=pagination.limit,
|
||||||
@@ -130,7 +312,7 @@ async def admin_list_users(
|
|||||||
return PaginatedResponse(data=users, pagination=pagination_meta)
|
return PaginatedResponse(data=users, pagination=pagination_meta)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error listing users (admin): {e!s}", exc_info=True)
|
logger.exception("Error listing users (admin): %s", e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
@@ -153,14 +335,14 @@ async def admin_create_user(
|
|||||||
Allows setting is_superuser and other fields.
|
Allows setting is_superuser and other fields.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
user = await user_crud.create(db, obj_in=user_in)
|
user = await user_service.create_user(db, user_in)
|
||||||
logger.info(f"Admin {admin.email} created user {user.email}")
|
logger.info("Admin %s created user %s", admin.email, user.email)
|
||||||
return user
|
return user
|
||||||
except ValueError as e:
|
except DuplicateEntryError as e:
|
||||||
logger.warning(f"Failed to create user: {e!s}")
|
logger.warning("Failed to create user: %s", e)
|
||||||
raise NotFoundError(message=str(e), error_code=ErrorCode.USER_ALREADY_EXISTS)
|
raise DuplicateError(message=str(e), error_code=ErrorCode.USER_ALREADY_EXISTS)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error creating user (admin): {e!s}", exc_info=True)
|
logger.exception("Error creating user (admin): %s", e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
@@ -177,11 +359,7 @@ async def admin_get_user(
|
|||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Get detailed information about a specific user."""
|
"""Get detailed information about a specific user."""
|
||||||
user = await user_crud.get(db, id=user_id)
|
user = await user_service.get_user(db, str(user_id))
|
||||||
if not user:
|
|
||||||
raise NotFoundError(
|
|
||||||
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
|
|
||||||
)
|
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
@@ -200,20 +378,13 @@ async def admin_update_user(
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
"""Update user information with admin privileges."""
|
"""Update user information with admin privileges."""
|
||||||
try:
|
try:
|
||||||
user = await user_crud.get(db, id=user_id)
|
user = await user_service.get_user(db, str(user_id))
|
||||||
if not user:
|
updated_user = await user_service.update_user(db, user=user, obj_in=user_in)
|
||||||
raise NotFoundError(
|
logger.info("Admin %s updated user %s", admin.email, updated_user.email)
|
||||||
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
|
|
||||||
)
|
|
||||||
|
|
||||||
updated_user = await user_crud.update(db, db_obj=user, obj_in=user_in)
|
|
||||||
logger.info(f"Admin {admin.email} updated user {updated_user.email}")
|
|
||||||
return updated_user
|
return updated_user
|
||||||
|
|
||||||
except NotFoundError:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error updating user (admin): {e!s}", exc_info=True)
|
logger.exception("Error updating user (admin): %s", e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
@@ -231,11 +402,7 @@ async def admin_delete_user(
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
"""Soft delete a user (sets deleted_at timestamp)."""
|
"""Soft delete a user (sets deleted_at timestamp)."""
|
||||||
try:
|
try:
|
||||||
user = await user_crud.get(db, id=user_id)
|
user = await user_service.get_user(db, str(user_id))
|
||||||
if not user:
|
|
||||||
raise NotFoundError(
|
|
||||||
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prevent deleting yourself
|
# Prevent deleting yourself
|
||||||
if user.id == admin.id:
|
if user.id == admin.id:
|
||||||
@@ -245,17 +412,15 @@ async def admin_delete_user(
|
|||||||
error_code=ErrorCode.OPERATION_FORBIDDEN,
|
error_code=ErrorCode.OPERATION_FORBIDDEN,
|
||||||
)
|
)
|
||||||
|
|
||||||
await user_crud.soft_delete(db, id=user_id)
|
await user_service.soft_delete_user(db, str(user_id))
|
||||||
logger.info(f"Admin {admin.email} deleted user {user.email}")
|
logger.info("Admin %s deleted user %s", admin.email, user.email)
|
||||||
|
|
||||||
return MessageResponse(
|
return MessageResponse(
|
||||||
success=True, message=f"User {user.email} has been deleted"
|
success=True, message=f"User {user.email} has been deleted"
|
||||||
)
|
)
|
||||||
|
|
||||||
except NotFoundError:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error deleting user (admin): {e!s}", exc_info=True)
|
logger.exception("Error deleting user (admin): %s", e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
@@ -273,23 +438,16 @@ async def admin_activate_user(
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
"""Activate a user account."""
|
"""Activate a user account."""
|
||||||
try:
|
try:
|
||||||
user = await user_crud.get(db, id=user_id)
|
user = await user_service.get_user(db, str(user_id))
|
||||||
if not user:
|
await user_service.update_user(db, user=user, obj_in={"is_active": True})
|
||||||
raise NotFoundError(
|
logger.info("Admin %s activated user %s", admin.email, user.email)
|
||||||
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
|
|
||||||
)
|
|
||||||
|
|
||||||
await user_crud.update(db, db_obj=user, obj_in={"is_active": True})
|
|
||||||
logger.info(f"Admin {admin.email} activated user {user.email}")
|
|
||||||
|
|
||||||
return MessageResponse(
|
return MessageResponse(
|
||||||
success=True, message=f"User {user.email} has been activated"
|
success=True, message=f"User {user.email} has been activated"
|
||||||
)
|
)
|
||||||
|
|
||||||
except NotFoundError:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error activating user (admin): {e!s}", exc_info=True)
|
logger.exception("Error activating user (admin): %s", e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
@@ -307,11 +465,7 @@ async def admin_deactivate_user(
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
"""Deactivate a user account."""
|
"""Deactivate a user account."""
|
||||||
try:
|
try:
|
||||||
user = await user_crud.get(db, id=user_id)
|
user = await user_service.get_user(db, str(user_id))
|
||||||
if not user:
|
|
||||||
raise NotFoundError(
|
|
||||||
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prevent deactivating yourself
|
# Prevent deactivating yourself
|
||||||
if user.id == admin.id:
|
if user.id == admin.id:
|
||||||
@@ -321,17 +475,15 @@ async def admin_deactivate_user(
|
|||||||
error_code=ErrorCode.OPERATION_FORBIDDEN,
|
error_code=ErrorCode.OPERATION_FORBIDDEN,
|
||||||
)
|
)
|
||||||
|
|
||||||
await user_crud.update(db, db_obj=user, obj_in={"is_active": False})
|
await user_service.update_user(db, user=user, obj_in={"is_active": False})
|
||||||
logger.info(f"Admin {admin.email} deactivated user {user.email}")
|
logger.info("Admin %s deactivated user %s", admin.email, user.email)
|
||||||
|
|
||||||
return MessageResponse(
|
return MessageResponse(
|
||||||
success=True, message=f"User {user.email} has been deactivated"
|
success=True, message=f"User {user.email} has been deactivated"
|
||||||
)
|
)
|
||||||
|
|
||||||
except NotFoundError:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error deactivating user (admin): {e!s}", exc_info=True)
|
logger.exception("Error deactivating user (admin): %s", e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
@@ -356,19 +508,19 @@ async def admin_bulk_user_action(
|
|||||||
try:
|
try:
|
||||||
# Use efficient bulk operations instead of loop
|
# Use efficient bulk operations instead of loop
|
||||||
if bulk_action.action == BulkAction.ACTIVATE:
|
if bulk_action.action == BulkAction.ACTIVATE:
|
||||||
affected_count = await user_crud.bulk_update_status(
|
affected_count = await user_service.bulk_update_status(
|
||||||
db, user_ids=bulk_action.user_ids, is_active=True
|
db, user_ids=bulk_action.user_ids, is_active=True
|
||||||
)
|
)
|
||||||
elif bulk_action.action == BulkAction.DEACTIVATE:
|
elif bulk_action.action == BulkAction.DEACTIVATE:
|
||||||
affected_count = await user_crud.bulk_update_status(
|
affected_count = await user_service.bulk_update_status(
|
||||||
db, user_ids=bulk_action.user_ids, is_active=False
|
db, user_ids=bulk_action.user_ids, is_active=False
|
||||||
)
|
)
|
||||||
elif bulk_action.action == BulkAction.DELETE:
|
elif bulk_action.action == BulkAction.DELETE:
|
||||||
# bulk_soft_delete automatically excludes the admin user
|
# bulk_soft_delete automatically excludes the admin user
|
||||||
affected_count = await user_crud.bulk_soft_delete(
|
affected_count = await user_service.bulk_soft_delete(
|
||||||
db, user_ids=bulk_action.user_ids, exclude_user_id=admin.id
|
db, user_ids=bulk_action.user_ids, exclude_user_id=admin.id
|
||||||
)
|
)
|
||||||
else:
|
else: # pragma: no cover
|
||||||
raise ValueError(f"Unsupported bulk action: {bulk_action.action}")
|
raise ValueError(f"Unsupported bulk action: {bulk_action.action}")
|
||||||
|
|
||||||
# Calculate failed count (requested - affected)
|
# Calculate failed count (requested - affected)
|
||||||
@@ -376,8 +528,11 @@ async def admin_bulk_user_action(
|
|||||||
failed_count = requested_count - affected_count
|
failed_count = requested_count - affected_count
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Admin {admin.email} performed bulk {bulk_action.action.value} "
|
"Admin %s performed bulk %s on %s users (%s skipped/failed)",
|
||||||
f"on {affected_count} users ({failed_count} skipped/failed)"
|
admin.email,
|
||||||
|
bulk_action.action.value,
|
||||||
|
affected_count,
|
||||||
|
failed_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
return BulkActionResult(
|
return BulkActionResult(
|
||||||
@@ -388,8 +543,8 @@ async def admin_bulk_user_action(
|
|||||||
failed_ids=None, # Bulk operations don't track individual failures
|
failed_ids=None, # Bulk operations don't track individual failures
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e: # pragma: no cover
|
||||||
logger.error(f"Error in bulk user action: {e!s}", exc_info=True)
|
logger.exception("Error in bulk user action: %s", e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
@@ -413,7 +568,7 @@ async def admin_list_organizations(
|
|||||||
"""List all organizations with filtering and search."""
|
"""List all organizations with filtering and search."""
|
||||||
try:
|
try:
|
||||||
# Use optimized method that gets member counts in single query (no N+1)
|
# Use optimized method that gets member counts in single query (no N+1)
|
||||||
orgs_with_data, total = await organization_crud.get_multi_with_member_counts(
|
orgs_with_data, total = await organization_service.get_multi_with_member_counts(
|
||||||
db,
|
db,
|
||||||
skip=pagination.offset,
|
skip=pagination.offset,
|
||||||
limit=pagination.limit,
|
limit=pagination.limit,
|
||||||
@@ -450,7 +605,7 @@ async def admin_list_organizations(
|
|||||||
return PaginatedResponse(data=orgs_with_count, pagination=pagination_meta)
|
return PaginatedResponse(data=orgs_with_count, pagination=pagination_meta)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error listing organizations (admin): {e!s}", exc_info=True)
|
logger.exception("Error listing organizations (admin): %s", e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
@@ -469,8 +624,8 @@ async def admin_create_organization(
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
"""Create a new organization."""
|
"""Create a new organization."""
|
||||||
try:
|
try:
|
||||||
org = await organization_crud.create(db, obj_in=org_in)
|
org = await organization_service.create_organization(db, obj_in=org_in)
|
||||||
logger.info(f"Admin {admin.email} created organization {org.name}")
|
logger.info("Admin %s created organization %s", admin.email, org.name)
|
||||||
|
|
||||||
# Add member count
|
# Add member count
|
||||||
org_dict = {
|
org_dict = {
|
||||||
@@ -486,11 +641,11 @@ async def admin_create_organization(
|
|||||||
}
|
}
|
||||||
return OrganizationResponse(**org_dict)
|
return OrganizationResponse(**org_dict)
|
||||||
|
|
||||||
except ValueError as e:
|
except DuplicateEntryError as e:
|
||||||
logger.warning(f"Failed to create organization: {e!s}")
|
logger.warning("Failed to create organization: %s", e)
|
||||||
raise NotFoundError(message=str(e), error_code=ErrorCode.ALREADY_EXISTS)
|
raise DuplicateError(message=str(e), error_code=ErrorCode.ALREADY_EXISTS)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error creating organization (admin): {e!s}", exc_info=True)
|
logger.exception("Error creating organization (admin): %s", e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
@@ -507,12 +662,7 @@ async def admin_get_organization(
|
|||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Get detailed information about a specific organization."""
|
"""Get detailed information about a specific organization."""
|
||||||
org = await organization_crud.get(db, id=org_id)
|
org = await organization_service.get_organization(db, str(org_id))
|
||||||
if not org:
|
|
||||||
raise NotFoundError(
|
|
||||||
message=f"Organization {org_id} not found", error_code=ErrorCode.NOT_FOUND
|
|
||||||
)
|
|
||||||
|
|
||||||
org_dict = {
|
org_dict = {
|
||||||
"id": org.id,
|
"id": org.id,
|
||||||
"name": org.name,
|
"name": org.name,
|
||||||
@@ -522,7 +672,7 @@ async def admin_get_organization(
|
|||||||
"settings": org.settings,
|
"settings": org.settings,
|
||||||
"created_at": org.created_at,
|
"created_at": org.created_at,
|
||||||
"updated_at": org.updated_at,
|
"updated_at": org.updated_at,
|
||||||
"member_count": await organization_crud.get_member_count(
|
"member_count": await organization_service.get_member_count(
|
||||||
db, organization_id=org.id
|
db, organization_id=org.id
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
@@ -544,15 +694,11 @@ async def admin_update_organization(
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
"""Update organization information."""
|
"""Update organization information."""
|
||||||
try:
|
try:
|
||||||
org = await organization_crud.get(db, id=org_id)
|
org = await organization_service.get_organization(db, str(org_id))
|
||||||
if not org:
|
updated_org = await organization_service.update_organization(
|
||||||
raise NotFoundError(
|
db, org=org, obj_in=org_in
|
||||||
message=f"Organization {org_id} not found",
|
)
|
||||||
error_code=ErrorCode.NOT_FOUND,
|
logger.info("Admin %s updated organization %s", admin.email, updated_org.name)
|
||||||
)
|
|
||||||
|
|
||||||
updated_org = await organization_crud.update(db, db_obj=org, obj_in=org_in)
|
|
||||||
logger.info(f"Admin {admin.email} updated organization {updated_org.name}")
|
|
||||||
|
|
||||||
org_dict = {
|
org_dict = {
|
||||||
"id": updated_org.id,
|
"id": updated_org.id,
|
||||||
@@ -563,16 +709,14 @@ async def admin_update_organization(
|
|||||||
"settings": updated_org.settings,
|
"settings": updated_org.settings,
|
||||||
"created_at": updated_org.created_at,
|
"created_at": updated_org.created_at,
|
||||||
"updated_at": updated_org.updated_at,
|
"updated_at": updated_org.updated_at,
|
||||||
"member_count": await organization_crud.get_member_count(
|
"member_count": await organization_service.get_member_count(
|
||||||
db, organization_id=updated_org.id
|
db, organization_id=updated_org.id
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
return OrganizationResponse(**org_dict)
|
return OrganizationResponse(**org_dict)
|
||||||
|
|
||||||
except NotFoundError:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error updating organization (admin): {e!s}", exc_info=True)
|
logger.exception("Error updating organization (admin): %s", e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
@@ -590,24 +734,16 @@ async def admin_delete_organization(
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
"""Delete an organization and all its relationships."""
|
"""Delete an organization and all its relationships."""
|
||||||
try:
|
try:
|
||||||
org = await organization_crud.get(db, id=org_id)
|
org = await organization_service.get_organization(db, str(org_id))
|
||||||
if not org:
|
await organization_service.remove_organization(db, str(org_id))
|
||||||
raise NotFoundError(
|
logger.info("Admin %s deleted organization %s", admin.email, org.name)
|
||||||
message=f"Organization {org_id} not found",
|
|
||||||
error_code=ErrorCode.NOT_FOUND,
|
|
||||||
)
|
|
||||||
|
|
||||||
await organization_crud.remove(db, id=org_id)
|
|
||||||
logger.info(f"Admin {admin.email} deleted organization {org.name}")
|
|
||||||
|
|
||||||
return MessageResponse(
|
return MessageResponse(
|
||||||
success=True, message=f"Organization {org.name} has been deleted"
|
success=True, message=f"Organization {org.name} has been deleted"
|
||||||
)
|
)
|
||||||
|
|
||||||
except NotFoundError:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error deleting organization (admin): {e!s}", exc_info=True)
|
logger.exception("Error deleting organization (admin): %s", e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
@@ -627,14 +763,8 @@ async def admin_list_organization_members(
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
"""List all members of an organization."""
|
"""List all members of an organization."""
|
||||||
try:
|
try:
|
||||||
org = await organization_crud.get(db, id=org_id)
|
await organization_service.get_organization(db, str(org_id)) # validates exists
|
||||||
if not org:
|
members, total = await organization_service.get_organization_members(
|
||||||
raise NotFoundError(
|
|
||||||
message=f"Organization {org_id} not found",
|
|
||||||
error_code=ErrorCode.NOT_FOUND,
|
|
||||||
)
|
|
||||||
|
|
||||||
members, total = await organization_crud.get_organization_members(
|
|
||||||
db,
|
db,
|
||||||
organization_id=org_id,
|
organization_id=org_id,
|
||||||
skip=pagination.offset,
|
skip=pagination.offset,
|
||||||
@@ -657,9 +787,7 @@ async def admin_list_organization_members(
|
|||||||
except NotFoundError:
|
except NotFoundError:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.exception("Error listing organization members (admin): %s", e)
|
||||||
f"Error listing organization members (admin): {e!s}", exc_info=True
|
|
||||||
)
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
@@ -687,45 +815,32 @@ async def admin_add_organization_member(
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
"""Add a user to an organization."""
|
"""Add a user to an organization."""
|
||||||
try:
|
try:
|
||||||
org = await organization_crud.get(db, id=org_id)
|
org = await organization_service.get_organization(db, str(org_id))
|
||||||
if not org:
|
user = await user_service.get_user(db, str(request.user_id))
|
||||||
raise NotFoundError(
|
|
||||||
message=f"Organization {org_id} not found",
|
|
||||||
error_code=ErrorCode.NOT_FOUND,
|
|
||||||
)
|
|
||||||
|
|
||||||
user = await user_crud.get(db, id=request.user_id)
|
await organization_service.add_member(
|
||||||
if not user:
|
|
||||||
raise NotFoundError(
|
|
||||||
message=f"User {request.user_id} not found",
|
|
||||||
error_code=ErrorCode.USER_NOT_FOUND,
|
|
||||||
)
|
|
||||||
|
|
||||||
await organization_crud.add_user(
|
|
||||||
db, organization_id=org_id, user_id=request.user_id, role=request.role
|
db, organization_id=org_id, user_id=request.user_id, role=request.role
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Admin {admin.email} added user {user.email} to organization {org.name} "
|
"Admin %s added user %s to organization %s with role %s",
|
||||||
f"with role {request.role.value}"
|
admin.email,
|
||||||
|
user.email,
|
||||||
|
org.name,
|
||||||
|
request.role.value,
|
||||||
)
|
)
|
||||||
|
|
||||||
return MessageResponse(
|
return MessageResponse(
|
||||||
success=True, message=f"User {user.email} added to organization {org.name}"
|
success=True, message=f"User {user.email} added to organization {org.name}"
|
||||||
)
|
)
|
||||||
|
|
||||||
except ValueError as e:
|
except DuplicateEntryError as e:
|
||||||
logger.warning(f"Failed to add user to organization: {e!s}")
|
logger.warning("Failed to add user to organization: %s", e)
|
||||||
# Use DuplicateError for "already exists" scenarios
|
|
||||||
raise DuplicateError(
|
raise DuplicateError(
|
||||||
message=str(e), error_code=ErrorCode.USER_ALREADY_EXISTS, field="user_id"
|
message=str(e), error_code=ErrorCode.USER_ALREADY_EXISTS, field="user_id"
|
||||||
)
|
)
|
||||||
except NotFoundError:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.exception("Error adding member to organization (admin): %s", e)
|
||||||
f"Error adding member to organization (admin): {e!s}", exc_info=True
|
|
||||||
)
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
@@ -744,20 +859,10 @@ async def admin_remove_organization_member(
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
"""Remove a user from an organization."""
|
"""Remove a user from an organization."""
|
||||||
try:
|
try:
|
||||||
org = await organization_crud.get(db, id=org_id)
|
org = await organization_service.get_organization(db, str(org_id))
|
||||||
if not org:
|
user = await user_service.get_user(db, str(user_id))
|
||||||
raise NotFoundError(
|
|
||||||
message=f"Organization {org_id} not found",
|
|
||||||
error_code=ErrorCode.NOT_FOUND,
|
|
||||||
)
|
|
||||||
|
|
||||||
user = await user_crud.get(db, id=user_id)
|
success = await organization_service.remove_member(
|
||||||
if not user:
|
|
||||||
raise NotFoundError(
|
|
||||||
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
|
|
||||||
)
|
|
||||||
|
|
||||||
success = await organization_crud.remove_user(
|
|
||||||
db, organization_id=org_id, user_id=user_id
|
db, organization_id=org_id, user_id=user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -768,7 +873,10 @@ async def admin_remove_organization_member(
|
|||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Admin {admin.email} removed user {user.email} from organization {org.name}"
|
"Admin %s removed user %s from organization %s",
|
||||||
|
admin.email,
|
||||||
|
user.email,
|
||||||
|
org.name,
|
||||||
)
|
)
|
||||||
|
|
||||||
return MessageResponse(
|
return MessageResponse(
|
||||||
@@ -778,10 +886,8 @@ async def admin_remove_organization_member(
|
|||||||
|
|
||||||
except NotFoundError:
|
except NotFoundError:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e: # pragma: no cover
|
||||||
logger.error(
|
logger.exception("Error removing member from organization (admin): %s", e)
|
||||||
f"Error removing member from organization (admin): {e!s}", exc_info=True
|
|
||||||
)
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
@@ -811,7 +917,7 @@ async def admin_list_sessions(
|
|||||||
"""List all sessions across all users with filtering and pagination."""
|
"""List all sessions across all users with filtering and pagination."""
|
||||||
try:
|
try:
|
||||||
# Get sessions with user info (eager loaded to prevent N+1)
|
# Get sessions with user info (eager loaded to prevent N+1)
|
||||||
sessions, total = await session_crud.get_all_sessions(
|
sessions, total = await session_service.get_all_sessions(
|
||||||
db,
|
db,
|
||||||
skip=pagination.offset,
|
skip=pagination.offset,
|
||||||
limit=pagination.limit,
|
limit=pagination.limit,
|
||||||
@@ -850,7 +956,10 @@ async def admin_list_sessions(
|
|||||||
session_responses.append(session_response)
|
session_responses.append(session_response)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Admin {admin.email} listed {len(session_responses)} sessions (total: {total})"
|
"Admin %s listed %s sessions (total: %s)",
|
||||||
|
admin.email,
|
||||||
|
len(session_responses),
|
||||||
|
total,
|
||||||
)
|
)
|
||||||
|
|
||||||
pagination_meta = create_pagination_meta(
|
pagination_meta = create_pagination_meta(
|
||||||
@@ -862,6 +971,6 @@ async def admin_list_sessions(
|
|||||||
|
|
||||||
return PaginatedResponse(data=session_responses, pagination=pagination_meta)
|
return PaginatedResponse(data=session_responses, pagination=pagination_meta)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e: # pragma: no cover
|
||||||
logger.error(f"Error listing sessions (admin): {e!s}", exc_info=True)
|
logger.exception("Error listing sessions (admin): %s", e)
|
||||||
raise
|
raise
|
||||||
|
|||||||
@@ -15,16 +15,14 @@ from app.core.auth import (
|
|||||||
TokenExpiredError,
|
TokenExpiredError,
|
||||||
TokenInvalidError,
|
TokenInvalidError,
|
||||||
decode_token,
|
decode_token,
|
||||||
get_password_hash,
|
|
||||||
)
|
)
|
||||||
from app.core.database import get_db
|
from app.core.database import get_db
|
||||||
from app.core.exceptions import (
|
from app.core.exceptions import (
|
||||||
AuthenticationError as AuthError,
|
AuthenticationError as AuthError,
|
||||||
DatabaseError,
|
DatabaseError,
|
||||||
|
DuplicateError,
|
||||||
ErrorCode,
|
ErrorCode,
|
||||||
)
|
)
|
||||||
from app.crud.session import session as session_crud
|
|
||||||
from app.crud.user import user as user_crud
|
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.schemas.common import MessageResponse
|
from app.schemas.common import MessageResponse
|
||||||
from app.schemas.sessions import LogoutRequest, SessionCreate
|
from app.schemas.sessions import LogoutRequest, SessionCreate
|
||||||
@@ -39,6 +37,8 @@ from app.schemas.users import (
|
|||||||
)
|
)
|
||||||
from app.services.auth_service import AuthenticationError, AuthService
|
from app.services.auth_service import AuthenticationError, AuthService
|
||||||
from app.services.email_service import email_service
|
from app.services.email_service import email_service
|
||||||
|
from app.services.session_service import session_service
|
||||||
|
from app.services.user_service import user_service
|
||||||
from app.utils.device import extract_device_info
|
from app.utils.device import extract_device_info
|
||||||
from app.utils.security import create_password_reset_token, verify_password_reset_token
|
from app.utils.security import create_password_reset_token, verify_password_reset_token
|
||||||
|
|
||||||
@@ -91,17 +91,18 @@ async def _create_login_session(
|
|||||||
location_country=device_info.location_country,
|
location_country=device_info.location_country,
|
||||||
)
|
)
|
||||||
|
|
||||||
await session_crud.create_session(db, obj_in=session_data)
|
await session_service.create_session(db, obj_in=session_data)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"{login_type.capitalize()} successful: {user.email} from {device_info.device_name} "
|
"%s successful: %s from %s (IP: %s)",
|
||||||
f"(IP: {device_info.ip_address})"
|
login_type.capitalize(),
|
||||||
|
user.email,
|
||||||
|
device_info.device_name,
|
||||||
|
device_info.ip_address,
|
||||||
)
|
)
|
||||||
except Exception as session_err:
|
except Exception as session_err:
|
||||||
# Log but don't fail login if session creation fails
|
# Log but don't fail login if session creation fails
|
||||||
logger.error(
|
logger.exception("Failed to create session for %s: %s", user.email, session_err)
|
||||||
f"Failed to create session for {user.email}: {session_err!s}", exc_info=True
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
@@ -123,15 +124,21 @@ async def register_user(
|
|||||||
try:
|
try:
|
||||||
user = await AuthService.create_user(db, user_data)
|
user = await AuthService.create_user(db, user_data)
|
||||||
return user
|
return user
|
||||||
except AuthenticationError as e:
|
except DuplicateError:
|
||||||
# SECURITY: Don't reveal if email exists - generic error message
|
# SECURITY: Don't reveal if email exists - generic error message
|
||||||
logger.warning(f"Registration failed: {e!s}")
|
logger.warning("Registration failed: duplicate email %s", user_data.email)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Registration failed. Please check your information and try again.",
|
||||||
|
)
|
||||||
|
except AuthError as e:
|
||||||
|
logger.warning("Registration failed: %s", e)
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail="Registration failed. Please check your information and try again.",
|
detail="Registration failed. Please check your information and try again.",
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unexpected error during registration: {e!s}", exc_info=True)
|
logger.exception("Unexpected error during registration: %s", e)
|
||||||
raise DatabaseError(
|
raise DatabaseError(
|
||||||
message="An unexpected error occurred. Please try again later.",
|
message="An unexpected error occurred. Please try again later.",
|
||||||
error_code=ErrorCode.INTERNAL_ERROR,
|
error_code=ErrorCode.INTERNAL_ERROR,
|
||||||
@@ -159,7 +166,7 @@ async def login(
|
|||||||
|
|
||||||
# Explicitly check for None result and raise correct exception
|
# Explicitly check for None result and raise correct exception
|
||||||
if user is None:
|
if user is None:
|
||||||
logger.warning(f"Invalid login attempt for: {login_data.email}")
|
logger.warning("Invalid login attempt for: %s", login_data.email)
|
||||||
raise AuthError(
|
raise AuthError(
|
||||||
message="Invalid email or password",
|
message="Invalid email or password",
|
||||||
error_code=ErrorCode.INVALID_CREDENTIALS,
|
error_code=ErrorCode.INVALID_CREDENTIALS,
|
||||||
@@ -175,14 +182,11 @@ async def login(
|
|||||||
|
|
||||||
except AuthenticationError as e:
|
except AuthenticationError as e:
|
||||||
# Handle specific authentication errors like inactive accounts
|
# Handle specific authentication errors like inactive accounts
|
||||||
logger.warning(f"Authentication failed: {e!s}")
|
logger.warning("Authentication failed: %s", e)
|
||||||
raise AuthError(message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS)
|
raise AuthError(message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS)
|
||||||
except AuthError:
|
|
||||||
# Re-raise custom auth exceptions without modification
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Handle unexpected errors
|
# Handle unexpected errors
|
||||||
logger.error(f"Unexpected error during login: {e!s}", exc_info=True)
|
logger.exception("Unexpected error during login: %s", e)
|
||||||
raise DatabaseError(
|
raise DatabaseError(
|
||||||
message="An unexpected error occurred. Please try again later.",
|
message="An unexpected error occurred. Please try again later.",
|
||||||
error_code=ErrorCode.INTERNAL_ERROR,
|
error_code=ErrorCode.INTERNAL_ERROR,
|
||||||
@@ -224,13 +228,10 @@ async def login_oauth(
|
|||||||
# Return full token response with user data
|
# Return full token response with user data
|
||||||
return tokens
|
return tokens
|
||||||
except AuthenticationError as e:
|
except AuthenticationError as e:
|
||||||
logger.warning(f"OAuth authentication failed: {e!s}")
|
logger.warning("OAuth authentication failed: %s", e)
|
||||||
raise AuthError(message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS)
|
raise AuthError(message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS)
|
||||||
except AuthError:
|
|
||||||
# Re-raise custom auth exceptions without modification
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unexpected error during OAuth login: {e!s}", exc_info=True)
|
logger.exception("Unexpected error during OAuth login: %s", e)
|
||||||
raise DatabaseError(
|
raise DatabaseError(
|
||||||
message="An unexpected error occurred. Please try again later.",
|
message="An unexpected error occurred. Please try again later.",
|
||||||
error_code=ErrorCode.INTERNAL_ERROR,
|
error_code=ErrorCode.INTERNAL_ERROR,
|
||||||
@@ -259,11 +260,12 @@ async def refresh_token(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Check if session exists and is active
|
# Check if session exists and is active
|
||||||
session = await session_crud.get_active_by_jti(db, jti=refresh_payload.jti)
|
session = await session_service.get_active_by_jti(db, jti=refresh_payload.jti)
|
||||||
|
|
||||||
if not session:
|
if not session:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Refresh token used for inactive or non-existent session: {refresh_payload.jti}"
|
"Refresh token used for inactive or non-existent session: %s",
|
||||||
|
refresh_payload.jti,
|
||||||
)
|
)
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
@@ -279,16 +281,14 @@ async def refresh_token(
|
|||||||
|
|
||||||
# Update session with new refresh token JTI and expiration
|
# Update session with new refresh token JTI and expiration
|
||||||
try:
|
try:
|
||||||
await session_crud.update_refresh_token(
|
await session_service.update_refresh_token(
|
||||||
db,
|
db,
|
||||||
session=session,
|
session=session,
|
||||||
new_jti=new_refresh_payload.jti,
|
new_jti=new_refresh_payload.jti,
|
||||||
new_expires_at=datetime.fromtimestamp(new_refresh_payload.exp, tz=UTC),
|
new_expires_at=datetime.fromtimestamp(new_refresh_payload.exp, tz=UTC),
|
||||||
)
|
)
|
||||||
except Exception as session_err:
|
except Exception as session_err:
|
||||||
logger.error(
|
logger.exception("Failed to update session %s: %s", session.id, session_err)
|
||||||
f"Failed to update session {session.id}: {session_err!s}", exc_info=True
|
|
||||||
)
|
|
||||||
# Continue anyway - tokens are already issued
|
# Continue anyway - tokens are already issued
|
||||||
|
|
||||||
return tokens
|
return tokens
|
||||||
@@ -311,7 +311,7 @@ async def refresh_token(
|
|||||||
# Re-raise HTTP exceptions (like session revoked)
|
# Re-raise HTTP exceptions (like session revoked)
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unexpected error during token refresh: {e!s}")
|
logger.error("Unexpected error during token refresh: %s", e)
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail="An unexpected error occurred. Please try again later.",
|
detail="An unexpected error occurred. Please try again later.",
|
||||||
@@ -347,7 +347,7 @@ async def request_password_reset(
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Look up user by email
|
# Look up user by email
|
||||||
user = await user_crud.get_by_email(db, email=reset_request.email)
|
user = await user_service.get_by_email(db, email=reset_request.email)
|
||||||
|
|
||||||
# Only send email if user exists and is active
|
# Only send email if user exists and is active
|
||||||
if user and user.is_active:
|
if user and user.is_active:
|
||||||
@@ -358,11 +358,12 @@ async def request_password_reset(
|
|||||||
await email_service.send_password_reset_email(
|
await email_service.send_password_reset_email(
|
||||||
to_email=user.email, reset_token=reset_token, user_name=user.first_name
|
to_email=user.email, reset_token=reset_token, user_name=user.first_name
|
||||||
)
|
)
|
||||||
logger.info(f"Password reset requested for {user.email}")
|
logger.info("Password reset requested for %s", user.email)
|
||||||
else:
|
else:
|
||||||
# Log attempt but don't reveal if email exists
|
# Log attempt but don't reveal if email exists
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Password reset requested for non-existent or inactive email: {reset_request.email}"
|
"Password reset requested for non-existent or inactive email: %s",
|
||||||
|
reset_request.email,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Always return success to prevent email enumeration
|
# Always return success to prevent email enumeration
|
||||||
@@ -371,7 +372,7 @@ async def request_password_reset(
|
|||||||
message="If your email is registered, you will receive a password reset link shortly",
|
message="If your email is registered, you will receive a password reset link shortly",
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error processing password reset request: {e!s}", exc_info=True)
|
logger.exception("Error processing password reset request: %s", e)
|
||||||
# Still return success to prevent information leakage
|
# Still return success to prevent information leakage
|
||||||
return MessageResponse(
|
return MessageResponse(
|
||||||
success=True,
|
success=True,
|
||||||
@@ -412,40 +413,34 @@ async def confirm_password_reset(
|
|||||||
detail="Invalid or expired password reset token",
|
detail="Invalid or expired password reset token",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Look up user
|
# Reset password via service (validates user exists and is active)
|
||||||
user = await user_crud.get_by_email(db, email=email)
|
try:
|
||||||
|
user = await AuthService.reset_password(
|
||||||
if not user:
|
db, email=email, new_password=reset_confirm.new_password
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
|
||||||
)
|
)
|
||||||
|
except AuthenticationError as e:
|
||||||
if not user.is_active:
|
err_msg = str(e)
|
||||||
raise HTTPException(
|
if "inactive" in err_msg.lower():
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
raise HTTPException(
|
||||||
detail="User account is inactive",
|
status_code=status.HTTP_400_BAD_REQUEST, detail=err_msg
|
||||||
)
|
)
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=err_msg)
|
||||||
# Update password
|
|
||||||
user.password_hash = get_password_hash(reset_confirm.new_password)
|
|
||||||
db.add(user)
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
# SECURITY: Invalidate all existing sessions after password reset
|
# SECURITY: Invalidate all existing sessions after password reset
|
||||||
# This prevents stolen sessions from being used after password change
|
# This prevents stolen sessions from being used after password change
|
||||||
from app.crud.session import session as session_crud
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
deactivated_count = await session_crud.deactivate_all_user_sessions(
|
deactivated_count = await session_service.deactivate_all_user_sessions(
|
||||||
db, user_id=str(user.id)
|
db, user_id=str(user.id)
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Password reset successful for {user.email}, invalidated {deactivated_count} sessions"
|
"Password reset successful for %s, invalidated %s sessions",
|
||||||
|
user.email,
|
||||||
|
deactivated_count,
|
||||||
)
|
)
|
||||||
except Exception as session_error:
|
except Exception as session_error:
|
||||||
# Log but don't fail password reset if session invalidation fails
|
# Log but don't fail password reset if session invalidation fails
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Failed to invalidate sessions after password reset: {session_error!s}"
|
"Failed to invalidate sessions after password reset: %s", session_error
|
||||||
)
|
)
|
||||||
|
|
||||||
return MessageResponse(
|
return MessageResponse(
|
||||||
@@ -456,7 +451,7 @@ async def confirm_password_reset(
|
|||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error confirming password reset: {e!s}", exc_info=True)
|
logger.exception("Error confirming password reset: %s", e)
|
||||||
await db.rollback()
|
await db.rollback()
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
@@ -506,19 +501,21 @@ async def logout(
|
|||||||
)
|
)
|
||||||
except (TokenExpiredError, TokenInvalidError) as e:
|
except (TokenExpiredError, TokenInvalidError) as e:
|
||||||
# Even if token is expired/invalid, try to deactivate session
|
# Even if token is expired/invalid, try to deactivate session
|
||||||
logger.warning(f"Logout with invalid/expired token: {e!s}")
|
logger.warning("Logout with invalid/expired token: %s", e)
|
||||||
# Don't fail - return success anyway
|
# Don't fail - return success anyway
|
||||||
return MessageResponse(success=True, message="Logged out successfully")
|
return MessageResponse(success=True, message="Logged out successfully")
|
||||||
|
|
||||||
# Find the session by JTI
|
# Find the session by JTI
|
||||||
session = await session_crud.get_by_jti(db, jti=refresh_payload.jti)
|
session = await session_service.get_by_jti(db, jti=refresh_payload.jti)
|
||||||
|
|
||||||
if session:
|
if session:
|
||||||
# Verify session belongs to current user (security check)
|
# Verify session belongs to current user (security check)
|
||||||
if str(session.user_id) != str(current_user.id):
|
if str(session.user_id) != str(current_user.id):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"User {current_user.id} attempted to logout session {session.id} "
|
"User %s attempted to logout session %s belonging to user %s",
|
||||||
f"belonging to user {session.user_id}"
|
current_user.id,
|
||||||
|
session.id,
|
||||||
|
session.user_id,
|
||||||
)
|
)
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
@@ -526,17 +523,20 @@ async def logout(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Deactivate the session
|
# Deactivate the session
|
||||||
await session_crud.deactivate(db, session_id=str(session.id))
|
await session_service.deactivate(db, session_id=str(session.id))
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"User {current_user.id} logged out from {session.device_name} "
|
"User %s logged out from %s (session %s)",
|
||||||
f"(session {session.id})"
|
current_user.id,
|
||||||
|
session.device_name,
|
||||||
|
session.id,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Session not found - maybe already deleted or never existed
|
# Session not found - maybe already deleted or never existed
|
||||||
# Return success anyway (idempotent)
|
# Return success anyway (idempotent)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Logout requested for non-existent session (JTI: {refresh_payload.jti})"
|
"Logout requested for non-existent session (JTI: %s)",
|
||||||
|
refresh_payload.jti,
|
||||||
)
|
)
|
||||||
|
|
||||||
return MessageResponse(success=True, message="Logged out successfully")
|
return MessageResponse(success=True, message="Logged out successfully")
|
||||||
@@ -544,9 +544,7 @@ async def logout(
|
|||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.exception("Error during logout for user %s: %s", current_user.id, e)
|
||||||
f"Error during logout for user {current_user.id}: {e!s}", exc_info=True
|
|
||||||
)
|
|
||||||
# Don't expose error details
|
# Don't expose error details
|
||||||
return MessageResponse(success=True, message="Logged out successfully")
|
return MessageResponse(success=True, message="Logged out successfully")
|
||||||
|
|
||||||
@@ -584,12 +582,12 @@ async def logout_all(
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Deactivate all sessions for this user
|
# Deactivate all sessions for this user
|
||||||
count = await session_crud.deactivate_all_user_sessions(
|
count = await session_service.deactivate_all_user_sessions(
|
||||||
db, user_id=str(current_user.id)
|
db, user_id=str(current_user.id)
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"User {current_user.id} logged out from all devices ({count} sessions)"
|
"User %s logged out from all devices (%s sessions)", current_user.id, count
|
||||||
)
|
)
|
||||||
|
|
||||||
return MessageResponse(
|
return MessageResponse(
|
||||||
@@ -598,9 +596,7 @@ async def logout_all(
|
|||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.exception("Error during logout-all for user %s: %s", current_user.id, e)
|
||||||
f"Error during logout-all for user {current_user.id}: {e!s}", exc_info=True
|
|
||||||
)
|
|
||||||
await db.rollback()
|
await db.rollback()
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
|||||||
434
backend/app/api/routes/oauth.py
Normal file
434
backend/app/api/routes/oauth.py
Normal file
@@ -0,0 +1,434 @@
|
|||||||
|
# app/api/routes/oauth.py
|
||||||
|
"""
|
||||||
|
OAuth routes for social authentication.
|
||||||
|
|
||||||
|
Endpoints:
|
||||||
|
- GET /oauth/providers - List enabled OAuth providers
|
||||||
|
- GET /oauth/authorize/{provider} - Get authorization URL
|
||||||
|
- POST /oauth/callback/{provider} - Handle OAuth callback
|
||||||
|
- GET /oauth/accounts - List linked OAuth accounts
|
||||||
|
- DELETE /oauth/accounts/{provider} - Unlink an OAuth account
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||||
|
from slowapi import Limiter
|
||||||
|
from slowapi.util import get_remote_address
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.api.dependencies.auth import get_current_user, get_optional_current_user
|
||||||
|
from app.core.auth import decode_token
|
||||||
|
from app.core.config import settings
|
||||||
|
from app.core.database import get_db
|
||||||
|
from app.core.exceptions import AuthenticationError as AuthError
|
||||||
|
from app.models.user import User
|
||||||
|
from app.schemas.oauth import (
|
||||||
|
OAuthAccountsListResponse,
|
||||||
|
OAuthCallbackRequest,
|
||||||
|
OAuthCallbackResponse,
|
||||||
|
OAuthProvidersResponse,
|
||||||
|
OAuthUnlinkResponse,
|
||||||
|
)
|
||||||
|
from app.schemas.sessions import SessionCreate
|
||||||
|
from app.schemas.users import Token
|
||||||
|
from app.services.oauth_service import OAuthService
|
||||||
|
from app.services.session_service import session_service
|
||||||
|
from app.utils.device import extract_device_info
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Initialize limiter for this router
|
||||||
|
limiter = Limiter(key_func=get_remote_address)
|
||||||
|
|
||||||
|
# Use higher rate limits in test environment
|
||||||
|
IS_TEST = os.getenv("IS_TEST", "False") == "True"
|
||||||
|
RATE_MULTIPLIER = 100 if IS_TEST else 1
|
||||||
|
|
||||||
|
|
||||||
|
async def _create_oauth_login_session(
|
||||||
|
db: AsyncSession,
|
||||||
|
request: Request,
|
||||||
|
user: User,
|
||||||
|
tokens: Token,
|
||||||
|
provider: str,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Create a session record for successful OAuth login.
|
||||||
|
|
||||||
|
This is a best-effort operation - login succeeds even if session creation fails.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
device_info = extract_device_info(request)
|
||||||
|
|
||||||
|
# Decode refresh token to get JTI and expiration
|
||||||
|
refresh_payload = decode_token(tokens.refresh_token, verify_type="refresh")
|
||||||
|
|
||||||
|
session_data = SessionCreate(
|
||||||
|
user_id=user.id,
|
||||||
|
refresh_token_jti=refresh_payload.jti,
|
||||||
|
device_name=device_info.device_name or f"OAuth ({provider})",
|
||||||
|
device_id=device_info.device_id,
|
||||||
|
ip_address=device_info.ip_address,
|
||||||
|
user_agent=device_info.user_agent,
|
||||||
|
last_used_at=datetime.now(UTC),
|
||||||
|
expires_at=datetime.fromtimestamp(refresh_payload.exp, tz=UTC),
|
||||||
|
location_city=device_info.location_city,
|
||||||
|
location_country=device_info.location_country,
|
||||||
|
)
|
||||||
|
|
||||||
|
await session_service.create_session(db, obj_in=session_data)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"OAuth login successful: %s via %s from %s (IP: %s)",
|
||||||
|
user.email,
|
||||||
|
provider,
|
||||||
|
device_info.device_name,
|
||||||
|
device_info.ip_address,
|
||||||
|
)
|
||||||
|
except Exception as session_err:
|
||||||
|
# Log but don't fail login if session creation fails
|
||||||
|
logger.exception(
|
||||||
|
"Failed to create session for OAuth login %s: %s", user.email, session_err
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/providers",
|
||||||
|
response_model=OAuthProvidersResponse,
|
||||||
|
summary="List OAuth Providers",
|
||||||
|
description="""
|
||||||
|
Get list of enabled OAuth providers for the login/register UI.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of enabled providers with display info.
|
||||||
|
""",
|
||||||
|
operation_id="list_oauth_providers",
|
||||||
|
)
|
||||||
|
async def list_providers() -> Any:
|
||||||
|
"""
|
||||||
|
Get list of enabled OAuth providers.
|
||||||
|
|
||||||
|
This endpoint is public (no authentication required) as it's needed
|
||||||
|
for the login/register UI to display available social login options.
|
||||||
|
"""
|
||||||
|
return OAuthService.get_enabled_providers()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/authorize/{provider}",
|
||||||
|
response_model=dict,
|
||||||
|
summary="Get OAuth Authorization URL",
|
||||||
|
description="""
|
||||||
|
Get the authorization URL to redirect the user to the OAuth provider.
|
||||||
|
|
||||||
|
The frontend should redirect the user to the returned URL.
|
||||||
|
After authentication, the provider will redirect back to the callback URL.
|
||||||
|
|
||||||
|
**Rate Limit**: 10 requests/minute
|
||||||
|
""",
|
||||||
|
operation_id="get_oauth_authorization_url",
|
||||||
|
)
|
||||||
|
@limiter.limit(f"{10 * RATE_MULTIPLIER}/minute")
|
||||||
|
async def get_authorization_url(
|
||||||
|
request: Request,
|
||||||
|
provider: str,
|
||||||
|
redirect_uri: str = Query(
|
||||||
|
..., description="Frontend callback URL after OAuth completes"
|
||||||
|
),
|
||||||
|
current_user: User | None = Depends(get_optional_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Get OAuth authorization URL.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider: OAuth provider (google, github)
|
||||||
|
redirect_uri: Frontend callback URL
|
||||||
|
current_user: Current user (optional, for account linking)
|
||||||
|
db: Database session
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict with authorization_url and state
|
||||||
|
"""
|
||||||
|
if not settings.OAUTH_ENABLED:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="OAuth is not enabled",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# If user is logged in, this is an account linking flow
|
||||||
|
user_id = str(current_user.id) if current_user else None
|
||||||
|
|
||||||
|
url, state = await OAuthService.create_authorization_url(
|
||||||
|
db,
|
||||||
|
provider=provider,
|
||||||
|
redirect_uri=redirect_uri,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"authorization_url": url,
|
||||||
|
"state": state,
|
||||||
|
}
|
||||||
|
|
||||||
|
except AuthError as e:
|
||||||
|
logger.warning("OAuth authorization failed: %s", e)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=str(e),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("OAuth authorization error: %s", e)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Failed to create authorization URL",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/callback/{provider}",
|
||||||
|
response_model=OAuthCallbackResponse,
|
||||||
|
summary="OAuth Callback",
|
||||||
|
description="""
|
||||||
|
Handle OAuth callback from provider.
|
||||||
|
|
||||||
|
The frontend should call this endpoint with the code and state
|
||||||
|
parameters received from the OAuth provider redirect.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JWT tokens for the authenticated user.
|
||||||
|
|
||||||
|
**Rate Limit**: 10 requests/minute
|
||||||
|
""",
|
||||||
|
operation_id="handle_oauth_callback",
|
||||||
|
)
|
||||||
|
@limiter.limit(f"{10 * RATE_MULTIPLIER}/minute")
|
||||||
|
async def handle_callback(
|
||||||
|
request: Request,
|
||||||
|
provider: str,
|
||||||
|
callback_data: OAuthCallbackRequest,
|
||||||
|
redirect_uri: str = Query(
|
||||||
|
..., description="Must match the redirect_uri used in authorization"
|
||||||
|
),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Handle OAuth callback.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider: OAuth provider (google, github)
|
||||||
|
callback_data: Code and state from provider
|
||||||
|
redirect_uri: Original redirect URI (for validation)
|
||||||
|
db: Database session
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OAuthCallbackResponse with tokens
|
||||||
|
"""
|
||||||
|
if not settings.OAUTH_ENABLED:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="OAuth is not enabled",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await OAuthService.handle_callback(
|
||||||
|
db,
|
||||||
|
code=callback_data.code,
|
||||||
|
state=callback_data.state,
|
||||||
|
redirect_uri=redirect_uri,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create session for the login (need to get the user first)
|
||||||
|
# Note: This requires fetching the user from the token
|
||||||
|
# For now, we skip session creation here as the result doesn't include user info
|
||||||
|
# The session will be created on next request if needed
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except AuthError as e:
|
||||||
|
logger.warning("OAuth callback failed: %s", e)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail=str(e),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("OAuth callback error: %s", e)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="OAuth authentication failed",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/accounts",
|
||||||
|
response_model=OAuthAccountsListResponse,
|
||||||
|
summary="List Linked OAuth Accounts",
|
||||||
|
description="""
|
||||||
|
Get list of OAuth accounts linked to the current user.
|
||||||
|
|
||||||
|
Requires authentication.
|
||||||
|
""",
|
||||||
|
operation_id="list_oauth_accounts",
|
||||||
|
)
|
||||||
|
async def list_accounts(
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
List OAuth accounts linked to the current user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
current_user: Current authenticated user
|
||||||
|
db: Database session
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of linked OAuth accounts
|
||||||
|
"""
|
||||||
|
accounts = await OAuthService.get_user_accounts(db, user_id=current_user.id)
|
||||||
|
return OAuthAccountsListResponse(accounts=accounts)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete(
|
||||||
|
"/accounts/{provider}",
|
||||||
|
response_model=OAuthUnlinkResponse,
|
||||||
|
summary="Unlink OAuth Account",
|
||||||
|
description="""
|
||||||
|
Unlink an OAuth provider from the current user.
|
||||||
|
|
||||||
|
The user must have either a password set or another OAuth provider
|
||||||
|
linked to ensure they can still log in.
|
||||||
|
|
||||||
|
**Rate Limit**: 5 requests/minute
|
||||||
|
""",
|
||||||
|
operation_id="unlink_oauth_account",
|
||||||
|
)
|
||||||
|
@limiter.limit(f"{5 * RATE_MULTIPLIER}/minute")
|
||||||
|
async def unlink_account(
|
||||||
|
request: Request,
|
||||||
|
provider: str,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Unlink an OAuth provider from the current user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider: Provider to unlink (google, github)
|
||||||
|
current_user: Current authenticated user
|
||||||
|
db: Database session
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Success message
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
await OAuthService.unlink_provider(
|
||||||
|
db,
|
||||||
|
user=current_user,
|
||||||
|
provider=provider,
|
||||||
|
)
|
||||||
|
|
||||||
|
return OAuthUnlinkResponse(
|
||||||
|
success=True,
|
||||||
|
message=f"{provider.capitalize()} account unlinked successfully",
|
||||||
|
)
|
||||||
|
|
||||||
|
except AuthError as e:
|
||||||
|
logger.warning("OAuth unlink failed for %s: %s", current_user.email, e)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=str(e),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("OAuth unlink error: %s", e)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Failed to unlink OAuth account",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/link/{provider}",
|
||||||
|
response_model=dict,
|
||||||
|
summary="Start Account Linking",
|
||||||
|
description="""
|
||||||
|
Start the OAuth flow to link a new provider to the current user.
|
||||||
|
|
||||||
|
This is a convenience endpoint that redirects to /authorize/{provider}
|
||||||
|
with the current user context.
|
||||||
|
|
||||||
|
**Rate Limit**: 10 requests/minute
|
||||||
|
""",
|
||||||
|
operation_id="start_oauth_link",
|
||||||
|
)
|
||||||
|
@limiter.limit(f"{10 * RATE_MULTIPLIER}/minute")
|
||||||
|
async def start_link(
|
||||||
|
request: Request,
|
||||||
|
provider: str,
|
||||||
|
redirect_uri: str = Query(
|
||||||
|
..., description="Frontend callback URL after OAuth completes"
|
||||||
|
),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Start OAuth account linking flow.
|
||||||
|
|
||||||
|
This endpoint requires authentication and will initiate an OAuth flow
|
||||||
|
to link a new provider to the current user's account.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider: OAuth provider to link (google, github)
|
||||||
|
redirect_uri: Frontend callback URL
|
||||||
|
current_user: Current authenticated user
|
||||||
|
db: Database session
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict with authorization_url and state
|
||||||
|
"""
|
||||||
|
if not settings.OAUTH_ENABLED:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="OAuth is not enabled",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if user already has this provider linked
|
||||||
|
existing = await OAuthService.get_user_account_by_provider(
|
||||||
|
db, user_id=current_user.id, provider=provider
|
||||||
|
)
|
||||||
|
if existing:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"You already have a {provider} account linked",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
url, state = await OAuthService.create_authorization_url(
|
||||||
|
db,
|
||||||
|
provider=provider,
|
||||||
|
redirect_uri=redirect_uri,
|
||||||
|
user_id=str(current_user.id),
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"authorization_url": url,
|
||||||
|
"state": state,
|
||||||
|
}
|
||||||
|
|
||||||
|
except AuthError as e:
|
||||||
|
logger.warning("OAuth link authorization failed: %s", e)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=str(e),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("OAuth link error: %s", e)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Failed to create authorization URL",
|
||||||
|
)
|
||||||
824
backend/app/api/routes/oauth_provider.py
Normal file
824
backend/app/api/routes/oauth_provider.py
Normal file
@@ -0,0 +1,824 @@
|
|||||||
|
# app/api/routes/oauth_provider.py
|
||||||
|
"""
|
||||||
|
OAuth Provider routes (Authorization Server mode) for MCP integration.
|
||||||
|
|
||||||
|
Implements OAuth 2.0 Authorization Server endpoints:
|
||||||
|
- GET /.well-known/oauth-authorization-server - Server metadata (RFC 8414)
|
||||||
|
- GET /oauth/provider/authorize - Authorization endpoint
|
||||||
|
- POST /oauth/provider/token - Token endpoint
|
||||||
|
- POST /oauth/provider/revoke - Token revocation (RFC 7009)
|
||||||
|
- POST /oauth/provider/introspect - Token introspection (RFC 7662)
|
||||||
|
- Client management endpoints
|
||||||
|
|
||||||
|
Security features:
|
||||||
|
- PKCE required for public clients (S256)
|
||||||
|
- CSRF protection via state parameter
|
||||||
|
- Secure token handling
|
||||||
|
- Rate limiting on sensitive endpoints
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, Form, HTTPException, Query, Request, status
|
||||||
|
from fastapi.responses import RedirectResponse
|
||||||
|
from slowapi import Limiter
|
||||||
|
from slowapi.util import get_remote_address
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.api.dependencies.auth import (
|
||||||
|
get_current_active_user,
|
||||||
|
get_current_superuser,
|
||||||
|
get_optional_current_user,
|
||||||
|
)
|
||||||
|
from app.core.config import settings
|
||||||
|
from app.core.database import get_db
|
||||||
|
from app.models.user import User
|
||||||
|
from app.schemas.oauth import (
|
||||||
|
OAuthClientCreate,
|
||||||
|
OAuthClientResponse,
|
||||||
|
OAuthServerMetadata,
|
||||||
|
OAuthTokenIntrospectionResponse,
|
||||||
|
OAuthTokenResponse,
|
||||||
|
)
|
||||||
|
from app.services import oauth_provider_service as provider_service
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
# Separate router for RFC 8414 well-known endpoint (registered at root level)
|
||||||
|
wellknown_router = APIRouter()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
limiter = Limiter(key_func=get_remote_address)
|
||||||
|
|
||||||
|
|
||||||
|
def require_provider_enabled():
|
||||||
|
"""Dependency to check if OAuth provider mode is enabled."""
|
||||||
|
if not settings.OAUTH_PROVIDER_ENABLED:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="OAuth provider mode is not enabled. Set OAUTH_PROVIDER_ENABLED=true",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Server Metadata (RFC 8414)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@wellknown_router.get(
|
||||||
|
"/.well-known/oauth-authorization-server",
|
||||||
|
response_model=OAuthServerMetadata,
|
||||||
|
summary="OAuth Server Metadata",
|
||||||
|
description="""
|
||||||
|
OAuth 2.0 Authorization Server Metadata (RFC 8414).
|
||||||
|
|
||||||
|
Returns server metadata including supported endpoints, scopes,
|
||||||
|
and capabilities. MCP clients use this to discover the server.
|
||||||
|
|
||||||
|
Note: This endpoint is at the root level per RFC 8414.
|
||||||
|
""",
|
||||||
|
operation_id="get_oauth_server_metadata",
|
||||||
|
tags=["OAuth Provider"],
|
||||||
|
)
|
||||||
|
async def get_server_metadata(
|
||||||
|
_: None = Depends(require_provider_enabled),
|
||||||
|
) -> OAuthServerMetadata:
|
||||||
|
"""Get OAuth 2.0 server metadata."""
|
||||||
|
base_url = settings.OAUTH_ISSUER.rstrip("/")
|
||||||
|
|
||||||
|
return OAuthServerMetadata(
|
||||||
|
issuer=base_url,
|
||||||
|
authorization_endpoint=f"{base_url}/api/v1/oauth/provider/authorize",
|
||||||
|
token_endpoint=f"{base_url}/api/v1/oauth/provider/token",
|
||||||
|
revocation_endpoint=f"{base_url}/api/v1/oauth/provider/revoke",
|
||||||
|
introspection_endpoint=f"{base_url}/api/v1/oauth/provider/introspect",
|
||||||
|
registration_endpoint=None, # Dynamic registration not supported
|
||||||
|
scopes_supported=[
|
||||||
|
"openid",
|
||||||
|
"profile",
|
||||||
|
"email",
|
||||||
|
"read:users",
|
||||||
|
"write:users",
|
||||||
|
"read:organizations",
|
||||||
|
"write:organizations",
|
||||||
|
"admin",
|
||||||
|
],
|
||||||
|
response_types_supported=["code"],
|
||||||
|
grant_types_supported=["authorization_code", "refresh_token"],
|
||||||
|
code_challenge_methods_supported=["S256"],
|
||||||
|
token_endpoint_auth_methods_supported=[
|
||||||
|
"client_secret_basic",
|
||||||
|
"client_secret_post",
|
||||||
|
"none", # For public clients with PKCE
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Authorization Endpoint
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/provider/authorize",
|
||||||
|
summary="Authorization Endpoint",
|
||||||
|
description="""
|
||||||
|
OAuth 2.0 Authorization Endpoint.
|
||||||
|
|
||||||
|
Initiates the authorization code flow:
|
||||||
|
1. Validates client and parameters
|
||||||
|
2. Checks if user is authenticated (redirects to login if not)
|
||||||
|
3. Checks existing consent
|
||||||
|
4. Redirects to consent page if needed
|
||||||
|
5. Issues authorization code and redirects back to client
|
||||||
|
|
||||||
|
Required parameters:
|
||||||
|
- response_type: Must be "code"
|
||||||
|
- client_id: Registered client ID
|
||||||
|
- redirect_uri: Must match registered URI
|
||||||
|
|
||||||
|
Recommended parameters:
|
||||||
|
- state: CSRF protection
|
||||||
|
- code_challenge + code_challenge_method: PKCE (required for public clients)
|
||||||
|
- scope: Requested permissions
|
||||||
|
""",
|
||||||
|
operation_id="oauth_provider_authorize",
|
||||||
|
tags=["OAuth Provider"],
|
||||||
|
)
|
||||||
|
@limiter.limit("30/minute")
|
||||||
|
async def authorize(
|
||||||
|
request: Request,
|
||||||
|
response_type: str = Query(..., description="Must be 'code'"),
|
||||||
|
client_id: str = Query(..., description="OAuth client ID"),
|
||||||
|
redirect_uri: str = Query(..., description="Redirect URI"),
|
||||||
|
scope: str = Query(default="", description="Requested scopes (space-separated)"),
|
||||||
|
state: str = Query(default="", description="CSRF state parameter"),
|
||||||
|
code_challenge: str | None = Query(default=None, description="PKCE code challenge"),
|
||||||
|
code_challenge_method: str | None = Query(
|
||||||
|
default=None, description="PKCE method (S256)"
|
||||||
|
),
|
||||||
|
nonce: str | None = Query(default=None, description="OpenID Connect nonce"),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
_: None = Depends(require_provider_enabled),
|
||||||
|
current_user: User | None = Depends(get_optional_current_user),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Authorization endpoint - initiates OAuth flow.
|
||||||
|
|
||||||
|
If user is not authenticated, redirects to login with return URL.
|
||||||
|
If user has not consented, redirects to consent page.
|
||||||
|
If all checks pass, generates code and redirects to client.
|
||||||
|
"""
|
||||||
|
# Validate response_type
|
||||||
|
if response_type != "code":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="invalid_request: response_type must be 'code'",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate PKCE method if provided - ONLY S256 is allowed (RFC 7636 Section 4.3)
|
||||||
|
# "plain" method provides no security benefit and MUST NOT be used
|
||||||
|
if code_challenge_method and code_challenge_method != "S256":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="invalid_request: code_challenge_method must be 'S256' (plain is not supported)",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate client
|
||||||
|
try:
|
||||||
|
client = await provider_service.get_client(db, client_id)
|
||||||
|
if not client:
|
||||||
|
raise provider_service.InvalidClientError("Unknown client_id")
|
||||||
|
provider_service.validate_redirect_uri(client, redirect_uri)
|
||||||
|
except provider_service.OAuthProviderError as e:
|
||||||
|
# For client/redirect errors, we can't safely redirect - show error
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"{e.error}: {e.error_description}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate and filter scopes
|
||||||
|
try:
|
||||||
|
requested_scopes = provider_service.parse_scope(scope)
|
||||||
|
valid_scopes = provider_service.validate_scopes(client, requested_scopes)
|
||||||
|
except provider_service.InvalidScopeError as e:
|
||||||
|
# Redirect with error
|
||||||
|
scope_error_params: dict[str, str] = {"error": e.error}
|
||||||
|
if e.error_description:
|
||||||
|
scope_error_params["error_description"] = e.error_description
|
||||||
|
if state:
|
||||||
|
scope_error_params["state"] = state
|
||||||
|
return RedirectResponse(
|
||||||
|
url=f"{redirect_uri}?{urlencode(scope_error_params)}",
|
||||||
|
status_code=status.HTTP_302_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Public clients MUST use PKCE
|
||||||
|
if client.client_type == "public":
|
||||||
|
if not code_challenge or code_challenge_method != "S256":
|
||||||
|
pkce_error_params: dict[str, str] = {
|
||||||
|
"error": "invalid_request",
|
||||||
|
"error_description": "PKCE with S256 is required for public clients",
|
||||||
|
}
|
||||||
|
if state:
|
||||||
|
pkce_error_params["state"] = state
|
||||||
|
return RedirectResponse(
|
||||||
|
url=f"{redirect_uri}?{urlencode(pkce_error_params)}",
|
||||||
|
status_code=status.HTTP_302_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
# If user is not authenticated, redirect to login
|
||||||
|
if not current_user:
|
||||||
|
# Store authorization request in session and redirect to login
|
||||||
|
# The frontend will handle the return URL
|
||||||
|
login_url = f"{settings.FRONTEND_URL}/login"
|
||||||
|
return_params = urlencode(
|
||||||
|
{
|
||||||
|
"oauth_authorize": "true",
|
||||||
|
"client_id": client_id,
|
||||||
|
"redirect_uri": redirect_uri,
|
||||||
|
"scope": " ".join(valid_scopes),
|
||||||
|
"state": state,
|
||||||
|
"code_challenge": code_challenge or "",
|
||||||
|
"code_challenge_method": code_challenge_method or "",
|
||||||
|
"nonce": nonce or "",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return RedirectResponse(
|
||||||
|
url=f"{login_url}?return_to=/auth/consent?{return_params}",
|
||||||
|
status_code=status.HTTP_302_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if user has already consented
|
||||||
|
has_consent = await provider_service.check_consent(
|
||||||
|
db, current_user.id, client_id, valid_scopes
|
||||||
|
)
|
||||||
|
|
||||||
|
if not has_consent:
|
||||||
|
# Redirect to consent page
|
||||||
|
consent_params = urlencode(
|
||||||
|
{
|
||||||
|
"client_id": client_id,
|
||||||
|
"client_name": client.client_name,
|
||||||
|
"redirect_uri": redirect_uri,
|
||||||
|
"scope": " ".join(valid_scopes),
|
||||||
|
"state": state,
|
||||||
|
"code_challenge": code_challenge or "",
|
||||||
|
"code_challenge_method": code_challenge_method or "",
|
||||||
|
"nonce": nonce or "",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return RedirectResponse(
|
||||||
|
url=f"{settings.FRONTEND_URL}/auth/consent?{consent_params}",
|
||||||
|
status_code=status.HTTP_302_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
# User is authenticated and has consented - issue authorization code
|
||||||
|
try:
|
||||||
|
code = await provider_service.create_authorization_code(
|
||||||
|
db=db,
|
||||||
|
client=client,
|
||||||
|
user=current_user,
|
||||||
|
redirect_uri=redirect_uri,
|
||||||
|
scope=" ".join(valid_scopes),
|
||||||
|
code_challenge=code_challenge,
|
||||||
|
code_challenge_method=code_challenge_method,
|
||||||
|
state=state,
|
||||||
|
nonce=nonce,
|
||||||
|
)
|
||||||
|
except provider_service.OAuthProviderError as e:
|
||||||
|
error_params: dict[str, str] = {"error": e.error}
|
||||||
|
if e.error_description:
|
||||||
|
error_params["error_description"] = e.error_description
|
||||||
|
if state:
|
||||||
|
error_params["state"] = state
|
||||||
|
return RedirectResponse(
|
||||||
|
url=f"{redirect_uri}?{urlencode(error_params)}",
|
||||||
|
status_code=status.HTTP_302_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Success - redirect with code
|
||||||
|
success_params = {"code": code}
|
||||||
|
if state:
|
||||||
|
success_params["state"] = state
|
||||||
|
|
||||||
|
return RedirectResponse(
|
||||||
|
url=f"{redirect_uri}?{urlencode(success_params)}",
|
||||||
|
status_code=status.HTTP_302_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/provider/authorize/consent",
|
||||||
|
summary="Submit Authorization Consent",
|
||||||
|
description="""
|
||||||
|
Submit user consent for OAuth authorization.
|
||||||
|
|
||||||
|
Called by the consent page after user approves or denies.
|
||||||
|
""",
|
||||||
|
operation_id="oauth_provider_consent",
|
||||||
|
tags=["OAuth Provider"],
|
||||||
|
)
|
||||||
|
@limiter.limit("30/minute")
|
||||||
|
async def submit_consent(
|
||||||
|
request: Request,
|
||||||
|
approved: bool = Form(..., description="Whether user approved"),
|
||||||
|
client_id: str = Form(..., description="OAuth client ID"),
|
||||||
|
redirect_uri: str = Form(..., description="Redirect URI"),
|
||||||
|
scope: str = Form(default="", description="Granted scopes"),
|
||||||
|
state: str = Form(default="", description="CSRF state parameter"),
|
||||||
|
code_challenge: str | None = Form(default=None),
|
||||||
|
code_challenge_method: str | None = Form(default=None),
|
||||||
|
nonce: str | None = Form(default=None),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
_: None = Depends(require_provider_enabled),
|
||||||
|
current_user: User = Depends(get_current_active_user),
|
||||||
|
) -> Any:
|
||||||
|
"""Process consent form submission."""
|
||||||
|
# Validate client
|
||||||
|
try:
|
||||||
|
client = await provider_service.get_client(db, client_id)
|
||||||
|
if not client:
|
||||||
|
raise provider_service.InvalidClientError("Unknown client_id")
|
||||||
|
provider_service.validate_redirect_uri(client, redirect_uri)
|
||||||
|
except provider_service.OAuthProviderError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"{e.error}: {e.error_description}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# If user denied, redirect with error
|
||||||
|
if not approved:
|
||||||
|
denied_params: dict[str, str] = {
|
||||||
|
"error": "access_denied",
|
||||||
|
"error_description": "User denied authorization",
|
||||||
|
}
|
||||||
|
if state:
|
||||||
|
denied_params["state"] = state
|
||||||
|
return RedirectResponse(
|
||||||
|
url=f"{redirect_uri}?{urlencode(denied_params)}",
|
||||||
|
status_code=status.HTTP_302_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse and validate scopes
|
||||||
|
granted_scopes = provider_service.parse_scope(scope)
|
||||||
|
valid_scopes = provider_service.validate_scopes(client, granted_scopes)
|
||||||
|
|
||||||
|
# Record consent
|
||||||
|
await provider_service.grant_consent(db, current_user.id, client_id, valid_scopes)
|
||||||
|
|
||||||
|
# Generate authorization code
|
||||||
|
try:
|
||||||
|
code = await provider_service.create_authorization_code(
|
||||||
|
db=db,
|
||||||
|
client=client,
|
||||||
|
user=current_user,
|
||||||
|
redirect_uri=redirect_uri,
|
||||||
|
scope=" ".join(valid_scopes),
|
||||||
|
code_challenge=code_challenge,
|
||||||
|
code_challenge_method=code_challenge_method,
|
||||||
|
state=state,
|
||||||
|
nonce=nonce,
|
||||||
|
)
|
||||||
|
except provider_service.OAuthProviderError as e:
|
||||||
|
error_params: dict[str, str] = {"error": e.error}
|
||||||
|
if e.error_description:
|
||||||
|
error_params["error_description"] = e.error_description
|
||||||
|
if state:
|
||||||
|
error_params["state"] = state
|
||||||
|
return RedirectResponse(
|
||||||
|
url=f"{redirect_uri}?{urlencode(error_params)}",
|
||||||
|
status_code=status.HTTP_302_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Success
|
||||||
|
success_params = {"code": code}
|
||||||
|
if state:
|
||||||
|
success_params["state"] = state
|
||||||
|
|
||||||
|
return RedirectResponse(
|
||||||
|
url=f"{redirect_uri}?{urlencode(success_params)}",
|
||||||
|
status_code=status.HTTP_302_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Token Endpoint
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/provider/token",
|
||||||
|
response_model=OAuthTokenResponse,
|
||||||
|
summary="Token Endpoint",
|
||||||
|
description="""
|
||||||
|
OAuth 2.0 Token Endpoint.
|
||||||
|
|
||||||
|
Supports:
|
||||||
|
- authorization_code: Exchange code for tokens
|
||||||
|
- refresh_token: Refresh access token
|
||||||
|
|
||||||
|
Client authentication:
|
||||||
|
- Confidential clients: client_secret (Basic auth or POST body)
|
||||||
|
- Public clients: No secret, but PKCE code_verifier required
|
||||||
|
""",
|
||||||
|
operation_id="oauth_provider_token",
|
||||||
|
tags=["OAuth Provider"],
|
||||||
|
)
|
||||||
|
@limiter.limit("60/minute")
|
||||||
|
async def token(
|
||||||
|
request: Request,
|
||||||
|
grant_type: str = Form(..., description="Grant type"),
|
||||||
|
code: str | None = Form(default=None, description="Authorization code"),
|
||||||
|
redirect_uri: str | None = Form(default=None, description="Redirect URI"),
|
||||||
|
client_id: str | None = Form(default=None, description="Client ID"),
|
||||||
|
client_secret: str | None = Form(default=None, description="Client secret"),
|
||||||
|
code_verifier: str | None = Form(default=None, description="PKCE code verifier"),
|
||||||
|
refresh_token: str | None = Form(default=None, description="Refresh token"),
|
||||||
|
scope: str | None = Form(default=None, description="Scope (for refresh)"),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
_: None = Depends(require_provider_enabled),
|
||||||
|
) -> OAuthTokenResponse:
|
||||||
|
"""Token endpoint - exchange code for tokens or refresh."""
|
||||||
|
# Extract client credentials from Basic auth if not in body
|
||||||
|
if not client_id:
|
||||||
|
auth_header = request.headers.get("Authorization", "")
|
||||||
|
if auth_header.startswith("Basic "):
|
||||||
|
import base64
|
||||||
|
|
||||||
|
try:
|
||||||
|
decoded = base64.b64decode(auth_header[6:]).decode()
|
||||||
|
client_id, client_secret = decoded.split(":", 1)
|
||||||
|
except Exception as e:
|
||||||
|
# Log malformed Basic auth for security monitoring
|
||||||
|
logger.warning(
|
||||||
|
"Malformed Basic auth header in token request: %s", type(e).__name__
|
||||||
|
)
|
||||||
|
# Fall back to form body
|
||||||
|
|
||||||
|
if not client_id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="invalid_client: client_id required",
|
||||||
|
headers={"WWW-Authenticate": "Basic"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get device info
|
||||||
|
device_info = request.headers.get("User-Agent", "")[:500]
|
||||||
|
ip_address = get_remote_address(request)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if grant_type == "authorization_code":
|
||||||
|
if not code:
|
||||||
|
raise provider_service.InvalidRequestError("code required")
|
||||||
|
if not redirect_uri:
|
||||||
|
raise provider_service.InvalidRequestError("redirect_uri required")
|
||||||
|
|
||||||
|
result = await provider_service.exchange_authorization_code(
|
||||||
|
db=db,
|
||||||
|
code=code,
|
||||||
|
client_id=client_id,
|
||||||
|
redirect_uri=redirect_uri,
|
||||||
|
code_verifier=code_verifier,
|
||||||
|
client_secret=client_secret,
|
||||||
|
device_info=device_info,
|
||||||
|
ip_address=ip_address,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif grant_type == "refresh_token":
|
||||||
|
if not refresh_token:
|
||||||
|
raise provider_service.InvalidRequestError("refresh_token required")
|
||||||
|
|
||||||
|
result = await provider_service.refresh_tokens(
|
||||||
|
db=db,
|
||||||
|
refresh_token=refresh_token,
|
||||||
|
client_id=client_id,
|
||||||
|
client_secret=client_secret,
|
||||||
|
scope=scope,
|
||||||
|
device_info=device_info,
|
||||||
|
ip_address=ip_address,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="unsupported_grant_type: Must be authorization_code or refresh_token",
|
||||||
|
)
|
||||||
|
|
||||||
|
return OAuthTokenResponse(**result)
|
||||||
|
|
||||||
|
except provider_service.InvalidClientError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail=f"{e.error}: {e.error_description}",
|
||||||
|
headers={"WWW-Authenticate": "Basic"},
|
||||||
|
)
|
||||||
|
except provider_service.OAuthProviderError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"{e.error}: {e.error_description}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Token Revocation (RFC 7009)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/provider/revoke",
|
||||||
|
status_code=status.HTTP_200_OK,
|
||||||
|
summary="Token Revocation Endpoint",
|
||||||
|
description="""
|
||||||
|
OAuth 2.0 Token Revocation Endpoint (RFC 7009).
|
||||||
|
|
||||||
|
Revokes an access token or refresh token.
|
||||||
|
Always returns 200 OK (even if token is invalid) per spec.
|
||||||
|
""",
|
||||||
|
operation_id="oauth_provider_revoke",
|
||||||
|
tags=["OAuth Provider"],
|
||||||
|
)
|
||||||
|
@limiter.limit("30/minute")
|
||||||
|
async def revoke(
|
||||||
|
request: Request,
|
||||||
|
token: str = Form(..., description="Token to revoke"),
|
||||||
|
token_type_hint: str | None = Form(
|
||||||
|
default=None, description="Token type hint (access_token, refresh_token)"
|
||||||
|
),
|
||||||
|
client_id: str | None = Form(default=None, description="Client ID"),
|
||||||
|
client_secret: str | None = Form(default=None, description="Client secret"),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
_: None = Depends(require_provider_enabled),
|
||||||
|
) -> dict[str, str]:
|
||||||
|
"""Revoke a token."""
|
||||||
|
# Extract client credentials from Basic auth if not in body
|
||||||
|
if not client_id:
|
||||||
|
auth_header = request.headers.get("Authorization", "")
|
||||||
|
if auth_header.startswith("Basic "):
|
||||||
|
import base64
|
||||||
|
|
||||||
|
try:
|
||||||
|
decoded = base64.b64decode(auth_header[6:]).decode()
|
||||||
|
client_id, client_secret = decoded.split(":", 1)
|
||||||
|
except Exception as e:
|
||||||
|
# Log malformed Basic auth for security monitoring
|
||||||
|
logger.warning(
|
||||||
|
"Malformed Basic auth header in revoke request: %s",
|
||||||
|
type(e).__name__,
|
||||||
|
)
|
||||||
|
# Fall back to form body
|
||||||
|
|
||||||
|
try:
|
||||||
|
await provider_service.revoke_token(
|
||||||
|
db=db,
|
||||||
|
token=token,
|
||||||
|
token_type_hint=token_type_hint,
|
||||||
|
client_id=client_id,
|
||||||
|
client_secret=client_secret,
|
||||||
|
)
|
||||||
|
except provider_service.InvalidClientError:
|
||||||
|
# Per RFC 7009, we should return 200 OK even for errors
|
||||||
|
# But client authentication errors can return 401
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="invalid_client",
|
||||||
|
headers={"WWW-Authenticate": "Basic"},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
# Log but don't expose errors per RFC 7009
|
||||||
|
logger.warning("Token revocation error: %s", e)
|
||||||
|
|
||||||
|
# Always return 200 OK per RFC 7009
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Token Introspection (RFC 7662)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/provider/introspect",
|
||||||
|
response_model=OAuthTokenIntrospectionResponse,
|
||||||
|
summary="Token Introspection Endpoint",
|
||||||
|
description="""
|
||||||
|
OAuth 2.0 Token Introspection Endpoint (RFC 7662).
|
||||||
|
|
||||||
|
Allows resource servers to query the authorization server
|
||||||
|
to determine the active state and metadata of a token.
|
||||||
|
""",
|
||||||
|
operation_id="oauth_provider_introspect",
|
||||||
|
tags=["OAuth Provider"],
|
||||||
|
)
|
||||||
|
@limiter.limit("120/minute")
|
||||||
|
async def introspect(
|
||||||
|
request: Request,
|
||||||
|
token: str = Form(..., description="Token to introspect"),
|
||||||
|
token_type_hint: str | None = Form(
|
||||||
|
default=None, description="Token type hint (access_token, refresh_token)"
|
||||||
|
),
|
||||||
|
client_id: str | None = Form(default=None, description="Client ID"),
|
||||||
|
client_secret: str | None = Form(default=None, description="Client secret"),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
_: None = Depends(require_provider_enabled),
|
||||||
|
) -> OAuthTokenIntrospectionResponse:
|
||||||
|
"""Introspect a token."""
|
||||||
|
# Extract client credentials from Basic auth if not in body
|
||||||
|
if not client_id:
|
||||||
|
auth_header = request.headers.get("Authorization", "")
|
||||||
|
if auth_header.startswith("Basic "):
|
||||||
|
import base64
|
||||||
|
|
||||||
|
try:
|
||||||
|
decoded = base64.b64decode(auth_header[6:]).decode()
|
||||||
|
client_id, client_secret = decoded.split(":", 1)
|
||||||
|
except Exception as e:
|
||||||
|
# Log malformed Basic auth for security monitoring
|
||||||
|
logger.warning(
|
||||||
|
"Malformed Basic auth header in introspect request: %s",
|
||||||
|
type(e).__name__,
|
||||||
|
)
|
||||||
|
# Fall back to form body
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await provider_service.introspect_token(
|
||||||
|
db=db,
|
||||||
|
token=token,
|
||||||
|
token_type_hint=token_type_hint,
|
||||||
|
client_id=client_id,
|
||||||
|
client_secret=client_secret,
|
||||||
|
)
|
||||||
|
return OAuthTokenIntrospectionResponse(**result)
|
||||||
|
except provider_service.InvalidClientError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="invalid_client",
|
||||||
|
headers={"WWW-Authenticate": "Basic"},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Token introspection error: %s", e)
|
||||||
|
return OAuthTokenIntrospectionResponse(active=False) # pyright: ignore[reportCallIssue]
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Client Management (Admin)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/provider/clients",
|
||||||
|
response_model=dict,
|
||||||
|
summary="Register OAuth Client",
|
||||||
|
description="""
|
||||||
|
Register a new OAuth client (admin only).
|
||||||
|
|
||||||
|
Creates an MCP client that can authenticate against this API.
|
||||||
|
Returns client_id and client_secret (for confidential clients).
|
||||||
|
|
||||||
|
**Important:** Store the client_secret securely - it won't be shown again!
|
||||||
|
""",
|
||||||
|
operation_id="register_oauth_client",
|
||||||
|
tags=["OAuth Provider Admin"],
|
||||||
|
)
|
||||||
|
async def register_client(
|
||||||
|
client_name: str = Form(..., description="Client application name"),
|
||||||
|
redirect_uris: str = Form(..., description="Comma-separated redirect URIs"),
|
||||||
|
client_type: str = Form(default="public", description="public or confidential"),
|
||||||
|
scopes: str = Form(
|
||||||
|
default="openid profile email",
|
||||||
|
description="Allowed scopes (space-separated)",
|
||||||
|
),
|
||||||
|
mcp_server_url: str | None = Form(default=None, description="MCP server URL"),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
_: None = Depends(require_provider_enabled),
|
||||||
|
current_user: User = Depends(get_current_superuser),
|
||||||
|
) -> dict:
|
||||||
|
"""Register a new OAuth client."""
|
||||||
|
# Parse redirect URIs
|
||||||
|
uris = [uri.strip() for uri in redirect_uris.split(",") if uri.strip()]
|
||||||
|
if not uris:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="At least one redirect_uri is required",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse scopes
|
||||||
|
allowed_scopes = [s.strip() for s in scopes.split() if s.strip()]
|
||||||
|
|
||||||
|
client_data = OAuthClientCreate(
|
||||||
|
client_name=client_name,
|
||||||
|
client_description=None,
|
||||||
|
redirect_uris=uris,
|
||||||
|
allowed_scopes=allowed_scopes,
|
||||||
|
client_type=client_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
client, secret = await provider_service.register_client(db, client_data)
|
||||||
|
|
||||||
|
# Update MCP server URL if provided
|
||||||
|
if mcp_server_url:
|
||||||
|
client.mcp_server_url = mcp_server_url
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"client_id": client.client_id,
|
||||||
|
"client_name": client.client_name,
|
||||||
|
"client_type": client.client_type,
|
||||||
|
"redirect_uris": client.redirect_uris,
|
||||||
|
"allowed_scopes": client.allowed_scopes,
|
||||||
|
}
|
||||||
|
|
||||||
|
if secret:
|
||||||
|
result["client_secret"] = secret
|
||||||
|
result["warning"] = (
|
||||||
|
"Store the client_secret securely! It will not be shown again."
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/provider/clients",
|
||||||
|
response_model=list[OAuthClientResponse],
|
||||||
|
summary="List OAuth Clients",
|
||||||
|
description="List all registered OAuth clients (admin only).",
|
||||||
|
operation_id="list_oauth_clients",
|
||||||
|
tags=["OAuth Provider Admin"],
|
||||||
|
)
|
||||||
|
async def list_clients(
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
_: None = Depends(require_provider_enabled),
|
||||||
|
current_user: User = Depends(get_current_superuser),
|
||||||
|
) -> list[OAuthClientResponse]:
|
||||||
|
"""List all OAuth clients."""
|
||||||
|
clients = await provider_service.list_clients(db)
|
||||||
|
return [OAuthClientResponse.model_validate(c) for c in clients]
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete(
|
||||||
|
"/provider/clients/{client_id}",
|
||||||
|
status_code=status.HTTP_204_NO_CONTENT,
|
||||||
|
summary="Delete OAuth Client",
|
||||||
|
description="Delete an OAuth client (admin only). Revokes all tokens.",
|
||||||
|
operation_id="delete_oauth_client",
|
||||||
|
tags=["OAuth Provider Admin"],
|
||||||
|
)
|
||||||
|
async def delete_client(
|
||||||
|
client_id: str,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
_: None = Depends(require_provider_enabled),
|
||||||
|
current_user: User = Depends(get_current_superuser),
|
||||||
|
) -> None:
|
||||||
|
"""Delete an OAuth client."""
|
||||||
|
client = await provider_service.get_client(db, client_id)
|
||||||
|
if not client:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Client not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
await provider_service.delete_client_by_id(db, client_id=client_id)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# User Consent Management
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/provider/consents",
|
||||||
|
summary="List My Consents",
|
||||||
|
description="List OAuth applications the current user has authorized.",
|
||||||
|
operation_id="list_my_oauth_consents",
|
||||||
|
tags=["OAuth Provider"],
|
||||||
|
)
|
||||||
|
async def list_my_consents(
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
_: None = Depends(require_provider_enabled),
|
||||||
|
current_user: User = Depends(get_current_active_user),
|
||||||
|
) -> list[dict]:
|
||||||
|
"""List applications the user has authorized."""
|
||||||
|
return await provider_service.list_user_consents(db, user_id=current_user.id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete(
|
||||||
|
"/provider/consents/{client_id}",
|
||||||
|
status_code=status.HTTP_204_NO_CONTENT,
|
||||||
|
summary="Revoke My Consent",
|
||||||
|
description="Revoke authorization for an OAuth application. Also revokes all tokens.",
|
||||||
|
operation_id="revoke_my_oauth_consent",
|
||||||
|
tags=["OAuth Provider"],
|
||||||
|
)
|
||||||
|
async def revoke_my_consent(
|
||||||
|
client_id: str,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
_: None = Depends(require_provider_enabled),
|
||||||
|
current_user: User = Depends(get_current_active_user),
|
||||||
|
) -> None:
|
||||||
|
"""Revoke consent for an application."""
|
||||||
|
revoked = await provider_service.revoke_consent(db, current_user.id, client_id)
|
||||||
|
if not revoked:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="No consent found for this client",
|
||||||
|
)
|
||||||
@@ -15,8 +15,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||||||
from app.api.dependencies.auth import get_current_user
|
from app.api.dependencies.auth import get_current_user
|
||||||
from app.api.dependencies.permissions import require_org_admin, require_org_membership
|
from app.api.dependencies.permissions import require_org_admin, require_org_membership
|
||||||
from app.core.database import get_db
|
from app.core.database import get_db
|
||||||
from app.core.exceptions import ErrorCode, NotFoundError
|
|
||||||
from app.crud.organization import organization as organization_crud
|
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.schemas.common import (
|
from app.schemas.common import (
|
||||||
PaginatedResponse,
|
PaginatedResponse,
|
||||||
@@ -28,6 +26,7 @@ from app.schemas.organizations import (
|
|||||||
OrganizationResponse,
|
OrganizationResponse,
|
||||||
OrganizationUpdate,
|
OrganizationUpdate,
|
||||||
)
|
)
|
||||||
|
from app.services.organization_service import organization_service
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -54,7 +53,7 @@ async def get_my_organizations(
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Get all org data in single query with JOIN and subquery
|
# Get all org data in single query with JOIN and subquery
|
||||||
orgs_data = await organization_crud.get_user_organizations_with_details(
|
orgs_data = await organization_service.get_user_organizations_with_details(
|
||||||
db, user_id=current_user.id, is_active=is_active
|
db, user_id=current_user.id, is_active=is_active
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -78,7 +77,7 @@ async def get_my_organizations(
|
|||||||
return orgs_with_data
|
return orgs_with_data
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting user organizations: {e!s}", exc_info=True)
|
logger.exception("Error getting user organizations: %s", e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
@@ -100,13 +99,7 @@ async def get_organization(
|
|||||||
User must be a member of the organization.
|
User must be a member of the organization.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
org = await organization_crud.get(db, id=organization_id)
|
org = await organization_service.get_organization(db, str(organization_id))
|
||||||
if not org: # pragma: no cover - Permission check prevents this (see docs/UNREACHABLE_DEFENSIVE_CODE_ANALYSIS.md)
|
|
||||||
raise NotFoundError(
|
|
||||||
detail=f"Organization {organization_id} not found",
|
|
||||||
error_code=ErrorCode.NOT_FOUND,
|
|
||||||
)
|
|
||||||
|
|
||||||
org_dict = {
|
org_dict = {
|
||||||
"id": org.id,
|
"id": org.id,
|
||||||
"name": org.name,
|
"name": org.name,
|
||||||
@@ -116,16 +109,14 @@ async def get_organization(
|
|||||||
"settings": org.settings,
|
"settings": org.settings,
|
||||||
"created_at": org.created_at,
|
"created_at": org.created_at,
|
||||||
"updated_at": org.updated_at,
|
"updated_at": org.updated_at,
|
||||||
"member_count": await organization_crud.get_member_count(
|
"member_count": await organization_service.get_member_count(
|
||||||
db, organization_id=org.id
|
db, organization_id=org.id
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
return OrganizationResponse(**org_dict)
|
return OrganizationResponse(**org_dict)
|
||||||
|
|
||||||
except NotFoundError: # pragma: no cover - See above
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting organization: {e!s}", exc_info=True)
|
logger.exception("Error getting organization: %s", e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
@@ -149,7 +140,7 @@ async def get_organization_members(
|
|||||||
User must be a member of the organization to view members.
|
User must be a member of the organization to view members.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
members, total = await organization_crud.get_organization_members(
|
members, total = await organization_service.get_organization_members(
|
||||||
db,
|
db,
|
||||||
organization_id=organization_id,
|
organization_id=organization_id,
|
||||||
skip=pagination.offset,
|
skip=pagination.offset,
|
||||||
@@ -169,7 +160,7 @@ async def get_organization_members(
|
|||||||
return PaginatedResponse(data=member_responses, pagination=pagination_meta)
|
return PaginatedResponse(data=member_responses, pagination=pagination_meta)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting organization members: {e!s}", exc_info=True)
|
logger.exception("Error getting organization members: %s", e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
@@ -192,16 +183,12 @@ async def update_organization(
|
|||||||
Requires owner or admin role in the organization.
|
Requires owner or admin role in the organization.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
org = await organization_crud.get(db, id=organization_id)
|
org = await organization_service.get_organization(db, str(organization_id))
|
||||||
if not org: # pragma: no cover - Permission check prevents this (see docs/UNREACHABLE_DEFENSIVE_CODE_ANALYSIS.md)
|
updated_org = await organization_service.update_organization(
|
||||||
raise NotFoundError(
|
db, org=org, obj_in=org_in
|
||||||
detail=f"Organization {organization_id} not found",
|
)
|
||||||
error_code=ErrorCode.NOT_FOUND,
|
|
||||||
)
|
|
||||||
|
|
||||||
updated_org = await organization_crud.update(db, db_obj=org, obj_in=org_in)
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"User {current_user.email} updated organization {updated_org.name}"
|
"User %s updated organization %s", current_user.email, updated_org.name
|
||||||
)
|
)
|
||||||
|
|
||||||
org_dict = {
|
org_dict = {
|
||||||
@@ -213,14 +200,12 @@ async def update_organization(
|
|||||||
"settings": updated_org.settings,
|
"settings": updated_org.settings,
|
||||||
"created_at": updated_org.created_at,
|
"created_at": updated_org.created_at,
|
||||||
"updated_at": updated_org.updated_at,
|
"updated_at": updated_org.updated_at,
|
||||||
"member_count": await organization_crud.get_member_count(
|
"member_count": await organization_service.get_member_count(
|
||||||
db, organization_id=updated_org.id
|
db, organization_id=updated_org.id
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
return OrganizationResponse(**org_dict)
|
return OrganizationResponse(**org_dict)
|
||||||
|
|
||||||
except NotFoundError: # pragma: no cover - See above
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error updating organization: {e!s}", exc_info=True)
|
logger.exception("Error updating organization: %s", e)
|
||||||
raise
|
raise
|
||||||
|
|||||||
@@ -17,10 +17,10 @@ from app.api.dependencies.auth import get_current_user
|
|||||||
from app.core.auth import decode_token
|
from app.core.auth import decode_token
|
||||||
from app.core.database import get_db
|
from app.core.database import get_db
|
||||||
from app.core.exceptions import AuthorizationError, ErrorCode, NotFoundError
|
from app.core.exceptions import AuthorizationError, ErrorCode, NotFoundError
|
||||||
from app.crud.session import session as session_crud
|
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.schemas.common import MessageResponse
|
from app.schemas.common import MessageResponse
|
||||||
from app.schemas.sessions import SessionListResponse, SessionResponse
|
from app.schemas.sessions import SessionListResponse, SessionResponse
|
||||||
|
from app.services.session_service import session_service
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -60,7 +60,7 @@ async def list_my_sessions(
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Get all active sessions for user
|
# Get all active sessions for user
|
||||||
sessions = await session_crud.get_user_sessions(
|
sessions = await session_service.get_user_sessions(
|
||||||
db, user_id=str(current_user.id), active_only=True
|
db, user_id=str(current_user.id), active_only=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -74,9 +74,7 @@ async def list_my_sessions(
|
|||||||
# For now, we'll mark current based on most recent activity
|
# For now, we'll mark current based on most recent activity
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Optional token parsing - silently ignore failures
|
# Optional token parsing - silently ignore failures
|
||||||
logger.debug(
|
logger.debug("Failed to decode access token for session marking: %s", e)
|
||||||
f"Failed to decode access token for session marking: {e!s}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert to response format
|
# Convert to response format
|
||||||
session_responses = []
|
session_responses = []
|
||||||
@@ -98,7 +96,7 @@ async def list_my_sessions(
|
|||||||
session_responses.append(session_response)
|
session_responses.append(session_response)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"User {current_user.id} listed {len(session_responses)} active sessions"
|
"User %s listed %s active sessions", current_user.id, len(session_responses)
|
||||||
)
|
)
|
||||||
|
|
||||||
return SessionListResponse(
|
return SessionListResponse(
|
||||||
@@ -106,9 +104,7 @@ async def list_my_sessions(
|
|||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.exception("Error listing sessions for user %s: %s", current_user.id, e)
|
||||||
f"Error listing sessions for user {current_user.id}: {e!s}", exc_info=True
|
|
||||||
)
|
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail="Failed to retrieve sessions",
|
detail="Failed to retrieve sessions",
|
||||||
@@ -150,7 +146,7 @@ async def revoke_session(
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Get the session
|
# Get the session
|
||||||
session = await session_crud.get(db, id=str(session_id))
|
session = await session_service.get_session(db, str(session_id))
|
||||||
|
|
||||||
if not session:
|
if not session:
|
||||||
raise NotFoundError(
|
raise NotFoundError(
|
||||||
@@ -161,8 +157,10 @@ async def revoke_session(
|
|||||||
# Verify session belongs to current user
|
# Verify session belongs to current user
|
||||||
if str(session.user_id) != str(current_user.id):
|
if str(session.user_id) != str(current_user.id):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"User {current_user.id} attempted to revoke session {session_id} "
|
"User %s attempted to revoke session %s belonging to user %s",
|
||||||
f"belonging to user {session.user_id}"
|
current_user.id,
|
||||||
|
session_id,
|
||||||
|
session.user_id,
|
||||||
)
|
)
|
||||||
raise AuthorizationError(
|
raise AuthorizationError(
|
||||||
message="You can only revoke your own sessions",
|
message="You can only revoke your own sessions",
|
||||||
@@ -170,11 +168,13 @@ async def revoke_session(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Deactivate the session
|
# Deactivate the session
|
||||||
await session_crud.deactivate(db, session_id=str(session_id))
|
await session_service.deactivate(db, session_id=str(session_id))
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"User {current_user.id} revoked session {session_id} "
|
"User %s revoked session %s (%s)",
|
||||||
f"({session.device_name})"
|
current_user.id,
|
||||||
|
session_id,
|
||||||
|
session.device_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
return MessageResponse(
|
return MessageResponse(
|
||||||
@@ -185,7 +185,7 @@ async def revoke_session(
|
|||||||
except (NotFoundError, AuthorizationError):
|
except (NotFoundError, AuthorizationError):
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error revoking session {session_id}: {e!s}", exc_info=True)
|
logger.exception("Error revoking session %s: %s", session_id, e)
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail="Failed to revoke session",
|
detail="Failed to revoke session",
|
||||||
@@ -224,12 +224,12 @@ async def cleanup_expired_sessions(
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Use optimized bulk DELETE instead of N individual deletes
|
# Use optimized bulk DELETE instead of N individual deletes
|
||||||
deleted_count = await session_crud.cleanup_expired_for_user(
|
deleted_count = await session_service.cleanup_expired_for_user(
|
||||||
db, user_id=str(current_user.id)
|
db, user_id=str(current_user.id)
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"User {current_user.id} cleaned up {deleted_count} expired sessions"
|
"User %s cleaned up %s expired sessions", current_user.id, deleted_count
|
||||||
)
|
)
|
||||||
|
|
||||||
return MessageResponse(
|
return MessageResponse(
|
||||||
@@ -237,9 +237,8 @@ async def cleanup_expired_sessions(
|
|||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.exception(
|
||||||
f"Error cleaning up sessions for user {current_user.id}: {e!s}",
|
"Error cleaning up sessions for user %s: %s", current_user.id, e
|
||||||
exc_info=True,
|
|
||||||
)
|
)
|
||||||
await db.rollback()
|
await db.rollback()
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
User management endpoints for CRUD operations.
|
User management endpoints for database operations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
@@ -13,8 +13,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||||||
|
|
||||||
from app.api.dependencies.auth import get_current_superuser, get_current_user
|
from app.api.dependencies.auth import get_current_superuser, get_current_user
|
||||||
from app.core.database import get_db
|
from app.core.database import get_db
|
||||||
from app.core.exceptions import AuthorizationError, ErrorCode, NotFoundError
|
from app.core.exceptions import AuthorizationError, ErrorCode
|
||||||
from app.crud.user import user as user_crud
|
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.schemas.common import (
|
from app.schemas.common import (
|
||||||
MessageResponse,
|
MessageResponse,
|
||||||
@@ -25,6 +24,7 @@ from app.schemas.common import (
|
|||||||
)
|
)
|
||||||
from app.schemas.users import PasswordChange, UserResponse, UserUpdate
|
from app.schemas.users import PasswordChange, UserResponse, UserUpdate
|
||||||
from app.services.auth_service import AuthenticationError, AuthService
|
from app.services.auth_service import AuthenticationError, AuthService
|
||||||
|
from app.services.user_service import user_service
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -71,7 +71,7 @@ async def list_users(
|
|||||||
filters["is_superuser"] = is_superuser
|
filters["is_superuser"] = is_superuser
|
||||||
|
|
||||||
# Get paginated users with total count
|
# Get paginated users with total count
|
||||||
users, total = await user_crud.get_multi_with_total(
|
users, total = await user_service.list_users(
|
||||||
db,
|
db,
|
||||||
skip=pagination.offset,
|
skip=pagination.offset,
|
||||||
limit=pagination.limit,
|
limit=pagination.limit,
|
||||||
@@ -90,7 +90,7 @@ async def list_users(
|
|||||||
|
|
||||||
return PaginatedResponse(data=users, pagination=pagination_meta)
|
return PaginatedResponse(data=users, pagination=pagination_meta)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error listing users: {e!s}", exc_info=True)
|
logger.exception("Error listing users: %s", e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
@@ -107,7 +107,9 @@ async def list_users(
|
|||||||
""",
|
""",
|
||||||
operation_id="get_current_user_profile",
|
operation_id="get_current_user_profile",
|
||||||
)
|
)
|
||||||
def get_current_user_profile(current_user: User = Depends(get_current_user)) -> Any:
|
async def get_current_user_profile(
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
) -> Any:
|
||||||
"""Get current user's profile."""
|
"""Get current user's profile."""
|
||||||
return current_user
|
return current_user
|
||||||
|
|
||||||
@@ -138,18 +140,16 @@ async def update_current_user(
|
|||||||
Users cannot elevate their own permissions (protected by UserUpdate schema validator).
|
Users cannot elevate their own permissions (protected by UserUpdate schema validator).
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
updated_user = await user_crud.update(
|
updated_user = await user_service.update_user(
|
||||||
db, db_obj=current_user, obj_in=user_update
|
db, user=current_user, obj_in=user_update
|
||||||
)
|
)
|
||||||
logger.info(f"User {current_user.id} updated their profile")
|
logger.info("User %s updated their profile", current_user.id)
|
||||||
return updated_user
|
return updated_user
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.error(f"Error updating user {current_user.id}: {e!s}")
|
logger.error("Error updating user %s: %s", current_user.id, e)
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.exception("Unexpected error updating user %s: %s", current_user.id, e)
|
||||||
f"Unexpected error updating user {current_user.id}: {e!s}", exc_info=True
|
|
||||||
)
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
@@ -182,7 +182,9 @@ async def get_user_by_id(
|
|||||||
# Check permissions
|
# Check permissions
|
||||||
if str(user_id) != str(current_user.id) and not current_user.is_superuser:
|
if str(user_id) != str(current_user.id) and not current_user.is_superuser:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"User {current_user.id} attempted to access user {user_id} without permission"
|
"User %s attempted to access user %s without permission",
|
||||||
|
current_user.id,
|
||||||
|
user_id,
|
||||||
)
|
)
|
||||||
raise AuthorizationError(
|
raise AuthorizationError(
|
||||||
message="Not enough permissions to view this user",
|
message="Not enough permissions to view this user",
|
||||||
@@ -190,13 +192,7 @@ async def get_user_by_id(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Get user
|
# Get user
|
||||||
user = await user_crud.get(db, id=str(user_id))
|
user = await user_service.get_user(db, str(user_id))
|
||||||
if not user:
|
|
||||||
raise NotFoundError(
|
|
||||||
message=f"User with id {user_id} not found",
|
|
||||||
error_code=ErrorCode.USER_NOT_FOUND,
|
|
||||||
)
|
|
||||||
|
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
@@ -233,7 +229,9 @@ async def update_user(
|
|||||||
|
|
||||||
if not is_own_profile and not current_user.is_superuser:
|
if not is_own_profile and not current_user.is_superuser:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"User {current_user.id} attempted to update user {user_id} without permission"
|
"User %s attempted to update user %s without permission",
|
||||||
|
current_user.id,
|
||||||
|
user_id,
|
||||||
)
|
)
|
||||||
raise AuthorizationError(
|
raise AuthorizationError(
|
||||||
message="Not enough permissions to update this user",
|
message="Not enough permissions to update this user",
|
||||||
@@ -241,22 +239,17 @@ async def update_user(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Get user
|
# Get user
|
||||||
user = await user_crud.get(db, id=str(user_id))
|
user = await user_service.get_user(db, str(user_id))
|
||||||
if not user:
|
|
||||||
raise NotFoundError(
|
|
||||||
message=f"User with id {user_id} not found",
|
|
||||||
error_code=ErrorCode.USER_NOT_FOUND,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
updated_user = await user_crud.update(db, db_obj=user, obj_in=user_update)
|
updated_user = await user_service.update_user(db, user=user, obj_in=user_update)
|
||||||
logger.info(f"User {user_id} updated by {current_user.id}")
|
logger.info("User %s updated by %s", user_id, current_user.id)
|
||||||
return updated_user
|
return updated_user
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.error(f"Error updating user {user_id}: {e!s}")
|
logger.error("Error updating user %s: %s", user_id, e)
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unexpected error updating user {user_id}: {e!s}", exc_info=True)
|
logger.exception("Unexpected error updating user %s: %s", user_id, e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
@@ -296,19 +289,19 @@ async def change_current_user_password(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
logger.info(f"User {current_user.id} changed their password")
|
logger.info("User %s changed their password", current_user.id)
|
||||||
return MessageResponse(
|
return MessageResponse(
|
||||||
success=True, message="Password changed successfully"
|
success=True, message="Password changed successfully"
|
||||||
)
|
)
|
||||||
except AuthenticationError as e:
|
except AuthenticationError as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Failed password change attempt for user {current_user.id}: {e!s}"
|
"Failed password change attempt for user %s: %s", current_user.id, e
|
||||||
)
|
)
|
||||||
raise AuthorizationError(
|
raise AuthorizationError(
|
||||||
message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS
|
message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error changing password for user {current_user.id}: {e!s}")
|
logger.error("Error changing password for user %s: %s", current_user.id, e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
@@ -346,24 +339,19 @@ async def delete_user(
|
|||||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS,
|
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get user
|
# Get user (raises NotFoundError if not found)
|
||||||
user = await user_crud.get(db, id=str(user_id))
|
await user_service.get_user(db, str(user_id))
|
||||||
if not user:
|
|
||||||
raise NotFoundError(
|
|
||||||
message=f"User with id {user_id} not found",
|
|
||||||
error_code=ErrorCode.USER_NOT_FOUND,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Use soft delete instead of hard delete
|
# Use soft delete instead of hard delete
|
||||||
await user_crud.soft_delete(db, id=str(user_id))
|
await user_service.soft_delete_user(db, str(user_id))
|
||||||
logger.info(f"User {user_id} soft-deleted by {current_user.id}")
|
logger.info("User %s soft-deleted by %s", user_id, current_user.id)
|
||||||
return MessageResponse(
|
return MessageResponse(
|
||||||
success=True, message=f"User {user_id} deleted successfully"
|
success=True, message=f"User {user_id} deleted successfully"
|
||||||
)
|
)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.error(f"Error deleting user {user_id}: {e!s}")
|
logger.error("Error deleting user %s: %s", user_id, e)
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unexpected error deleting user {user_id}: {e!s}", exc_info=True)
|
logger.exception("Unexpected error deleting user %s: %s", user_id, e)
|
||||||
raise
|
raise
|
||||||
|
|||||||
@@ -1,23 +1,21 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import UTC, datetime, timedelta
|
from datetime import UTC, datetime, timedelta
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from jose import JWTError, jwt
|
import bcrypt
|
||||||
from passlib.context import CryptContext
|
import jwt
|
||||||
|
from jwt.exceptions import (
|
||||||
|
ExpiredSignatureError,
|
||||||
|
InvalidTokenError,
|
||||||
|
MissingRequiredClaimError,
|
||||||
|
)
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.schemas.users import TokenData, TokenPayload
|
from app.schemas.users import TokenData, TokenPayload
|
||||||
|
|
||||||
# Suppress passlib bcrypt warnings about ident
|
|
||||||
logging.getLogger("passlib").setLevel(logging.ERROR)
|
|
||||||
|
|
||||||
# Password hashing context
|
|
||||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|
||||||
|
|
||||||
|
|
||||||
# Custom exceptions for auth
|
# Custom exceptions for auth
|
||||||
class AuthError(Exception):
|
class AuthError(Exception):
|
||||||
@@ -37,13 +35,16 @@ class TokenMissingClaimError(AuthError):
|
|||||||
|
|
||||||
|
|
||||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||||
"""Verify a password against a hash."""
|
"""Verify a password against a bcrypt hash."""
|
||||||
return pwd_context.verify(plain_password, hashed_password)
|
return bcrypt.checkpw(
|
||||||
|
plain_password.encode("utf-8"), hashed_password.encode("utf-8")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_password_hash(password: str) -> str:
|
def get_password_hash(password: str) -> str:
|
||||||
"""Generate a password hash."""
|
"""Generate a bcrypt password hash."""
|
||||||
return pwd_context.hash(password)
|
salt = bcrypt.gensalt()
|
||||||
|
return bcrypt.hashpw(password.encode("utf-8"), salt).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
async def verify_password_async(plain_password: str, hashed_password: str) -> bool:
|
async def verify_password_async(plain_password: str, hashed_password: str) -> bool:
|
||||||
@@ -60,9 +61,9 @@ async def verify_password_async(plain_password: str, hashed_password: str) -> bo
|
|||||||
Returns:
|
Returns:
|
||||||
True if password matches, False otherwise
|
True if password matches, False otherwise
|
||||||
"""
|
"""
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_running_loop()
|
||||||
return await loop.run_in_executor(
|
return await loop.run_in_executor(
|
||||||
None, partial(pwd_context.verify, plain_password, hashed_password)
|
None, partial(verify_password, plain_password, hashed_password)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -80,8 +81,8 @@ async def get_password_hash_async(password: str) -> str:
|
|||||||
Returns:
|
Returns:
|
||||||
Hashed password string
|
Hashed password string
|
||||||
"""
|
"""
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_running_loop()
|
||||||
return await loop.run_in_executor(None, pwd_context.hash, password)
|
return await loop.run_in_executor(None, get_password_hash, password)
|
||||||
|
|
||||||
|
|
||||||
def create_access_token(
|
def create_access_token(
|
||||||
@@ -121,11 +122,7 @@ def create_access_token(
|
|||||||
to_encode.update(claims)
|
to_encode.update(claims)
|
||||||
|
|
||||||
# Create the JWT
|
# Create the JWT
|
||||||
encoded_jwt = jwt.encode(
|
return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||||||
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
|
|
||||||
)
|
|
||||||
|
|
||||||
return encoded_jwt
|
|
||||||
|
|
||||||
|
|
||||||
def create_refresh_token(
|
def create_refresh_token(
|
||||||
@@ -154,11 +151,7 @@ def create_refresh_token(
|
|||||||
"type": "refresh",
|
"type": "refresh",
|
||||||
}
|
}
|
||||||
|
|
||||||
encoded_jwt = jwt.encode(
|
return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||||||
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
|
|
||||||
)
|
|
||||||
|
|
||||||
return encoded_jwt
|
|
||||||
|
|
||||||
|
|
||||||
def decode_token(token: str, verify_type: str | None = None) -> TokenPayload:
|
def decode_token(token: str, verify_type: str | None = None) -> TokenPayload:
|
||||||
@@ -198,7 +191,7 @@ def decode_token(token: str, verify_type: str | None = None) -> TokenPayload:
|
|||||||
|
|
||||||
# Reject weak or unexpected algorithms
|
# Reject weak or unexpected algorithms
|
||||||
# NOTE: These are defensive checks that provide defense-in-depth.
|
# NOTE: These are defensive checks that provide defense-in-depth.
|
||||||
# The python-jose library rejects these tokens BEFORE we reach here,
|
# PyJWT rejects these tokens BEFORE we reach here,
|
||||||
# but we keep these checks in case the library changes or is misconfigured.
|
# but we keep these checks in case the library changes or is misconfigured.
|
||||||
# Coverage: Marked as pragma since library catches first (see tests/core/test_auth_security.py)
|
# Coverage: Marked as pragma since library catches first (see tests/core/test_auth_security.py)
|
||||||
if token_algorithm == "NONE": # pragma: no cover
|
if token_algorithm == "NONE": # pragma: no cover
|
||||||
@@ -219,10 +212,11 @@ def decode_token(token: str, verify_type: str | None = None) -> TokenPayload:
|
|||||||
token_data = TokenPayload(**payload)
|
token_data = TokenPayload(**payload)
|
||||||
return token_data
|
return token_data
|
||||||
|
|
||||||
except JWTError as e:
|
except ExpiredSignatureError:
|
||||||
# Check if the error is due to an expired token
|
raise TokenExpiredError("Token has expired")
|
||||||
if "expired" in str(e).lower():
|
except MissingRequiredClaimError as e:
|
||||||
raise TokenExpiredError("Token has expired")
|
raise TokenMissingClaimError(f"Token missing required claim: {e}")
|
||||||
|
except InvalidTokenError:
|
||||||
raise TokenInvalidError("Invalid authentication token")
|
raise TokenInvalidError("Invalid authentication token")
|
||||||
except ValidationError:
|
except ValidationError:
|
||||||
raise TokenInvalidError("Invalid token payload")
|
raise TokenInvalidError("Invalid token payload")
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from pydantic_settings import BaseSettings
|
|||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
PROJECT_NAME: str = "App"
|
PROJECT_NAME: str = "PragmaStack"
|
||||||
VERSION: str = "1.0.0"
|
VERSION: str = "1.0.0"
|
||||||
API_V1_STR: str = "/api/v1"
|
API_V1_STR: str = "/api/v1"
|
||||||
|
|
||||||
@@ -14,6 +14,10 @@ class Settings(BaseSettings):
|
|||||||
default="development",
|
default="development",
|
||||||
description="Environment: development, staging, or production",
|
description="Environment: development, staging, or production",
|
||||||
)
|
)
|
||||||
|
DEMO_MODE: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description="Enable demo mode (relaxed security, demo users)",
|
||||||
|
)
|
||||||
|
|
||||||
# Security: Content Security Policy
|
# Security: Content Security Policy
|
||||||
# Set to False to disable CSP entirely (not recommended)
|
# Set to False to disable CSP entirely (not recommended)
|
||||||
@@ -72,6 +76,60 @@ class Settings(BaseSettings):
|
|||||||
description="Frontend application URL for email links",
|
description="Frontend application URL for email links",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# OAuth Configuration
|
||||||
|
OAUTH_ENABLED: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description="Enable OAuth authentication (social login)",
|
||||||
|
)
|
||||||
|
OAUTH_AUTO_LINK_BY_EMAIL: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Automatically link OAuth accounts to existing users with matching email",
|
||||||
|
)
|
||||||
|
OAUTH_STATE_EXPIRE_MINUTES: int = Field(
|
||||||
|
default=10,
|
||||||
|
description="OAuth state parameter expiration time in minutes",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Google OAuth
|
||||||
|
OAUTH_GOOGLE_CLIENT_ID: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Google OAuth client ID from Google Cloud Console",
|
||||||
|
)
|
||||||
|
OAUTH_GOOGLE_CLIENT_SECRET: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Google OAuth client secret from Google Cloud Console",
|
||||||
|
)
|
||||||
|
|
||||||
|
# GitHub OAuth
|
||||||
|
OAUTH_GITHUB_CLIENT_ID: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="GitHub OAuth client ID from GitHub Developer Settings",
|
||||||
|
)
|
||||||
|
OAUTH_GITHUB_CLIENT_SECRET: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="GitHub OAuth client secret from GitHub Developer Settings",
|
||||||
|
)
|
||||||
|
|
||||||
|
# OAuth Provider Mode (for MCP clients - skeleton)
|
||||||
|
OAUTH_PROVIDER_ENABLED: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description="Enable OAuth provider mode (act as authorization server for MCP clients)",
|
||||||
|
)
|
||||||
|
OAUTH_ISSUER: str = Field(
|
||||||
|
default="http://localhost:8000",
|
||||||
|
description="OAuth issuer URL (your API base URL)",
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def enabled_oauth_providers(self) -> list[str]:
|
||||||
|
"""Get list of enabled OAuth providers based on configured credentials."""
|
||||||
|
providers = []
|
||||||
|
if self.OAUTH_GOOGLE_CLIENT_ID and self.OAUTH_GOOGLE_CLIENT_SECRET:
|
||||||
|
providers.append("google")
|
||||||
|
if self.OAUTH_GITHUB_CLIENT_ID and self.OAUTH_GITHUB_CLIENT_SECRET:
|
||||||
|
providers.append("github")
|
||||||
|
return providers
|
||||||
|
|
||||||
# Admin user
|
# Admin user
|
||||||
FIRST_SUPERUSER_EMAIL: str | None = Field(
|
FIRST_SUPERUSER_EMAIL: str | None = Field(
|
||||||
default=None, description="Email for first superuser account"
|
default=None, description="Email for first superuser account"
|
||||||
@@ -110,11 +168,21 @@ class Settings(BaseSettings):
|
|||||||
|
|
||||||
@field_validator("FIRST_SUPERUSER_PASSWORD")
|
@field_validator("FIRST_SUPERUSER_PASSWORD")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_superuser_password(cls, v: str | None) -> str | None:
|
def validate_superuser_password(cls, v: str | None, info) -> str | None:
|
||||||
"""Validate superuser password strength."""
|
"""Validate superuser password strength."""
|
||||||
if v is None:
|
if v is None:
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
# Get environment from values if available
|
||||||
|
values_data = info.data if info.data else {}
|
||||||
|
demo_mode = values_data.get("DEMO_MODE", False)
|
||||||
|
|
||||||
|
if demo_mode:
|
||||||
|
# In demo mode, allow specific weak passwords for demo accounts
|
||||||
|
demo_passwords = {"Demo123!", "Admin123!"}
|
||||||
|
if v in demo_passwords:
|
||||||
|
return v
|
||||||
|
|
||||||
if len(v) < 12:
|
if len(v) < 12:
|
||||||
raise ValueError("FIRST_SUPERUSER_PASSWORD must be at least 12 characters")
|
raise ValueError("FIRST_SUPERUSER_PASSWORD must be at least 12 characters")
|
||||||
|
|
||||||
|
|||||||
@@ -128,8 +128,8 @@ async def async_transaction_scope() -> AsyncGenerator[AsyncSession, None]:
|
|||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
async with async_transaction_scope() as db:
|
async with async_transaction_scope() as db:
|
||||||
user = await user_crud.create(db, obj_in=user_create)
|
user = await user_repo.create(db, obj_in=user_create)
|
||||||
profile = await profile_crud.create(db, obj_in=profile_create)
|
profile = await profile_repo.create(db, obj_in=profile_create)
|
||||||
# Both operations committed together
|
# Both operations committed together
|
||||||
"""
|
"""
|
||||||
async with SessionLocal() as session:
|
async with SessionLocal() as session:
|
||||||
@@ -139,7 +139,7 @@ async def async_transaction_scope() -> AsyncGenerator[AsyncSession, None]:
|
|||||||
logger.debug("Async transaction committed successfully")
|
logger.debug("Async transaction committed successfully")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
logger.error(f"Async transaction failed, rolling back: {e!s}")
|
logger.error("Async transaction failed, rolling back: %s", e)
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
await session.close()
|
await session.close()
|
||||||
@@ -155,7 +155,7 @@ async def check_async_database_health() -> bool:
|
|||||||
await db.execute(text("SELECT 1"))
|
await db.execute(text("SELECT 1"))
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Async database health check failed: {e!s}")
|
logger.error("Async database health check failed: %s", e)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
366
backend/app/core/demo_data.json
Normal file
366
backend/app/core/demo_data.json
Normal file
@@ -0,0 +1,366 @@
|
|||||||
|
{
|
||||||
|
"organizations": [
|
||||||
|
{
|
||||||
|
"name": "Acme Corp",
|
||||||
|
"slug": "acme-corp",
|
||||||
|
"description": "A leading provider of coyote-catching equipment."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Globex Corporation",
|
||||||
|
"slug": "globex",
|
||||||
|
"description": "We own the East Coast."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Soylent Corp",
|
||||||
|
"slug": "soylent",
|
||||||
|
"description": "Making food for the future."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Initech",
|
||||||
|
"slug": "initech",
|
||||||
|
"description": "Software for the soul."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Umbrella Corporation",
|
||||||
|
"slug": "umbrella",
|
||||||
|
"description": "Our business is life itself."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Massive Dynamic",
|
||||||
|
"slug": "massive-dynamic",
|
||||||
|
"description": "What don't we do?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"users": [
|
||||||
|
{
|
||||||
|
"email": "demo@example.com",
|
||||||
|
"password": "DemoPass1234!",
|
||||||
|
"first_name": "Demo",
|
||||||
|
"last_name": "User",
|
||||||
|
"is_superuser": false,
|
||||||
|
"organization_slug": "acme-corp",
|
||||||
|
"role": "member",
|
||||||
|
"is_active": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"email": "alice@acme.com",
|
||||||
|
"password": "Demo123!",
|
||||||
|
"first_name": "Alice",
|
||||||
|
"last_name": "Smith",
|
||||||
|
"is_superuser": false,
|
||||||
|
"organization_slug": "acme-corp",
|
||||||
|
"role": "admin",
|
||||||
|
"is_active": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"email": "bob@acme.com",
|
||||||
|
"password": "Demo123!",
|
||||||
|
"first_name": "Bob",
|
||||||
|
"last_name": "Jones",
|
||||||
|
"is_superuser": false,
|
||||||
|
"organization_slug": "acme-corp",
|
||||||
|
"role": "member",
|
||||||
|
"is_active": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"email": "charlie@acme.com",
|
||||||
|
"password": "Demo123!",
|
||||||
|
"first_name": "Charlie",
|
||||||
|
"last_name": "Brown",
|
||||||
|
"is_superuser": false,
|
||||||
|
"organization_slug": "acme-corp",
|
||||||
|
"role": "member",
|
||||||
|
"is_active": false
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"email": "diana@acme.com",
|
||||||
|
"password": "Demo123!",
|
||||||
|
"first_name": "Diana",
|
||||||
|
"last_name": "Prince",
|
||||||
|
"is_superuser": false,
|
||||||
|
"organization_slug": "acme-corp",
|
||||||
|
"role": "member",
|
||||||
|
"is_active": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"email": "carol@globex.com",
|
||||||
|
"password": "Demo123!",
|
||||||
|
"first_name": "Carol",
|
||||||
|
"last_name": "Williams",
|
||||||
|
"is_superuser": false,
|
||||||
|
"organization_slug": "globex",
|
||||||
|
"role": "owner",
|
||||||
|
"is_active": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"email": "dan@globex.com",
|
||||||
|
"password": "Demo123!",
|
||||||
|
"first_name": "Dan",
|
||||||
|
"last_name": "Miller",
|
||||||
|
"is_superuser": false,
|
||||||
|
"organization_slug": "globex",
|
||||||
|
"role": "member",
|
||||||
|
"is_active": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"email": "ellen@globex.com",
|
||||||
|
"password": "Demo123!",
|
||||||
|
"first_name": "Ellen",
|
||||||
|
"last_name": "Ripley",
|
||||||
|
"is_superuser": false,
|
||||||
|
"organization_slug": "globex",
|
||||||
|
"role": "member",
|
||||||
|
"is_active": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"email": "fred@globex.com",
|
||||||
|
"password": "Demo123!",
|
||||||
|
"first_name": "Fred",
|
||||||
|
"last_name": "Flintstone",
|
||||||
|
"is_superuser": false,
|
||||||
|
"organization_slug": "globex",
|
||||||
|
"role": "member",
|
||||||
|
"is_active": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"email": "dave@soylent.com",
|
||||||
|
"password": "Demo123!",
|
||||||
|
"first_name": "Dave",
|
||||||
|
"last_name": "Brown",
|
||||||
|
"is_superuser": false,
|
||||||
|
"organization_slug": "soylent",
|
||||||
|
"role": "member",
|
||||||
|
"is_active": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"email": "gina@soylent.com",
|
||||||
|
"password": "Demo123!",
|
||||||
|
"first_name": "Gina",
|
||||||
|
"last_name": "Torres",
|
||||||
|
"is_superuser": false,
|
||||||
|
"organization_slug": "soylent",
|
||||||
|
"role": "member",
|
||||||
|
"is_active": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"email": "harry@soylent.com",
|
||||||
|
"password": "Demo123!",
|
||||||
|
"first_name": "Harry",
|
||||||
|
"last_name": "Potter",
|
||||||
|
"is_superuser": false,
|
||||||
|
"organization_slug": "soylent",
|
||||||
|
"role": "admin",
|
||||||
|
"is_active": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"email": "eve@initech.com",
|
||||||
|
"password": "Demo123!",
|
||||||
|
"first_name": "Eve",
|
||||||
|
"last_name": "Davis",
|
||||||
|
"is_superuser": false,
|
||||||
|
"organization_slug": "initech",
|
||||||
|
"role": "admin",
|
||||||
|
"is_active": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"email": "iris@initech.com",
|
||||||
|
"password": "Demo123!",
|
||||||
|
"first_name": "Iris",
|
||||||
|
"last_name": "West",
|
||||||
|
"is_superuser": false,
|
||||||
|
"organization_slug": "initech",
|
||||||
|
"role": "member",
|
||||||
|
"is_active": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"email": "jack@initech.com",
|
||||||
|
"password": "Demo123!",
|
||||||
|
"first_name": "Jack",
|
||||||
|
"last_name": "Sparrow",
|
||||||
|
"is_superuser": false,
|
||||||
|
"organization_slug": "initech",
|
||||||
|
"role": "member",
|
||||||
|
"is_active": false
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"email": "frank@umbrella.com",
|
||||||
|
"password": "Demo123!",
|
||||||
|
"first_name": "Frank",
|
||||||
|
"last_name": "Miller",
|
||||||
|
"is_superuser": false,
|
||||||
|
"organization_slug": "umbrella",
|
||||||
|
"role": "member",
|
||||||
|
"is_active": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"email": "george@umbrella.com",
|
||||||
|
"password": "Demo123!",
|
||||||
|
"first_name": "George",
|
||||||
|
"last_name": "Costanza",
|
||||||
|
"is_superuser": false,
|
||||||
|
"organization_slug": "umbrella",
|
||||||
|
"role": "member",
|
||||||
|
"is_active": false
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"email": "kate@umbrella.com",
|
||||||
|
"password": "Demo123!",
|
||||||
|
"first_name": "Kate",
|
||||||
|
"last_name": "Bishop",
|
||||||
|
"is_superuser": false,
|
||||||
|
"organization_slug": "umbrella",
|
||||||
|
"role": "member",
|
||||||
|
"is_active": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"email": "leo@massive.com",
|
||||||
|
"password": "Demo123!",
|
||||||
|
"first_name": "Leo",
|
||||||
|
"last_name": "Messi",
|
||||||
|
"is_superuser": false,
|
||||||
|
"organization_slug": "massive-dynamic",
|
||||||
|
"role": "owner",
|
||||||
|
"is_active": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"email": "mary@massive.com",
|
||||||
|
"password": "Demo123!",
|
||||||
|
"first_name": "Mary",
|
||||||
|
"last_name": "Jane",
|
||||||
|
"is_superuser": false,
|
||||||
|
"organization_slug": "massive-dynamic",
|
||||||
|
"role": "member",
|
||||||
|
"is_active": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"email": "nathan@massive.com",
|
||||||
|
"password": "Demo123!",
|
||||||
|
"first_name": "Nathan",
|
||||||
|
"last_name": "Drake",
|
||||||
|
"is_superuser": false,
|
||||||
|
"organization_slug": "massive-dynamic",
|
||||||
|
"role": "member",
|
||||||
|
"is_active": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"email": "olivia@massive.com",
|
||||||
|
"password": "Demo123!",
|
||||||
|
"first_name": "Olivia",
|
||||||
|
"last_name": "Dunham",
|
||||||
|
"is_superuser": false,
|
||||||
|
"organization_slug": "massive-dynamic",
|
||||||
|
"role": "admin",
|
||||||
|
"is_active": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"email": "peter@massive.com",
|
||||||
|
"password": "Demo123!",
|
||||||
|
"first_name": "Peter",
|
||||||
|
"last_name": "Parker",
|
||||||
|
"is_superuser": false,
|
||||||
|
"organization_slug": "massive-dynamic",
|
||||||
|
"role": "member",
|
||||||
|
"is_active": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"email": "quinn@massive.com",
|
||||||
|
"password": "Demo123!",
|
||||||
|
"first_name": "Quinn",
|
||||||
|
"last_name": "Mallory",
|
||||||
|
"is_superuser": false,
|
||||||
|
"organization_slug": "massive-dynamic",
|
||||||
|
"role": "member",
|
||||||
|
"is_active": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"email": "grace@example.com",
|
||||||
|
"password": "Demo123!",
|
||||||
|
"first_name": "Grace",
|
||||||
|
"last_name": "Hopper",
|
||||||
|
"is_superuser": false,
|
||||||
|
"organization_slug": null,
|
||||||
|
"role": null,
|
||||||
|
"is_active": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"email": "heidi@example.com",
|
||||||
|
"password": "Demo123!",
|
||||||
|
"first_name": "Heidi",
|
||||||
|
"last_name": "Klum",
|
||||||
|
"is_superuser": false,
|
||||||
|
"organization_slug": null,
|
||||||
|
"role": null,
|
||||||
|
"is_active": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"email": "ivan@example.com",
|
||||||
|
"password": "Demo123!",
|
||||||
|
"first_name": "Ivan",
|
||||||
|
"last_name": "Drago",
|
||||||
|
"is_superuser": false,
|
||||||
|
"organization_slug": null,
|
||||||
|
"role": null,
|
||||||
|
"is_active": false
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"email": "rachel@example.com",
|
||||||
|
"password": "Demo123!",
|
||||||
|
"first_name": "Rachel",
|
||||||
|
"last_name": "Green",
|
||||||
|
"is_superuser": false,
|
||||||
|
"organization_slug": null,
|
||||||
|
"role": null,
|
||||||
|
"is_active": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"email": "sam@example.com",
|
||||||
|
"password": "Demo123!",
|
||||||
|
"first_name": "Sam",
|
||||||
|
"last_name": "Wilson",
|
||||||
|
"is_superuser": false,
|
||||||
|
"organization_slug": null,
|
||||||
|
"role": null,
|
||||||
|
"is_active": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"email": "tony@example.com",
|
||||||
|
"password": "Demo123!",
|
||||||
|
"first_name": "Tony",
|
||||||
|
"last_name": "Stark",
|
||||||
|
"is_superuser": false,
|
||||||
|
"organization_slug": null,
|
||||||
|
"role": null,
|
||||||
|
"is_active": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"email": "una@example.com",
|
||||||
|
"password": "Demo123!",
|
||||||
|
"first_name": "Una",
|
||||||
|
"last_name": "Chin-Riley",
|
||||||
|
"is_superuser": false,
|
||||||
|
"organization_slug": null,
|
||||||
|
"role": null,
|
||||||
|
"is_active": false
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"email": "victor@example.com",
|
||||||
|
"password": "Demo123!",
|
||||||
|
"first_name": "Victor",
|
||||||
|
"last_name": "Von Doom",
|
||||||
|
"is_superuser": false,
|
||||||
|
"organization_slug": null,
|
||||||
|
"role": null,
|
||||||
|
"is_active": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"email": "wanda@example.com",
|
||||||
|
"password": "Demo123!",
|
||||||
|
"first_name": "Wanda",
|
||||||
|
"last_name": "Maximoff",
|
||||||
|
"is_superuser": false,
|
||||||
|
"organization_slug": null,
|
||||||
|
"role": null,
|
||||||
|
"is_active": true
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
@@ -143,8 +143,11 @@ async def api_exception_handler(request: Request, exc: APIException) -> JSONResp
|
|||||||
Returns a standardized error response with error code and message.
|
Returns a standardized error response with error code and message.
|
||||||
"""
|
"""
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"API exception: {exc.error_code} - {exc.message} "
|
"API exception: %s - %s (status: %s, path: %s)",
|
||||||
f"(status: {exc.status_code}, path: {request.url.path})"
|
exc.error_code,
|
||||||
|
exc.message,
|
||||||
|
exc.status_code,
|
||||||
|
request.url.path,
|
||||||
)
|
)
|
||||||
|
|
||||||
error_response = ErrorResponse(
|
error_response = ErrorResponse(
|
||||||
@@ -186,7 +189,9 @@ async def validation_exception_handler(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.warning(f"Validation error: {len(errors)} errors (path: {request.url.path})")
|
logger.warning(
|
||||||
|
"Validation error: %s errors (path: %s)", len(errors), request.url.path
|
||||||
|
)
|
||||||
|
|
||||||
error_response = ErrorResponse(errors=errors)
|
error_response = ErrorResponse(errors=errors)
|
||||||
|
|
||||||
@@ -218,11 +223,14 @@ async def http_exception_handler(request: Request, exc: HTTPException) -> JSONRe
|
|||||||
)
|
)
|
||||||
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"HTTP exception: {exc.status_code} - {exc.detail} (path: {request.url.path})"
|
"HTTP exception: %s - %s (path: %s)",
|
||||||
|
exc.status_code,
|
||||||
|
exc.detail,
|
||||||
|
request.url.path,
|
||||||
)
|
)
|
||||||
|
|
||||||
error_response = ErrorResponse(
|
error_response = ErrorResponse(
|
||||||
errors=[ErrorDetail(code=error_code, message=str(exc.detail))]
|
errors=[ErrorDetail(code=error_code, message=str(exc.detail), field=None)]
|
||||||
)
|
)
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
@@ -239,10 +247,11 @@ async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONR
|
|||||||
Logs the full exception and returns a generic error response to avoid
|
Logs the full exception and returns a generic error response to avoid
|
||||||
leaking sensitive information in production.
|
leaking sensitive information in production.
|
||||||
"""
|
"""
|
||||||
logger.error(
|
logger.exception(
|
||||||
f"Unhandled exception: {type(exc).__name__} - {exc!s} "
|
"Unhandled exception: %s - %s (path: %s)",
|
||||||
f"(path: {request.url.path})",
|
type(exc).__name__,
|
||||||
exc_info=True,
|
exc,
|
||||||
|
request.url.path,
|
||||||
)
|
)
|
||||||
|
|
||||||
# In production, don't expose internal error details
|
# In production, don't expose internal error details
|
||||||
@@ -254,7 +263,7 @@ async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONR
|
|||||||
message = f"{type(exc).__name__}: {exc!s}"
|
message = f"{type(exc).__name__}: {exc!s}"
|
||||||
|
|
||||||
error_response = ErrorResponse(
|
error_response = ErrorResponse(
|
||||||
errors=[ErrorDetail(code=ErrorCode.INTERNAL_ERROR, message=message)]
|
errors=[ErrorDetail(code=ErrorCode.INTERNAL_ERROR, message=message, field=None)]
|
||||||
)
|
)
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
|
|||||||
26
backend/app/core/repository_exceptions.py
Normal file
26
backend/app/core/repository_exceptions.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
"""
|
||||||
|
Custom exceptions for the repository layer.
|
||||||
|
|
||||||
|
These exceptions allow services and routes to handle database-level errors
|
||||||
|
with proper semantics, without leaking SQLAlchemy internals.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class RepositoryError(Exception):
|
||||||
|
"""Base for all repository-layer errors."""
|
||||||
|
|
||||||
|
|
||||||
|
class DuplicateEntryError(RepositoryError):
|
||||||
|
"""Raised on unique constraint violations. Maps to HTTP 409 Conflict."""
|
||||||
|
|
||||||
|
|
||||||
|
class IntegrityConstraintError(RepositoryError):
|
||||||
|
"""Raised on FK or check constraint violations."""
|
||||||
|
|
||||||
|
|
||||||
|
class RecordNotFoundError(RepositoryError):
|
||||||
|
"""Raised when an expected record doesn't exist."""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidInputError(RepositoryError):
|
||||||
|
"""Raised on bad pagination params, invalid UUIDs, or other invalid inputs."""
|
||||||
@@ -1,6 +0,0 @@
|
|||||||
# app/crud/__init__.py
|
|
||||||
from .organization import organization
|
|
||||||
from .session import session as session_crud
|
|
||||||
from .user import user
|
|
||||||
|
|
||||||
__all__ = ["organization", "session_crud", "user"]
|
|
||||||
@@ -6,12 +6,20 @@ Creates the first superuser if configured and doesn't already exist.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import random
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from sqlalchemy import select, text
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.database import SessionLocal, engine
|
from app.core.database import SessionLocal, engine
|
||||||
from app.crud.user import user as user_crud
|
from app.models.organization import Organization
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
from app.models.user_organization import UserOrganization
|
||||||
|
from app.repositories.user import user_repo as user_repo
|
||||||
from app.schemas.users import UserCreate
|
from app.schemas.users import UserCreate
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -26,21 +34,27 @@ async def init_db() -> User | None:
|
|||||||
"""
|
"""
|
||||||
# Use default values if not set in environment variables
|
# Use default values if not set in environment variables
|
||||||
superuser_email = settings.FIRST_SUPERUSER_EMAIL or "admin@example.com"
|
superuser_email = settings.FIRST_SUPERUSER_EMAIL or "admin@example.com"
|
||||||
superuser_password = settings.FIRST_SUPERUSER_PASSWORD or "AdminPassword123!"
|
|
||||||
|
default_password = "AdminPassword123!"
|
||||||
|
if settings.DEMO_MODE:
|
||||||
|
default_password = "AdminPass1234!"
|
||||||
|
|
||||||
|
superuser_password = settings.FIRST_SUPERUSER_PASSWORD or default_password
|
||||||
|
|
||||||
if not settings.FIRST_SUPERUSER_EMAIL or not settings.FIRST_SUPERUSER_PASSWORD:
|
if not settings.FIRST_SUPERUSER_EMAIL or not settings.FIRST_SUPERUSER_PASSWORD:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"First superuser credentials not configured in settings. "
|
"First superuser credentials not configured in settings. "
|
||||||
f"Using defaults: {superuser_email}"
|
"Using defaults: %s",
|
||||||
|
superuser_email,
|
||||||
)
|
)
|
||||||
|
|
||||||
async with SessionLocal() as session:
|
async with SessionLocal() as session:
|
||||||
try:
|
try:
|
||||||
# Check if superuser already exists
|
# Check if superuser already exists
|
||||||
existing_user = await user_crud.get_by_email(session, email=superuser_email)
|
existing_user = await user_repo.get_by_email(session, email=superuser_email)
|
||||||
|
|
||||||
if existing_user:
|
if existing_user:
|
||||||
logger.info(f"Superuser already exists: {existing_user.email}")
|
logger.info("Superuser already exists: %s", existing_user.email)
|
||||||
return existing_user
|
return existing_user
|
||||||
|
|
||||||
# Create superuser if doesn't exist
|
# Create superuser if doesn't exist
|
||||||
@@ -52,19 +66,143 @@ async def init_db() -> User | None:
|
|||||||
is_superuser=True,
|
is_superuser=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
user = await user_crud.create(session, obj_in=user_in)
|
user = await user_repo.create(session, obj_in=user_in)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
await session.refresh(user)
|
await session.refresh(user)
|
||||||
|
|
||||||
logger.info(f"Created first superuser: {user.email}")
|
logger.info("Created first superuser: %s", user.email)
|
||||||
|
|
||||||
|
# Create demo data if in demo mode
|
||||||
|
if settings.DEMO_MODE:
|
||||||
|
await load_demo_data(session)
|
||||||
|
|
||||||
return user
|
return user
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
logger.error(f"Error initializing database: {e}")
|
logger.error("Error initializing database: %s", e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def _load_json_file(path: Path):
|
||||||
|
with open(path) as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
async def load_demo_data(session):
|
||||||
|
"""Load demo data from JSON file."""
|
||||||
|
demo_data_path = Path(__file__).parent / "core" / "demo_data.json"
|
||||||
|
if not demo_data_path.exists():
|
||||||
|
logger.warning("Demo data file not found: %s", demo_data_path)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Use asyncio.to_thread to avoid blocking the event loop
|
||||||
|
data = await asyncio.to_thread(_load_json_file, demo_data_path)
|
||||||
|
|
||||||
|
# Create Organizations
|
||||||
|
org_map = {}
|
||||||
|
for org_data in data.get("organizations", []):
|
||||||
|
# Check if org exists
|
||||||
|
result = await session.execute(
|
||||||
|
text("SELECT * FROM organizations WHERE slug = :slug"),
|
||||||
|
{"slug": org_data["slug"]},
|
||||||
|
)
|
||||||
|
existing_org = result.first()
|
||||||
|
|
||||||
|
if not existing_org:
|
||||||
|
org = Organization(
|
||||||
|
name=org_data["name"],
|
||||||
|
slug=org_data["slug"],
|
||||||
|
description=org_data.get("description"),
|
||||||
|
is_active=True,
|
||||||
|
)
|
||||||
|
session.add(org)
|
||||||
|
await session.flush() # Flush to get ID
|
||||||
|
org_map[org.slug] = org
|
||||||
|
logger.info("Created demo organization: %s", org.name)
|
||||||
|
else:
|
||||||
|
# We can't easily get the ORM object from raw SQL result for map without querying again or mapping
|
||||||
|
# So let's just query it properly if we need it for relationships
|
||||||
|
# But for simplicity in this script, let's just assume we created it or it exists.
|
||||||
|
# To properly map for users, we need the ID.
|
||||||
|
# Let's use a simpler approach: just try to create, if slug conflict, skip.
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Re-query all orgs to build map for users
|
||||||
|
result = await session.execute(select(Organization))
|
||||||
|
orgs = result.scalars().all()
|
||||||
|
org_map = {org.slug: org for org in orgs}
|
||||||
|
|
||||||
|
# Create Users
|
||||||
|
for user_data in data.get("users", []):
|
||||||
|
existing_user = await user_repo.get_by_email(
|
||||||
|
session, email=user_data["email"]
|
||||||
|
)
|
||||||
|
if not existing_user:
|
||||||
|
# Create user
|
||||||
|
user_in = UserCreate(
|
||||||
|
email=user_data["email"],
|
||||||
|
password=user_data["password"],
|
||||||
|
first_name=user_data["first_name"],
|
||||||
|
last_name=user_data["last_name"],
|
||||||
|
is_superuser=user_data["is_superuser"],
|
||||||
|
is_active=user_data.get("is_active", True),
|
||||||
|
)
|
||||||
|
user = await user_repo.create(session, obj_in=user_in)
|
||||||
|
|
||||||
|
# Randomize created_at for demo data (last 30 days)
|
||||||
|
# This makes the charts look more realistic
|
||||||
|
days_ago = random.randint(0, 30) # noqa: S311
|
||||||
|
random_time = datetime.now(UTC) - timedelta(days=days_ago)
|
||||||
|
# Add some random hours/minutes variation
|
||||||
|
random_time = random_time.replace(
|
||||||
|
hour=random.randint(0, 23), # noqa: S311
|
||||||
|
minute=random.randint(0, 59), # noqa: S311
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update the timestamp and is_active directly in the database
|
||||||
|
# We do this to ensure the values are persisted correctly
|
||||||
|
await session.execute(
|
||||||
|
text(
|
||||||
|
"UPDATE users SET created_at = :created_at, is_active = :is_active WHERE id = :user_id"
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"created_at": random_time,
|
||||||
|
"is_active": user_data.get("is_active", True),
|
||||||
|
"user_id": user.id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Created demo user: %s (created %s days ago, active=%s)",
|
||||||
|
user.email,
|
||||||
|
days_ago,
|
||||||
|
user_data.get("is_active", True),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add to organization if specified
|
||||||
|
org_slug = user_data.get("organization_slug")
|
||||||
|
role = user_data.get("role")
|
||||||
|
if org_slug and org_slug in org_map and role:
|
||||||
|
org = org_map[org_slug]
|
||||||
|
# Check if membership exists (it shouldn't for new user)
|
||||||
|
member = UserOrganization(
|
||||||
|
user_id=user.id, organization_id=org.id, role=role
|
||||||
|
)
|
||||||
|
session.add(member)
|
||||||
|
logger.info("Added %s to %s as %s", user.email, org.name, role)
|
||||||
|
else:
|
||||||
|
logger.info("Demo user already exists: %s", existing_user.email)
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
logger.info("Demo data loaded successfully")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error loading demo data: %s", e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
"""Main entry point for database initialization."""
|
"""Main entry point for database initialization."""
|
||||||
# Configure logging to show info logs
|
# Configure logging to show info logs
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from datetime import datetime
|
from datetime import UTC, datetime
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||||
@@ -14,8 +14,9 @@ from slowapi.errors import RateLimitExceeded
|
|||||||
from slowapi.util import get_remote_address
|
from slowapi.util import get_remote_address
|
||||||
|
|
||||||
from app.api.main import api_router
|
from app.api.main import api_router
|
||||||
|
from app.api.routes.oauth_provider import wellknown_router as oauth_wellknown_router
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.database import check_database_health
|
from app.core.database import check_database_health, close_async_db
|
||||||
from app.core.exceptions import (
|
from app.core.exceptions import (
|
||||||
APIException,
|
APIException,
|
||||||
api_exception_handler,
|
api_exception_handler,
|
||||||
@@ -71,6 +72,7 @@ async def lifespan(app: FastAPI):
|
|||||||
if os.getenv("IS_TEST", "False") != "True":
|
if os.getenv("IS_TEST", "False") != "True":
|
||||||
scheduler.shutdown()
|
scheduler.shutdown()
|
||||||
logger.info("Scheduled jobs stopped")
|
logger.info("Scheduled jobs stopped")
|
||||||
|
await close_async_db()
|
||||||
|
|
||||||
|
|
||||||
logger.info("Starting app!!!")
|
logger.info("Starting app!!!")
|
||||||
@@ -293,7 +295,7 @@ async def health_check() -> JSONResponse:
|
|||||||
"""
|
"""
|
||||||
health_status: dict[str, Any] = {
|
health_status: dict[str, Any] = {
|
||||||
"status": "healthy",
|
"status": "healthy",
|
||||||
"timestamp": datetime.utcnow().isoformat() + "Z",
|
"timestamp": datetime.now(UTC).isoformat().replace("+00:00", "Z"),
|
||||||
"version": settings.VERSION,
|
"version": settings.VERSION,
|
||||||
"environment": settings.ENVIRONMENT,
|
"environment": settings.ENVIRONMENT,
|
||||||
"checks": {},
|
"checks": {},
|
||||||
@@ -318,9 +320,13 @@ async def health_check() -> JSONResponse:
|
|||||||
"message": f"Database connection failed: {e!s}",
|
"message": f"Database connection failed: {e!s}",
|
||||||
}
|
}
|
||||||
response_status = status.HTTP_503_SERVICE_UNAVAILABLE
|
response_status = status.HTTP_503_SERVICE_UNAVAILABLE
|
||||||
logger.error(f"Health check failed - database error: {e}")
|
logger.error("Health check failed - database error: %s", e)
|
||||||
|
|
||||||
return JSONResponse(status_code=response_status, content=health_status)
|
return JSONResponse(status_code=response_status, content=health_status)
|
||||||
|
|
||||||
|
|
||||||
app.include_router(api_router, prefix=settings.API_V1_STR)
|
app.include_router(api_router, prefix=settings.API_V1_STR)
|
||||||
|
|
||||||
|
# OAuth 2.0 well-known endpoint at root level per RFC 8414
|
||||||
|
# This allows MCP clients to discover the OAuth server metadata at /.well-known/oauth-authorization-server
|
||||||
|
app.include_router(oauth_wellknown_router)
|
||||||
|
|||||||
@@ -7,6 +7,15 @@ Imports all models to ensure they're registered with SQLAlchemy.
|
|||||||
from app.core.database import Base
|
from app.core.database import Base
|
||||||
|
|
||||||
from .base import TimestampMixin, UUIDMixin
|
from .base import TimestampMixin, UUIDMixin
|
||||||
|
|
||||||
|
# OAuth models (client mode - authenticate via Google/GitHub)
|
||||||
|
from .oauth_account import OAuthAccount
|
||||||
|
|
||||||
|
# OAuth provider models (server mode - act as authorization server for MCP)
|
||||||
|
from .oauth_authorization_code import OAuthAuthorizationCode
|
||||||
|
from .oauth_client import OAuthClient
|
||||||
|
from .oauth_provider_token import OAuthConsent, OAuthProviderRefreshToken
|
||||||
|
from .oauth_state import OAuthState
|
||||||
from .organization import Organization
|
from .organization import Organization
|
||||||
|
|
||||||
# Import models
|
# Import models
|
||||||
@@ -16,6 +25,12 @@ from .user_session import UserSession
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Base",
|
"Base",
|
||||||
|
"OAuthAccount",
|
||||||
|
"OAuthAuthorizationCode",
|
||||||
|
"OAuthClient",
|
||||||
|
"OAuthConsent",
|
||||||
|
"OAuthProviderRefreshToken",
|
||||||
|
"OAuthState",
|
||||||
"Organization",
|
"Organization",
|
||||||
"OrganizationRole",
|
"OrganizationRole",
|
||||||
"TimestampMixin",
|
"TimestampMixin",
|
||||||
|
|||||||
55
backend/app/models/oauth_account.py
Executable file
55
backend/app/models/oauth_account.py
Executable file
@@ -0,0 +1,55 @@
|
|||||||
|
"""OAuth account model for linking external OAuth providers to users."""
|
||||||
|
|
||||||
|
from sqlalchemy import Column, DateTime, ForeignKey, Index, String, UniqueConstraint
|
||||||
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
|
from .base import Base, TimestampMixin, UUIDMixin
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthAccount(Base, UUIDMixin, TimestampMixin):
|
||||||
|
"""
|
||||||
|
Links OAuth provider accounts to users.
|
||||||
|
|
||||||
|
Supports multiple OAuth providers per user (e.g., user can have both
|
||||||
|
Google and GitHub connected). Each provider account is uniquely identified
|
||||||
|
by (provider, provider_user_id).
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "oauth_accounts"
|
||||||
|
|
||||||
|
# Link to user
|
||||||
|
user_id = Column(
|
||||||
|
UUID(as_uuid=True),
|
||||||
|
ForeignKey("users.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
index=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# OAuth provider identification
|
||||||
|
provider = Column(
|
||||||
|
String(50), nullable=False, index=True
|
||||||
|
) # google, github, microsoft
|
||||||
|
provider_user_id = Column(String(255), nullable=False) # Provider's unique user ID
|
||||||
|
provider_email = Column(
|
||||||
|
String(255), nullable=True, index=True
|
||||||
|
) # Email from provider (for reference)
|
||||||
|
|
||||||
|
# Optional: store provider tokens for API access
|
||||||
|
# TODO: Encrypt these at rest in production (requires key management infrastructure)
|
||||||
|
access_token = Column(String(2048), nullable=True)
|
||||||
|
refresh_token = Column(String(2048), nullable=True)
|
||||||
|
token_expires_at = Column(DateTime(timezone=True), nullable=True)
|
||||||
|
|
||||||
|
# Relationship
|
||||||
|
user = relationship("User", back_populates="oauth_accounts")
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
# Each provider account can only be linked to one user
|
||||||
|
UniqueConstraint("provider", "provider_user_id", name="uq_oauth_provider_user"),
|
||||||
|
# Index for finding all OAuth accounts for a user + provider
|
||||||
|
Index("ix_oauth_accounts_user_provider", "user_id", "provider"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<OAuthAccount {self.provider}:{self.provider_user_id}>"
|
||||||
100
backend/app/models/oauth_authorization_code.py
Executable file
100
backend/app/models/oauth_authorization_code.py
Executable file
@@ -0,0 +1,100 @@
|
|||||||
|
"""OAuth authorization code model for OAuth provider mode."""
|
||||||
|
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
|
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Index, String
|
||||||
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
|
from .base import Base, TimestampMixin, UUIDMixin
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthAuthorizationCode(Base, UUIDMixin, TimestampMixin):
|
||||||
|
"""
|
||||||
|
OAuth 2.0 Authorization Code for the authorization code flow.
|
||||||
|
|
||||||
|
Authorization codes are:
|
||||||
|
- Single-use (marked as used after exchange)
|
||||||
|
- Short-lived (10 minutes default)
|
||||||
|
- Bound to specific client, user, redirect_uri
|
||||||
|
- Support PKCE (code_challenge/code_challenge_method)
|
||||||
|
|
||||||
|
Security considerations:
|
||||||
|
- Code must be cryptographically random (64 chars, URL-safe)
|
||||||
|
- Must validate redirect_uri matches exactly
|
||||||
|
- Must verify PKCE code_verifier for public clients
|
||||||
|
- Must be consumed within expiration time
|
||||||
|
|
||||||
|
Performance indexes (defined in migration 0002_add_performance_indexes.py):
|
||||||
|
- ix_perf_oauth_auth_codes_expires: expires_at WHERE used = false
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "oauth_authorization_codes"
|
||||||
|
|
||||||
|
# The authorization code (cryptographically random, URL-safe)
|
||||||
|
code = Column(String(128), unique=True, nullable=False, index=True)
|
||||||
|
|
||||||
|
# Client that requested the code
|
||||||
|
client_id = Column(
|
||||||
|
String(64),
|
||||||
|
ForeignKey("oauth_clients.client_id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# User who authorized the request
|
||||||
|
user_id = Column(
|
||||||
|
UUID(as_uuid=True),
|
||||||
|
ForeignKey("users.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Redirect URI (must match exactly on token exchange)
|
||||||
|
redirect_uri = Column(String(2048), nullable=False)
|
||||||
|
|
||||||
|
# Granted scopes (space-separated)
|
||||||
|
scope = Column(String(1000), nullable=False, default="")
|
||||||
|
|
||||||
|
# PKCE support (required for public clients)
|
||||||
|
code_challenge = Column(String(128), nullable=True)
|
||||||
|
code_challenge_method = Column(String(10), nullable=True) # "S256" or "plain"
|
||||||
|
|
||||||
|
# State parameter (for CSRF protection, returned to client)
|
||||||
|
state = Column(String(256), nullable=True)
|
||||||
|
|
||||||
|
# Nonce (for OpenID Connect, included in ID token)
|
||||||
|
nonce = Column(String(256), nullable=True)
|
||||||
|
|
||||||
|
# Expiration (codes are short-lived, typically 10 minutes)
|
||||||
|
expires_at = Column(DateTime(timezone=True), nullable=False)
|
||||||
|
|
||||||
|
# Single-use flag (set to True after successful exchange)
|
||||||
|
used = Column(Boolean, default=False, nullable=False)
|
||||||
|
|
||||||
|
# Relationships
|
||||||
|
client = relationship("OAuthClient", backref="authorization_codes")
|
||||||
|
user = relationship("User", backref="oauth_authorization_codes")
|
||||||
|
|
||||||
|
# Indexes for efficient cleanup queries
|
||||||
|
__table_args__ = (
|
||||||
|
Index("ix_oauth_authorization_codes_expires_at", "expires_at"),
|
||||||
|
Index("ix_oauth_authorization_codes_client_user", "client_id", "user_id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<OAuthAuthorizationCode {self.code[:8]}... for {self.client_id}>"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_expired(self) -> bool:
|
||||||
|
"""Check if the authorization code has expired."""
|
||||||
|
# Use timezone-aware comparison (datetime.utcnow() is deprecated)
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
expires_at = self.expires_at
|
||||||
|
# Handle both timezone-aware and naive datetimes from DB
|
||||||
|
if expires_at.tzinfo is None:
|
||||||
|
expires_at = expires_at.replace(tzinfo=UTC)
|
||||||
|
return bool(now > expires_at)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_valid(self) -> bool:
|
||||||
|
"""Check if the authorization code is valid (not used, not expired)."""
|
||||||
|
return not self.used and not self.is_expired
|
||||||
67
backend/app/models/oauth_client.py
Executable file
67
backend/app/models/oauth_client.py
Executable file
@@ -0,0 +1,67 @@
|
|||||||
|
"""OAuth client model for OAuth provider mode (MCP clients)."""
|
||||||
|
|
||||||
|
from sqlalchemy import Boolean, Column, ForeignKey, String
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||||
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
|
from .base import Base, TimestampMixin, UUIDMixin
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthClient(Base, UUIDMixin, TimestampMixin):
|
||||||
|
"""
|
||||||
|
Registered OAuth clients (for OAuth provider mode).
|
||||||
|
|
||||||
|
This model stores third-party applications that can authenticate
|
||||||
|
against this API using OAuth 2.0. Used for MCP (Model Context Protocol)
|
||||||
|
client authentication and API access.
|
||||||
|
|
||||||
|
NOTE: This is a skeleton implementation. The full OAuth provider
|
||||||
|
functionality (authorization endpoint, token endpoint, etc.) can be
|
||||||
|
expanded when needed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "oauth_clients"
|
||||||
|
|
||||||
|
# Client credentials
|
||||||
|
client_id = Column(String(64), unique=True, nullable=False, index=True)
|
||||||
|
client_secret_hash = Column(
|
||||||
|
String(255), nullable=True
|
||||||
|
) # NULL for public clients (PKCE)
|
||||||
|
|
||||||
|
# Client metadata
|
||||||
|
client_name = Column(String(255), nullable=False)
|
||||||
|
client_description = Column(String(1000), nullable=True)
|
||||||
|
|
||||||
|
# Client type: "public" (SPA, mobile) or "confidential" (server-side)
|
||||||
|
client_type = Column(String(20), nullable=False, default="public")
|
||||||
|
|
||||||
|
# Allowed redirect URIs (JSON array)
|
||||||
|
redirect_uris = Column(JSONB, nullable=False, default=list)
|
||||||
|
|
||||||
|
# Allowed scopes (JSON array of scope names)
|
||||||
|
allowed_scopes = Column(JSONB, nullable=False, default=list)
|
||||||
|
|
||||||
|
# Token lifetimes (in seconds)
|
||||||
|
access_token_lifetime = Column(String(10), nullable=False, default="3600") # 1 hour
|
||||||
|
refresh_token_lifetime = Column(
|
||||||
|
String(10), nullable=False, default="604800"
|
||||||
|
) # 7 days
|
||||||
|
|
||||||
|
# Status
|
||||||
|
is_active = Column(Boolean, default=True, nullable=False, index=True)
|
||||||
|
|
||||||
|
# Optional: owner user (for user-registered applications)
|
||||||
|
owner_user_id = Column(
|
||||||
|
UUID(as_uuid=True),
|
||||||
|
ForeignKey("users.id", ondelete="SET NULL"),
|
||||||
|
nullable=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# MCP-specific: URL of the MCP server this client represents
|
||||||
|
mcp_server_url = Column(String(2048), nullable=True)
|
||||||
|
|
||||||
|
# Relationship
|
||||||
|
owner = relationship("User", backref="owned_oauth_clients")
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<OAuthClient {self.client_name} ({self.client_id[:8]}...)>"
|
||||||
162
backend/app/models/oauth_provider_token.py
Executable file
162
backend/app/models/oauth_provider_token.py
Executable file
@@ -0,0 +1,162 @@
|
|||||||
|
"""OAuth provider token models for OAuth provider mode."""
|
||||||
|
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
|
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Index, String
|
||||||
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
|
from .base import Base, TimestampMixin, UUIDMixin
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthProviderRefreshToken(Base, UUIDMixin, TimestampMixin):
|
||||||
|
"""
|
||||||
|
OAuth 2.0 Refresh Token for the OAuth provider.
|
||||||
|
|
||||||
|
Refresh tokens are:
|
||||||
|
- Opaque (stored as hash in DB, actual token given to client)
|
||||||
|
- Long-lived (configurable, default 30 days)
|
||||||
|
- Revocable (via revoked flag or deletion)
|
||||||
|
- Bound to specific client, user, and scope
|
||||||
|
|
||||||
|
Access tokens are JWTs and not stored in DB (self-contained).
|
||||||
|
This model only tracks refresh tokens for revocation support.
|
||||||
|
|
||||||
|
Security considerations:
|
||||||
|
- Store token hash, not plaintext
|
||||||
|
- Support token rotation (new refresh token on use)
|
||||||
|
- Track last used time for security auditing
|
||||||
|
- Support revocation by user, client, or admin
|
||||||
|
|
||||||
|
Performance indexes (defined in migration 0002_add_performance_indexes.py):
|
||||||
|
- ix_perf_oauth_refresh_tokens_expires: expires_at WHERE revoked = false
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "oauth_provider_refresh_tokens"
|
||||||
|
|
||||||
|
# Hash of the refresh token (SHA-256)
|
||||||
|
# We store hash, not plaintext, for security
|
||||||
|
token_hash = Column(String(64), unique=True, nullable=False, index=True)
|
||||||
|
|
||||||
|
# Unique token ID (JTI) - used in JWT access tokens to reference this refresh token
|
||||||
|
jti = Column(String(64), unique=True, nullable=False, index=True)
|
||||||
|
|
||||||
|
# Client that owns this token
|
||||||
|
client_id = Column(
|
||||||
|
String(64),
|
||||||
|
ForeignKey("oauth_clients.client_id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# User who authorized this token
|
||||||
|
user_id = Column(
|
||||||
|
UUID(as_uuid=True),
|
||||||
|
ForeignKey("users.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Granted scopes (space-separated)
|
||||||
|
scope = Column(String(1000), nullable=False, default="")
|
||||||
|
|
||||||
|
# Token expiration
|
||||||
|
expires_at = Column(DateTime(timezone=True), nullable=False)
|
||||||
|
|
||||||
|
# Revocation flag
|
||||||
|
revoked = Column(Boolean, default=False, nullable=False, index=True)
|
||||||
|
|
||||||
|
# Last used timestamp (for security auditing)
|
||||||
|
last_used_at = Column(DateTime(timezone=True), nullable=True)
|
||||||
|
|
||||||
|
# Device/session info (optional, for user visibility)
|
||||||
|
device_info = Column(String(500), nullable=True)
|
||||||
|
ip_address = Column(String(45), nullable=True)
|
||||||
|
|
||||||
|
# Relationships
|
||||||
|
client = relationship("OAuthClient", backref="refresh_tokens")
|
||||||
|
user = relationship("User", backref="oauth_provider_refresh_tokens")
|
||||||
|
|
||||||
|
# Indexes
|
||||||
|
__table_args__ = (
|
||||||
|
Index("ix_oauth_provider_refresh_tokens_expires_at", "expires_at"),
|
||||||
|
Index("ix_oauth_provider_refresh_tokens_client_user", "client_id", "user_id"),
|
||||||
|
Index(
|
||||||
|
"ix_oauth_provider_refresh_tokens_user_revoked",
|
||||||
|
"user_id",
|
||||||
|
"revoked",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
status = "revoked" if self.revoked else "active"
|
||||||
|
return f"<OAuthProviderRefreshToken {self.jti[:8]}... ({status})>"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_expired(self) -> bool:
|
||||||
|
"""Check if the refresh token has expired."""
|
||||||
|
# Use timezone-aware comparison (datetime.utcnow() is deprecated)
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
expires_at = self.expires_at
|
||||||
|
# Handle both timezone-aware and naive datetimes from DB
|
||||||
|
if expires_at.tzinfo is None:
|
||||||
|
expires_at = expires_at.replace(tzinfo=UTC)
|
||||||
|
return bool(now > expires_at)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_valid(self) -> bool:
|
||||||
|
"""Check if the refresh token is valid (not revoked, not expired)."""
|
||||||
|
return not self.revoked and not self.is_expired
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthConsent(Base, UUIDMixin, TimestampMixin):
|
||||||
|
"""
|
||||||
|
OAuth consent record - remembers user consent for a client.
|
||||||
|
|
||||||
|
When a user grants consent to an OAuth client, we store the record
|
||||||
|
so they don't have to re-consent on subsequent authorizations
|
||||||
|
(unless scopes change).
|
||||||
|
|
||||||
|
This enables a better UX - users only see consent screen once per client,
|
||||||
|
unless the client requests additional scopes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "oauth_consents"
|
||||||
|
|
||||||
|
# User who granted consent
|
||||||
|
user_id = Column(
|
||||||
|
UUID(as_uuid=True),
|
||||||
|
ForeignKey("users.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Client that received consent
|
||||||
|
client_id = Column(
|
||||||
|
String(64),
|
||||||
|
ForeignKey("oauth_clients.client_id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Granted scopes (space-separated)
|
||||||
|
granted_scopes = Column(String(1000), nullable=False, default="")
|
||||||
|
|
||||||
|
# Relationships
|
||||||
|
client = relationship("OAuthClient", backref="consents")
|
||||||
|
user = relationship("User", backref="oauth_consents")
|
||||||
|
|
||||||
|
# Unique constraint: one consent record per user+client
|
||||||
|
__table_args__ = (
|
||||||
|
Index(
|
||||||
|
"ix_oauth_consents_user_client",
|
||||||
|
"user_id",
|
||||||
|
"client_id",
|
||||||
|
unique=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<OAuthConsent user={self.user_id} client={self.client_id}>"
|
||||||
|
|
||||||
|
def has_scopes(self, requested_scopes: list[str]) -> bool:
|
||||||
|
"""Check if all requested scopes are already granted."""
|
||||||
|
granted = set(self.granted_scopes.split()) if self.granted_scopes else set()
|
||||||
|
requested = set(requested_scopes)
|
||||||
|
return requested.issubset(granted)
|
||||||
45
backend/app/models/oauth_state.py
Executable file
45
backend/app/models/oauth_state.py
Executable file
@@ -0,0 +1,45 @@
|
|||||||
|
"""OAuth state model for CSRF protection during OAuth flows."""
|
||||||
|
|
||||||
|
from sqlalchemy import Column, DateTime, String
|
||||||
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
|
|
||||||
|
from .base import Base, TimestampMixin, UUIDMixin
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthState(Base, UUIDMixin, TimestampMixin):
|
||||||
|
"""
|
||||||
|
Temporary storage for OAuth state parameters.
|
||||||
|
|
||||||
|
Prevents CSRF attacks during OAuth flows by storing a random state
|
||||||
|
value that must match on callback. Also stores PKCE code_verifier
|
||||||
|
for the Authorization Code flow with PKCE.
|
||||||
|
|
||||||
|
These records are short-lived (10 minutes by default) and should
|
||||||
|
be deleted after use or expiration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "oauth_states"
|
||||||
|
|
||||||
|
# Random state parameter (CSRF protection)
|
||||||
|
state = Column(String(255), unique=True, nullable=False, index=True)
|
||||||
|
|
||||||
|
# PKCE code_verifier (used to generate code_challenge)
|
||||||
|
code_verifier = Column(String(128), nullable=True)
|
||||||
|
|
||||||
|
# OIDC nonce for ID token replay protection
|
||||||
|
nonce = Column(String(255), nullable=True)
|
||||||
|
|
||||||
|
# OAuth provider (google, github, etc.)
|
||||||
|
provider = Column(String(50), nullable=False)
|
||||||
|
|
||||||
|
# Original redirect URI (for callback validation)
|
||||||
|
redirect_uri = Column(String(500), nullable=True)
|
||||||
|
|
||||||
|
# User ID if this is an account linking flow (user is already logged in)
|
||||||
|
user_id = Column(UUID(as_uuid=True), nullable=True)
|
||||||
|
|
||||||
|
# Expiration time
|
||||||
|
expires_at = Column(DateTime(timezone=True), nullable=False)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<OAuthState {self.state[:8]}... ({self.provider})>"
|
||||||
@@ -10,6 +10,9 @@ class Organization(Base, UUIDMixin, TimestampMixin):
|
|||||||
"""
|
"""
|
||||||
Organization model for multi-tenant support.
|
Organization model for multi-tenant support.
|
||||||
Users can belong to multiple organizations with different roles.
|
Users can belong to multiple organizations with different roles.
|
||||||
|
|
||||||
|
Performance indexes (defined in migration 0002_add_performance_indexes.py):
|
||||||
|
- ix_perf_organizations_slug_lower: LOWER(slug) WHERE is_active = true
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__tablename__ = "organizations"
|
__tablename__ = "organizations"
|
||||||
|
|||||||
@@ -6,10 +6,19 @@ from .base import Base, TimestampMixin, UUIDMixin
|
|||||||
|
|
||||||
|
|
||||||
class User(Base, UUIDMixin, TimestampMixin):
|
class User(Base, UUIDMixin, TimestampMixin):
|
||||||
|
"""
|
||||||
|
User model for authentication and profile data.
|
||||||
|
|
||||||
|
Performance indexes (defined in migration 0002_add_performance_indexes.py):
|
||||||
|
- ix_perf_users_email_lower: LOWER(email) WHERE deleted_at IS NULL
|
||||||
|
- ix_perf_users_active: is_active WHERE deleted_at IS NULL
|
||||||
|
"""
|
||||||
|
|
||||||
__tablename__ = "users"
|
__tablename__ = "users"
|
||||||
|
|
||||||
email = Column(String(255), unique=True, nullable=False, index=True)
|
email = Column(String(255), unique=True, nullable=False, index=True)
|
||||||
password_hash = Column(String(255), nullable=False)
|
# Nullable to support OAuth-only users who never set a password
|
||||||
|
password_hash = Column(String(255), nullable=True)
|
||||||
first_name = Column(String(100), nullable=False, default="user")
|
first_name = Column(String(100), nullable=False, default="user")
|
||||||
last_name = Column(String(100), nullable=True)
|
last_name = Column(String(100), nullable=True)
|
||||||
phone_number = Column(String(20))
|
phone_number = Column(String(20))
|
||||||
@@ -23,6 +32,19 @@ class User(Base, UUIDMixin, TimestampMixin):
|
|||||||
user_organizations = relationship(
|
user_organizations = relationship(
|
||||||
"UserOrganization", back_populates="user", cascade="all, delete-orphan"
|
"UserOrganization", back_populates="user", cascade="all, delete-orphan"
|
||||||
)
|
)
|
||||||
|
oauth_accounts = relationship(
|
||||||
|
"OAuthAccount", back_populates="user", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def has_password(self) -> bool:
|
||||||
|
"""Check if user can login with password (not OAuth-only)."""
|
||||||
|
return self.password_hash is not None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def can_remove_oauth(self) -> bool:
|
||||||
|
"""Check if user can safely remove an OAuth account link."""
|
||||||
|
return self.has_password or len(self.oauth_accounts) > 1
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"<User {self.email}>"
|
return f"<User {self.email}>"
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ class UserOrganization(Base, TimestampMixin):
|
|||||||
Enum(OrganizationRole),
|
Enum(OrganizationRole),
|
||||||
default=OrganizationRole.MEMBER,
|
default=OrganizationRole.MEMBER,
|
||||||
nullable=False,
|
nullable=False,
|
||||||
index=True,
|
# Note: index defined in __table_args__ as ix_user_org_role
|
||||||
)
|
)
|
||||||
is_active = Column(Boolean, default=True, nullable=False, index=True)
|
is_active = Column(Boolean, default=True, nullable=False, index=True)
|
||||||
|
|
||||||
|
|||||||
@@ -22,6 +22,9 @@ class UserSession(Base, UUIDMixin, TimestampMixin):
|
|||||||
|
|
||||||
Each time a user logs in from a device, a new session is created.
|
Each time a user logs in from a device, a new session is created.
|
||||||
Sessions are identified by the refresh token JTI (JWT ID).
|
Sessions are identified by the refresh token JTI (JWT ID).
|
||||||
|
|
||||||
|
Performance indexes (defined in migration 0002_add_performance_indexes.py):
|
||||||
|
- ix_perf_user_sessions_expires: expires_at WHERE is_active = true
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__tablename__ = "user_sessions"
|
__tablename__ = "user_sessions"
|
||||||
@@ -73,7 +76,11 @@ class UserSession(Base, UUIDMixin, TimestampMixin):
|
|||||||
"""Check if session has expired."""
|
"""Check if session has expired."""
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
return self.expires_at < datetime.now(UTC)
|
now = datetime.now(UTC)
|
||||||
|
expires_at = self.expires_at
|
||||||
|
if expires_at.tzinfo is None:
|
||||||
|
expires_at = expires_at.replace(tzinfo=UTC)
|
||||||
|
return bool(expires_at < now)
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
"""Convert session to dictionary for serialization."""
|
"""Convert session to dictionary for serialization."""
|
||||||
|
|||||||
39
backend/app/repositories/__init__.py
Normal file
39
backend/app/repositories/__init__.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
# app/repositories/__init__.py
|
||||||
|
"""Repository layer — all database access goes through these classes."""
|
||||||
|
|
||||||
|
from app.repositories.oauth_account import OAuthAccountRepository, oauth_account_repo
|
||||||
|
from app.repositories.oauth_authorization_code import (
|
||||||
|
OAuthAuthorizationCodeRepository,
|
||||||
|
oauth_authorization_code_repo,
|
||||||
|
)
|
||||||
|
from app.repositories.oauth_client import OAuthClientRepository, oauth_client_repo
|
||||||
|
from app.repositories.oauth_consent import OAuthConsentRepository, oauth_consent_repo
|
||||||
|
from app.repositories.oauth_provider_token import (
|
||||||
|
OAuthProviderTokenRepository,
|
||||||
|
oauth_provider_token_repo,
|
||||||
|
)
|
||||||
|
from app.repositories.oauth_state import OAuthStateRepository, oauth_state_repo
|
||||||
|
from app.repositories.organization import OrganizationRepository, organization_repo
|
||||||
|
from app.repositories.session import SessionRepository, session_repo
|
||||||
|
from app.repositories.user import UserRepository, user_repo
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"OAuthAccountRepository",
|
||||||
|
"OAuthAuthorizationCodeRepository",
|
||||||
|
"OAuthClientRepository",
|
||||||
|
"OAuthConsentRepository",
|
||||||
|
"OAuthProviderTokenRepository",
|
||||||
|
"OAuthStateRepository",
|
||||||
|
"OrganizationRepository",
|
||||||
|
"SessionRepository",
|
||||||
|
"UserRepository",
|
||||||
|
"oauth_account_repo",
|
||||||
|
"oauth_authorization_code_repo",
|
||||||
|
"oauth_client_repo",
|
||||||
|
"oauth_consent_repo",
|
||||||
|
"oauth_provider_token_repo",
|
||||||
|
"oauth_state_repo",
|
||||||
|
"organization_repo",
|
||||||
|
"session_repo",
|
||||||
|
"user_repo",
|
||||||
|
]
|
||||||
180
backend/app/crud/base.py → backend/app/repositories/base.py
Executable file → Normal file
180
backend/app/crud/base.py → backend/app/repositories/base.py
Executable file → Normal file
@@ -1,6 +1,6 @@
|
|||||||
# app/crud/base_async.py
|
# app/repositories/base.py
|
||||||
"""
|
"""
|
||||||
Async CRUD operations base class using SQLAlchemy 2.0 async patterns.
|
Base repository class for async database operations using SQLAlchemy 2.0 async patterns.
|
||||||
|
|
||||||
Provides reusable create, read, update, and delete operations for all models.
|
Provides reusable create, read, update, and delete operations for all models.
|
||||||
"""
|
"""
|
||||||
@@ -18,6 +18,11 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||||||
from sqlalchemy.orm import Load
|
from sqlalchemy.orm import Load
|
||||||
|
|
||||||
from app.core.database import Base
|
from app.core.database import Base
|
||||||
|
from app.core.repository_exceptions import (
|
||||||
|
DuplicateEntryError,
|
||||||
|
IntegrityConstraintError,
|
||||||
|
InvalidInputError,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -26,16 +31,16 @@ CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
|
|||||||
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
|
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
|
||||||
|
|
||||||
|
|
||||||
class CRUDBase[
|
class BaseRepository[
|
||||||
ModelType: Base,
|
ModelType: Base,
|
||||||
CreateSchemaType: BaseModel,
|
CreateSchemaType: BaseModel,
|
||||||
UpdateSchemaType: BaseModel,
|
UpdateSchemaType: BaseModel,
|
||||||
]:
|
]:
|
||||||
"""Async CRUD operations for a model."""
|
"""Async repository operations for a model."""
|
||||||
|
|
||||||
def __init__(self, model: type[ModelType]):
|
def __init__(self, model: type[ModelType]):
|
||||||
"""
|
"""
|
||||||
CRUD object with default async methods to Create, Read, Update, Delete.
|
Repository object with default async methods to Create, Read, Update, Delete.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
model: A SQLAlchemy model class
|
model: A SQLAlchemy model class
|
||||||
@@ -56,26 +61,19 @@ class CRUDBase[
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Model instance or None if not found
|
Model instance or None if not found
|
||||||
|
|
||||||
Example:
|
|
||||||
# Eager load user relationship
|
|
||||||
from sqlalchemy.orm import joinedload
|
|
||||||
session = await session_crud.get(db, id=session_id, options=[joinedload(UserSession.user)])
|
|
||||||
"""
|
"""
|
||||||
# Validate UUID format and convert to UUID object if string
|
|
||||||
try:
|
try:
|
||||||
if isinstance(id, uuid.UUID):
|
if isinstance(id, uuid.UUID):
|
||||||
uuid_obj = id
|
uuid_obj = id
|
||||||
else:
|
else:
|
||||||
uuid_obj = uuid.UUID(str(id))
|
uuid_obj = uuid.UUID(str(id))
|
||||||
except (ValueError, AttributeError, TypeError) as e:
|
except (ValueError, AttributeError, TypeError) as e:
|
||||||
logger.warning(f"Invalid UUID format: {id} - {e!s}")
|
logger.warning("Invalid UUID format: %s - %s", id, e)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
query = select(self.model).where(self.model.id == uuid_obj)
|
query = select(self.model).where(self.model.id == uuid_obj)
|
||||||
|
|
||||||
# Apply eager loading options if provided
|
|
||||||
if options:
|
if options:
|
||||||
for option in options:
|
for option in options:
|
||||||
query = query.options(option)
|
query = query.options(option)
|
||||||
@@ -83,7 +81,9 @@ class CRUDBase[
|
|||||||
result = await db.execute(query)
|
result = await db.execute(query)
|
||||||
return result.scalar_one_or_none()
|
return result.scalar_one_or_none()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error retrieving {self.model.__name__} with id {id}: {e!s}")
|
logger.error(
|
||||||
|
"Error retrieving %s with id %s: %s", self.model.__name__, id, e
|
||||||
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def get_multi(
|
async def get_multi(
|
||||||
@@ -96,28 +96,17 @@ class CRUDBase[
|
|||||||
) -> list[ModelType]:
|
) -> list[ModelType]:
|
||||||
"""
|
"""
|
||||||
Get multiple records with pagination validation and optional eager loading.
|
Get multiple records with pagination validation and optional eager loading.
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
skip: Number of records to skip
|
|
||||||
limit: Maximum number of records to return
|
|
||||||
options: Optional list of SQLAlchemy load options for eager loading
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of model instances
|
|
||||||
"""
|
"""
|
||||||
# Validate pagination parameters
|
|
||||||
if skip < 0:
|
if skip < 0:
|
||||||
raise ValueError("skip must be non-negative")
|
raise InvalidInputError("skip must be non-negative")
|
||||||
if limit < 0:
|
if limit < 0:
|
||||||
raise ValueError("limit must be non-negative")
|
raise InvalidInputError("limit must be non-negative")
|
||||||
if limit > 1000:
|
if limit > 1000:
|
||||||
raise ValueError("Maximum limit is 1000")
|
raise InvalidInputError("Maximum limit is 1000")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
query = select(self.model).offset(skip).limit(limit)
|
query = select(self.model).order_by(self.model.id).offset(skip).limit(limit)
|
||||||
|
|
||||||
# Apply eager loading options if provided
|
|
||||||
if options:
|
if options:
|
||||||
for option in options:
|
for option in options:
|
||||||
query = query.options(option)
|
query = query.options(option)
|
||||||
@@ -126,7 +115,7 @@ class CRUDBase[
|
|||||||
return list(result.scalars().all())
|
return list(result.scalars().all())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Error retrieving multiple {self.model.__name__} records: {e!s}"
|
"Error retrieving multiple %s records: %s", self.model.__name__, e
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@@ -136,9 +125,8 @@ class CRUDBase[
|
|||||||
"""Create a new record with error handling.
|
"""Create a new record with error handling.
|
||||||
|
|
||||||
NOTE: This method is defensive code that's never called in practice.
|
NOTE: This method is defensive code that's never called in practice.
|
||||||
All CRUD subclasses (CRUDUser, CRUDOrganization, CRUDSession) override this method
|
All repository subclasses override this method with their own implementations.
|
||||||
with their own implementations, so the base implementation and its exception handlers
|
Marked as pragma: no cover to avoid false coverage gaps.
|
||||||
are never executed. Marked as pragma: no cover to avoid false coverage gaps.
|
|
||||||
"""
|
"""
|
||||||
try: # pragma: no cover
|
try: # pragma: no cover
|
||||||
obj_in_data = jsonable_encoder(obj_in)
|
obj_in_data = jsonable_encoder(obj_in)
|
||||||
@@ -152,22 +140,24 @@ class CRUDBase[
|
|||||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||||
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
|
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Duplicate entry attempted for {self.model.__name__}: {error_msg}"
|
"Duplicate entry attempted for %s: %s",
|
||||||
|
self.model.__name__,
|
||||||
|
error_msg,
|
||||||
)
|
)
|
||||||
raise ValueError(
|
raise DuplicateEntryError(
|
||||||
f"A {self.model.__name__} with this data already exists"
|
f"A {self.model.__name__} with this data already exists"
|
||||||
)
|
)
|
||||||
logger.error(f"Integrity error creating {self.model.__name__}: {error_msg}")
|
logger.error(
|
||||||
raise ValueError(f"Database integrity error: {error_msg}")
|
"Integrity error creating %s: %s", self.model.__name__, error_msg
|
||||||
|
)
|
||||||
|
raise IntegrityConstraintError(f"Database integrity error: {error_msg}")
|
||||||
except (OperationalError, DataError) as e: # pragma: no cover
|
except (OperationalError, DataError) as e: # pragma: no cover
|
||||||
await db.rollback()
|
await db.rollback()
|
||||||
logger.error(f"Database error creating {self.model.__name__}: {e!s}")
|
logger.error("Database error creating %s: %s", self.model.__name__, e)
|
||||||
raise ValueError(f"Database operation failed: {e!s}")
|
raise IntegrityConstraintError(f"Database operation failed: {e!s}")
|
||||||
except Exception as e: # pragma: no cover
|
except Exception as e: # pragma: no cover
|
||||||
await db.rollback()
|
await db.rollback()
|
||||||
logger.error(
|
logger.exception("Unexpected error creating %s: %s", self.model.__name__, e)
|
||||||
f"Unexpected error creating {self.model.__name__}: {e!s}", exc_info=True
|
|
||||||
)
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def update(
|
async def update(
|
||||||
@@ -198,34 +188,35 @@ class CRUDBase[
|
|||||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||||
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
|
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Duplicate entry attempted for {self.model.__name__}: {error_msg}"
|
"Duplicate entry attempted for %s: %s",
|
||||||
|
self.model.__name__,
|
||||||
|
error_msg,
|
||||||
)
|
)
|
||||||
raise ValueError(
|
raise DuplicateEntryError(
|
||||||
f"A {self.model.__name__} with this data already exists"
|
f"A {self.model.__name__} with this data already exists"
|
||||||
)
|
)
|
||||||
logger.error(f"Integrity error updating {self.model.__name__}: {error_msg}")
|
logger.error(
|
||||||
raise ValueError(f"Database integrity error: {error_msg}")
|
"Integrity error updating %s: %s", self.model.__name__, error_msg
|
||||||
|
)
|
||||||
|
raise IntegrityConstraintError(f"Database integrity error: {error_msg}")
|
||||||
except (OperationalError, DataError) as e:
|
except (OperationalError, DataError) as e:
|
||||||
await db.rollback()
|
await db.rollback()
|
||||||
logger.error(f"Database error updating {self.model.__name__}: {e!s}")
|
logger.error("Database error updating %s: %s", self.model.__name__, e)
|
||||||
raise ValueError(f"Database operation failed: {e!s}")
|
raise IntegrityConstraintError(f"Database operation failed: {e!s}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await db.rollback()
|
await db.rollback()
|
||||||
logger.error(
|
logger.exception("Unexpected error updating %s: %s", self.model.__name__, e)
|
||||||
f"Unexpected error updating {self.model.__name__}: {e!s}", exc_info=True
|
|
||||||
)
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def remove(self, db: AsyncSession, *, id: str) -> ModelType | None:
|
async def remove(self, db: AsyncSession, *, id: str) -> ModelType | None:
|
||||||
"""Delete a record with error handling and null check."""
|
"""Delete a record with error handling and null check."""
|
||||||
# Validate UUID format and convert to UUID object if string
|
|
||||||
try:
|
try:
|
||||||
if isinstance(id, uuid.UUID):
|
if isinstance(id, uuid.UUID):
|
||||||
uuid_obj = id
|
uuid_obj = id
|
||||||
else:
|
else:
|
||||||
uuid_obj = uuid.UUID(str(id))
|
uuid_obj = uuid.UUID(str(id))
|
||||||
except (ValueError, AttributeError, TypeError) as e:
|
except (ValueError, AttributeError, TypeError) as e:
|
||||||
logger.warning(f"Invalid UUID format for deletion: {id} - {e!s}")
|
logger.warning("Invalid UUID format for deletion: %s - %s", id, e)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -236,7 +227,7 @@ class CRUDBase[
|
|||||||
|
|
||||||
if obj is None:
|
if obj is None:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"{self.model.__name__} with id {id} not found for deletion"
|
"%s with id %s not found for deletion", self.model.__name__, id
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -246,15 +237,16 @@ class CRUDBase[
|
|||||||
except IntegrityError as e:
|
except IntegrityError as e:
|
||||||
await db.rollback()
|
await db.rollback()
|
||||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||||
logger.error(f"Integrity error deleting {self.model.__name__}: {error_msg}")
|
logger.error(
|
||||||
raise ValueError(
|
"Integrity error deleting %s: %s", self.model.__name__, error_msg
|
||||||
|
)
|
||||||
|
raise IntegrityConstraintError(
|
||||||
f"Cannot delete {self.model.__name__}: referenced by other records"
|
f"Cannot delete {self.model.__name__}: referenced by other records"
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await db.rollback()
|
await db.rollback()
|
||||||
logger.error(
|
logger.exception(
|
||||||
f"Error deleting {self.model.__name__} with id {id}: {e!s}",
|
"Error deleting %s with id %s: %s", self.model.__name__, id, e
|
||||||
exc_info=True,
|
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@@ -267,65 +259,53 @@ class CRUDBase[
|
|||||||
sort_by: str | None = None,
|
sort_by: str | None = None,
|
||||||
sort_order: str = "asc",
|
sort_order: str = "asc",
|
||||||
filters: dict[str, Any] | None = None,
|
filters: dict[str, Any] | None = None,
|
||||||
) -> tuple[list[ModelType], int]:
|
) -> tuple[list[ModelType], int]: # pragma: no cover
|
||||||
"""
|
"""
|
||||||
Get multiple records with total count, filtering, and sorting.
|
Get multiple records with total count, filtering, and sorting.
|
||||||
|
|
||||||
Args:
|
NOTE: This method is defensive code that's never called in practice.
|
||||||
db: Database session
|
All repository subclasses override this method with their own implementations.
|
||||||
skip: Number of records to skip
|
Marked as pragma: no cover to avoid false coverage gaps.
|
||||||
limit: Maximum number of records to return
|
|
||||||
sort_by: Field name to sort by (must be a valid model attribute)
|
|
||||||
sort_order: Sort order ("asc" or "desc")
|
|
||||||
filters: Dictionary of filters (field_name: value)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (items, total_count)
|
|
||||||
"""
|
"""
|
||||||
# Validate pagination parameters
|
|
||||||
if skip < 0:
|
if skip < 0:
|
||||||
raise ValueError("skip must be non-negative")
|
raise InvalidInputError("skip must be non-negative")
|
||||||
if limit < 0:
|
if limit < 0:
|
||||||
raise ValueError("limit must be non-negative")
|
raise InvalidInputError("limit must be non-negative")
|
||||||
if limit > 1000:
|
if limit > 1000:
|
||||||
raise ValueError("Maximum limit is 1000")
|
raise InvalidInputError("Maximum limit is 1000")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Build base query
|
|
||||||
query = select(self.model)
|
query = select(self.model)
|
||||||
|
|
||||||
# Exclude soft-deleted records by default
|
|
||||||
if hasattr(self.model, "deleted_at"):
|
if hasattr(self.model, "deleted_at"):
|
||||||
query = query.where(self.model.deleted_at.is_(None))
|
query = query.where(self.model.deleted_at.is_(None))
|
||||||
|
|
||||||
# Apply filters
|
|
||||||
if filters:
|
if filters:
|
||||||
for field, value in filters.items():
|
for field, value in filters.items():
|
||||||
if hasattr(self.model, field) and value is not None:
|
if hasattr(self.model, field) and value is not None:
|
||||||
query = query.where(getattr(self.model, field) == value)
|
query = query.where(getattr(self.model, field) == value)
|
||||||
|
|
||||||
# Get total count (before pagination)
|
|
||||||
count_query = select(func.count()).select_from(query.alias())
|
count_query = select(func.count()).select_from(query.alias())
|
||||||
count_result = await db.execute(count_query)
|
count_result = await db.execute(count_query)
|
||||||
total = count_result.scalar_one()
|
total = count_result.scalar_one()
|
||||||
|
|
||||||
# Apply sorting
|
|
||||||
if sort_by and hasattr(self.model, sort_by):
|
if sort_by and hasattr(self.model, sort_by):
|
||||||
sort_column = getattr(self.model, sort_by)
|
sort_column = getattr(self.model, sort_by)
|
||||||
if sort_order.lower() == "desc":
|
if sort_order.lower() == "desc":
|
||||||
query = query.order_by(sort_column.desc())
|
query = query.order_by(sort_column.desc())
|
||||||
else:
|
else:
|
||||||
query = query.order_by(sort_column.asc())
|
query = query.order_by(sort_column.asc())
|
||||||
|
else:
|
||||||
|
query = query.order_by(self.model.id)
|
||||||
|
|
||||||
# Apply pagination
|
|
||||||
query = query.offset(skip).limit(limit)
|
query = query.offset(skip).limit(limit)
|
||||||
items_result = await db.execute(query)
|
items_result = await db.execute(query)
|
||||||
items = list(items_result.scalars().all())
|
items = list(items_result.scalars().all())
|
||||||
|
|
||||||
return items, total
|
return items, total
|
||||||
except Exception as e:
|
except Exception as e: # pragma: no cover
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Error retrieving paginated {self.model.__name__} records: {e!s}"
|
"Error retrieving paginated %s records: %s", self.model.__name__, e
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@@ -335,7 +315,7 @@ class CRUDBase[
|
|||||||
result = await db.execute(select(func.count(self.model.id)))
|
result = await db.execute(select(func.count(self.model.id)))
|
||||||
return result.scalar_one()
|
return result.scalar_one()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error counting {self.model.__name__} records: {e!s}")
|
logger.error("Error counting %s records: %s", self.model.__name__, e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def exists(self, db: AsyncSession, id: str) -> bool:
|
async def exists(self, db: AsyncSession, id: str) -> bool:
|
||||||
@@ -351,14 +331,13 @@ class CRUDBase[
|
|||||||
"""
|
"""
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
# Validate UUID format and convert to UUID object if string
|
|
||||||
try:
|
try:
|
||||||
if isinstance(id, uuid.UUID):
|
if isinstance(id, uuid.UUID):
|
||||||
uuid_obj = id
|
uuid_obj = id
|
||||||
else:
|
else:
|
||||||
uuid_obj = uuid.UUID(str(id))
|
uuid_obj = uuid.UUID(str(id))
|
||||||
except (ValueError, AttributeError, TypeError) as e:
|
except (ValueError, AttributeError, TypeError) as e:
|
||||||
logger.warning(f"Invalid UUID format for soft deletion: {id} - {e!s}")
|
logger.warning("Invalid UUID format for soft deletion: %s - %s", id, e)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -369,18 +348,16 @@ class CRUDBase[
|
|||||||
|
|
||||||
if obj is None:
|
if obj is None:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"{self.model.__name__} with id {id} not found for soft deletion"
|
"%s with id %s not found for soft deletion", self.model.__name__, id
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Check if model supports soft deletes
|
|
||||||
if not hasattr(self.model, "deleted_at"):
|
if not hasattr(self.model, "deleted_at"):
|
||||||
logger.error(f"{self.model.__name__} does not support soft deletes")
|
logger.error("%s does not support soft deletes", self.model.__name__)
|
||||||
raise ValueError(
|
raise InvalidInputError(
|
||||||
f"{self.model.__name__} does not have a deleted_at column"
|
f"{self.model.__name__} does not have a deleted_at column"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set deleted_at timestamp
|
|
||||||
obj.deleted_at = datetime.now(UTC)
|
obj.deleted_at = datetime.now(UTC)
|
||||||
db.add(obj)
|
db.add(obj)
|
||||||
await db.commit()
|
await db.commit()
|
||||||
@@ -388,9 +365,8 @@ class CRUDBase[
|
|||||||
return obj
|
return obj
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await db.rollback()
|
await db.rollback()
|
||||||
logger.error(
|
logger.exception(
|
||||||
f"Error soft deleting {self.model.__name__} with id {id}: {e!s}",
|
"Error soft deleting %s with id %s: %s", self.model.__name__, id, e
|
||||||
exc_info=True,
|
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@@ -400,18 +376,16 @@ class CRUDBase[
|
|||||||
|
|
||||||
Only works if the model has a 'deleted_at' column.
|
Only works if the model has a 'deleted_at' column.
|
||||||
"""
|
"""
|
||||||
# Validate UUID format
|
|
||||||
try:
|
try:
|
||||||
if isinstance(id, uuid.UUID):
|
if isinstance(id, uuid.UUID):
|
||||||
uuid_obj = id
|
uuid_obj = id
|
||||||
else:
|
else:
|
||||||
uuid_obj = uuid.UUID(str(id))
|
uuid_obj = uuid.UUID(str(id))
|
||||||
except (ValueError, AttributeError, TypeError) as e:
|
except (ValueError, AttributeError, TypeError) as e:
|
||||||
logger.warning(f"Invalid UUID format for restoration: {id} - {e!s}")
|
logger.warning("Invalid UUID format for restoration: %s - %s", id, e)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Find the soft-deleted record
|
|
||||||
if hasattr(self.model, "deleted_at"):
|
if hasattr(self.model, "deleted_at"):
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(self.model).where(
|
select(self.model).where(
|
||||||
@@ -420,18 +394,19 @@ class CRUDBase[
|
|||||||
)
|
)
|
||||||
obj = result.scalar_one_or_none()
|
obj = result.scalar_one_or_none()
|
||||||
else:
|
else:
|
||||||
logger.error(f"{self.model.__name__} does not support soft deletes")
|
logger.error("%s does not support soft deletes", self.model.__name__)
|
||||||
raise ValueError(
|
raise InvalidInputError(
|
||||||
f"{self.model.__name__} does not have a deleted_at column"
|
f"{self.model.__name__} does not have a deleted_at column"
|
||||||
)
|
)
|
||||||
|
|
||||||
if obj is None:
|
if obj is None:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Soft-deleted {self.model.__name__} with id {id} not found for restoration"
|
"Soft-deleted %s with id %s not found for restoration",
|
||||||
|
self.model.__name__,
|
||||||
|
id,
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Clear deleted_at timestamp
|
|
||||||
obj.deleted_at = None
|
obj.deleted_at = None
|
||||||
db.add(obj)
|
db.add(obj)
|
||||||
await db.commit()
|
await db.commit()
|
||||||
@@ -439,8 +414,7 @@ class CRUDBase[
|
|||||||
return obj
|
return obj
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await db.rollback()
|
await db.rollback()
|
||||||
logger.error(
|
logger.exception(
|
||||||
f"Error restoring {self.model.__name__} with id {id}: {e!s}",
|
"Error restoring %s with id %s: %s", self.model.__name__, id, e
|
||||||
exc_info=True,
|
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
249
backend/app/repositories/oauth_account.py
Normal file
249
backend/app/repositories/oauth_account.py
Normal file
@@ -0,0 +1,249 @@
|
|||||||
|
# app/repositories/oauth_account.py
|
||||||
|
"""Repository for OAuthAccount model async database operations."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy import and_, delete, select
|
||||||
|
from sqlalchemy.exc import IntegrityError
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.orm import joinedload
|
||||||
|
|
||||||
|
from app.core.repository_exceptions import DuplicateEntryError
|
||||||
|
from app.models.oauth_account import OAuthAccount
|
||||||
|
from app.repositories.base import BaseRepository
|
||||||
|
from app.schemas.oauth import OAuthAccountCreate
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class EmptySchema(BaseModel):
|
||||||
|
"""Placeholder schema for repository operations that don't need update schemas."""
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthAccountRepository(
|
||||||
|
BaseRepository[OAuthAccount, OAuthAccountCreate, EmptySchema]
|
||||||
|
):
|
||||||
|
"""Repository for OAuth account links."""
|
||||||
|
|
||||||
|
async def get_by_provider_id(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
provider: str,
|
||||||
|
provider_user_id: str,
|
||||||
|
) -> OAuthAccount | None:
|
||||||
|
"""Get OAuth account by provider and provider user ID."""
|
||||||
|
try:
|
||||||
|
result = await db.execute(
|
||||||
|
select(OAuthAccount)
|
||||||
|
.where(
|
||||||
|
and_(
|
||||||
|
OAuthAccount.provider == provider,
|
||||||
|
OAuthAccount.provider_user_id == provider_user_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.options(joinedload(OAuthAccount.user))
|
||||||
|
)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
except Exception as e: # pragma: no cover
|
||||||
|
logger.error(
|
||||||
|
"Error getting OAuth account for %s:%s: %s",
|
||||||
|
provider,
|
||||||
|
provider_user_id,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_by_provider_email(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
provider: str,
|
||||||
|
email: str,
|
||||||
|
) -> OAuthAccount | None:
|
||||||
|
"""Get OAuth account by provider and email."""
|
||||||
|
try:
|
||||||
|
result = await db.execute(
|
||||||
|
select(OAuthAccount)
|
||||||
|
.where(
|
||||||
|
and_(
|
||||||
|
OAuthAccount.provider == provider,
|
||||||
|
OAuthAccount.provider_email == email,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.options(joinedload(OAuthAccount.user))
|
||||||
|
)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
except Exception as e: # pragma: no cover
|
||||||
|
logger.error(
|
||||||
|
"Error getting OAuth account for %s email %s: %s", provider, email, e
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_user_accounts(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
user_id: str | UUID,
|
||||||
|
) -> list[OAuthAccount]:
|
||||||
|
"""Get all OAuth accounts linked to a user."""
|
||||||
|
try:
|
||||||
|
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(OAuthAccount)
|
||||||
|
.where(OAuthAccount.user_id == user_uuid)
|
||||||
|
.order_by(OAuthAccount.created_at.desc())
|
||||||
|
)
|
||||||
|
return list(result.scalars().all())
|
||||||
|
except Exception as e: # pragma: no cover
|
||||||
|
logger.error("Error getting OAuth accounts for user %s: %s", user_id, e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_user_account_by_provider(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
user_id: str | UUID,
|
||||||
|
provider: str,
|
||||||
|
) -> OAuthAccount | None:
|
||||||
|
"""Get a specific OAuth account for a user and provider."""
|
||||||
|
try:
|
||||||
|
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(OAuthAccount).where(
|
||||||
|
and_(
|
||||||
|
OAuthAccount.user_id == user_uuid,
|
||||||
|
OAuthAccount.provider == provider,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
except Exception as e: # pragma: no cover
|
||||||
|
logger.error(
|
||||||
|
"Error getting OAuth account for user %s, provider %s: %s",
|
||||||
|
user_id,
|
||||||
|
provider,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def create_account(
|
||||||
|
self, db: AsyncSession, *, obj_in: OAuthAccountCreate
|
||||||
|
) -> OAuthAccount:
|
||||||
|
"""Create a new OAuth account link."""
|
||||||
|
try:
|
||||||
|
db_obj = OAuthAccount(
|
||||||
|
user_id=obj_in.user_id,
|
||||||
|
provider=obj_in.provider,
|
||||||
|
provider_user_id=obj_in.provider_user_id,
|
||||||
|
provider_email=obj_in.provider_email,
|
||||||
|
access_token=obj_in.access_token,
|
||||||
|
refresh_token=obj_in.refresh_token,
|
||||||
|
token_expires_at=obj_in.token_expires_at,
|
||||||
|
)
|
||||||
|
db.add(db_obj)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(db_obj)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"OAuth account created: %s linked to user %s",
|
||||||
|
obj_in.provider,
|
||||||
|
obj_in.user_id,
|
||||||
|
)
|
||||||
|
return db_obj
|
||||||
|
except IntegrityError as e: # pragma: no cover
|
||||||
|
await db.rollback()
|
||||||
|
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||||
|
if "uq_oauth_provider_user" in error_msg.lower():
|
||||||
|
logger.warning(
|
||||||
|
"OAuth account already exists: %s:%s",
|
||||||
|
obj_in.provider,
|
||||||
|
obj_in.provider_user_id,
|
||||||
|
)
|
||||||
|
raise DuplicateEntryError(
|
||||||
|
f"This {obj_in.provider} account is already linked to another user"
|
||||||
|
)
|
||||||
|
logger.error("Integrity error creating OAuth account: %s", error_msg)
|
||||||
|
raise DuplicateEntryError(f"Failed to create OAuth account: {error_msg}")
|
||||||
|
except Exception as e: # pragma: no cover
|
||||||
|
await db.rollback()
|
||||||
|
logger.exception("Error creating OAuth account: %s", e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def delete_account(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
user_id: str | UUID,
|
||||||
|
provider: str,
|
||||||
|
) -> bool:
|
||||||
|
"""Delete an OAuth account link."""
|
||||||
|
try:
|
||||||
|
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
delete(OAuthAccount).where(
|
||||||
|
and_(
|
||||||
|
OAuthAccount.user_id == user_uuid,
|
||||||
|
OAuthAccount.provider == provider,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
deleted = result.rowcount > 0
|
||||||
|
if deleted:
|
||||||
|
logger.info(
|
||||||
|
"OAuth account deleted: %s unlinked from user %s", provider, user_id
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"OAuth account not found for deletion: %s for user %s",
|
||||||
|
provider,
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return deleted
|
||||||
|
except Exception as e: # pragma: no cover
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(
|
||||||
|
"Error deleting OAuth account %s for user %s: %s", provider, user_id, e
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def update_tokens(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
account: OAuthAccount,
|
||||||
|
access_token: str | None = None,
|
||||||
|
refresh_token: str | None = None,
|
||||||
|
token_expires_at: datetime | None = None,
|
||||||
|
) -> OAuthAccount:
|
||||||
|
"""Update OAuth tokens for an account."""
|
||||||
|
try:
|
||||||
|
if access_token is not None:
|
||||||
|
account.access_token = access_token
|
||||||
|
if refresh_token is not None:
|
||||||
|
account.refresh_token = refresh_token
|
||||||
|
if token_expires_at is not None:
|
||||||
|
account.token_expires_at = token_expires_at
|
||||||
|
|
||||||
|
db.add(account)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(account)
|
||||||
|
|
||||||
|
return account
|
||||||
|
except Exception as e: # pragma: no cover
|
||||||
|
await db.rollback()
|
||||||
|
logger.error("Error updating OAuth tokens: %s", e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
# Singleton instance
|
||||||
|
oauth_account_repo = OAuthAccountRepository(OAuthAccount)
|
||||||
108
backend/app/repositories/oauth_authorization_code.py
Normal file
108
backend/app/repositories/oauth_authorization_code.py
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
# app/repositories/oauth_authorization_code.py
|
||||||
|
"""Repository for OAuthAuthorizationCode model."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from sqlalchemy import and_, delete, select, update
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.models.oauth_authorization_code import OAuthAuthorizationCode
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthAuthorizationCodeRepository:
|
||||||
|
"""Repository for OAuth 2.0 authorization codes."""
|
||||||
|
|
||||||
|
async def create_code(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
code: str,
|
||||||
|
client_id: str,
|
||||||
|
user_id: UUID,
|
||||||
|
redirect_uri: str,
|
||||||
|
scope: str,
|
||||||
|
expires_at: datetime,
|
||||||
|
code_challenge: str | None = None,
|
||||||
|
code_challenge_method: str | None = None,
|
||||||
|
state: str | None = None,
|
||||||
|
nonce: str | None = None,
|
||||||
|
) -> OAuthAuthorizationCode:
|
||||||
|
"""Create and persist a new authorization code."""
|
||||||
|
auth_code = OAuthAuthorizationCode(
|
||||||
|
code=code,
|
||||||
|
client_id=client_id,
|
||||||
|
user_id=user_id,
|
||||||
|
redirect_uri=redirect_uri,
|
||||||
|
scope=scope,
|
||||||
|
code_challenge=code_challenge,
|
||||||
|
code_challenge_method=code_challenge_method,
|
||||||
|
state=state,
|
||||||
|
nonce=nonce,
|
||||||
|
expires_at=expires_at,
|
||||||
|
used=False,
|
||||||
|
)
|
||||||
|
db.add(auth_code)
|
||||||
|
await db.commit()
|
||||||
|
return auth_code
|
||||||
|
|
||||||
|
async def consume_code_atomically(
|
||||||
|
self, db: AsyncSession, *, code: str
|
||||||
|
) -> UUID | None:
|
||||||
|
"""
|
||||||
|
Atomically mark a code as used and return its UUID.
|
||||||
|
|
||||||
|
Returns the UUID if the code was found and not yet used, None otherwise.
|
||||||
|
This prevents race conditions per RFC 6749 Section 4.1.2.
|
||||||
|
"""
|
||||||
|
stmt = (
|
||||||
|
update(OAuthAuthorizationCode)
|
||||||
|
.where(
|
||||||
|
and_(
|
||||||
|
OAuthAuthorizationCode.code == code,
|
||||||
|
OAuthAuthorizationCode.used == False, # noqa: E712
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.values(used=True)
|
||||||
|
.returning(OAuthAuthorizationCode.id)
|
||||||
|
)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
row_id = result.scalar_one_or_none()
|
||||||
|
if row_id is not None:
|
||||||
|
await db.commit()
|
||||||
|
return row_id
|
||||||
|
|
||||||
|
async def get_by_id(
|
||||||
|
self, db: AsyncSession, *, code_id: UUID
|
||||||
|
) -> OAuthAuthorizationCode | None:
|
||||||
|
"""Get authorization code by its UUID primary key."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.id == code_id)
|
||||||
|
)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def get_by_code(
|
||||||
|
self, db: AsyncSession, *, code: str
|
||||||
|
) -> OAuthAuthorizationCode | None:
|
||||||
|
"""Get authorization code by the code string value."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.code == code)
|
||||||
|
)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def cleanup_expired(self, db: AsyncSession) -> int:
|
||||||
|
"""Delete all expired authorization codes. Returns count deleted."""
|
||||||
|
result = await db.execute(
|
||||||
|
delete(OAuthAuthorizationCode).where(
|
||||||
|
OAuthAuthorizationCode.expires_at < datetime.now(UTC)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
return result.rowcount # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
|
||||||
|
# Singleton instance
|
||||||
|
oauth_authorization_code_repo = OAuthAuthorizationCodeRepository()
|
||||||
201
backend/app/repositories/oauth_client.py
Normal file
201
backend/app/repositories/oauth_client.py
Normal file
@@ -0,0 +1,201 @@
|
|||||||
|
# app/repositories/oauth_client.py
|
||||||
|
"""Repository for OAuthClient model async database operations."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import secrets
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy import and_, delete, select
|
||||||
|
from sqlalchemy.exc import IntegrityError
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.core.repository_exceptions import DuplicateEntryError
|
||||||
|
from app.models.oauth_client import OAuthClient
|
||||||
|
from app.repositories.base import BaseRepository
|
||||||
|
from app.schemas.oauth import OAuthClientCreate
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class EmptySchema(BaseModel):
|
||||||
|
"""Placeholder schema for repository operations that don't need update schemas."""
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthClientRepository(
|
||||||
|
BaseRepository[OAuthClient, OAuthClientCreate, EmptySchema]
|
||||||
|
):
|
||||||
|
"""Repository for OAuth clients (provider mode)."""
|
||||||
|
|
||||||
|
async def get_by_client_id(
|
||||||
|
self, db: AsyncSession, *, client_id: str
|
||||||
|
) -> OAuthClient | None:
|
||||||
|
"""Get OAuth client by client_id."""
|
||||||
|
try:
|
||||||
|
result = await db.execute(
|
||||||
|
select(OAuthClient).where(
|
||||||
|
and_(
|
||||||
|
OAuthClient.client_id == client_id,
|
||||||
|
OAuthClient.is_active == True, # noqa: E712
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
except Exception as e: # pragma: no cover
|
||||||
|
logger.error("Error getting OAuth client %s: %s", client_id, e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def create_client(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
obj_in: OAuthClientCreate,
|
||||||
|
owner_user_id: UUID | None = None,
|
||||||
|
) -> tuple[OAuthClient, str | None]:
|
||||||
|
"""Create a new OAuth client."""
|
||||||
|
try:
|
||||||
|
client_id = secrets.token_urlsafe(32)
|
||||||
|
|
||||||
|
client_secret = None
|
||||||
|
client_secret_hash = None
|
||||||
|
if obj_in.client_type == "confidential":
|
||||||
|
client_secret = secrets.token_urlsafe(48)
|
||||||
|
from app.core.auth import get_password_hash
|
||||||
|
|
||||||
|
client_secret_hash = get_password_hash(client_secret)
|
||||||
|
|
||||||
|
db_obj = OAuthClient(
|
||||||
|
client_id=client_id,
|
||||||
|
client_secret_hash=client_secret_hash,
|
||||||
|
client_name=obj_in.client_name,
|
||||||
|
client_description=obj_in.client_description,
|
||||||
|
client_type=obj_in.client_type,
|
||||||
|
redirect_uris=obj_in.redirect_uris,
|
||||||
|
allowed_scopes=obj_in.allowed_scopes,
|
||||||
|
owner_user_id=owner_user_id,
|
||||||
|
is_active=True,
|
||||||
|
)
|
||||||
|
db.add(db_obj)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(db_obj)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"OAuth client created: %s (%s...)", obj_in.client_name, client_id[:8]
|
||||||
|
)
|
||||||
|
return db_obj, client_secret
|
||||||
|
except IntegrityError as e: # pragma: no cover
|
||||||
|
await db.rollback()
|
||||||
|
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||||
|
logger.error("Error creating OAuth client: %s", error_msg)
|
||||||
|
raise DuplicateEntryError(f"Failed to create OAuth client: {error_msg}")
|
||||||
|
except Exception as e: # pragma: no cover
|
||||||
|
await db.rollback()
|
||||||
|
logger.exception("Error creating OAuth client: %s", e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def deactivate_client(
|
||||||
|
self, db: AsyncSession, *, client_id: str
|
||||||
|
) -> OAuthClient | None:
|
||||||
|
"""Deactivate an OAuth client."""
|
||||||
|
try:
|
||||||
|
client = await self.get_by_client_id(db, client_id=client_id)
|
||||||
|
if client is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
client.is_active = False
|
||||||
|
db.add(client)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(client)
|
||||||
|
|
||||||
|
logger.info("OAuth client deactivated: %s", client.client_name)
|
||||||
|
return client
|
||||||
|
except Exception as e: # pragma: no cover
|
||||||
|
await db.rollback()
|
||||||
|
logger.error("Error deactivating OAuth client %s: %s", client_id, e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def validate_redirect_uri(
|
||||||
|
self, db: AsyncSession, *, client_id: str, redirect_uri: str
|
||||||
|
) -> bool:
|
||||||
|
"""Validate that a redirect URI is allowed for a client."""
|
||||||
|
try:
|
||||||
|
client = await self.get_by_client_id(db, client_id=client_id)
|
||||||
|
if client is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return redirect_uri in (client.redirect_uris or [])
|
||||||
|
except Exception as e: # pragma: no cover
|
||||||
|
logger.error("Error validating redirect URI: %s", e)
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def verify_client_secret(
|
||||||
|
self, db: AsyncSession, *, client_id: str, client_secret: str
|
||||||
|
) -> bool:
|
||||||
|
"""Verify client credentials."""
|
||||||
|
try:
|
||||||
|
result = await db.execute(
|
||||||
|
select(OAuthClient).where(
|
||||||
|
and_(
|
||||||
|
OAuthClient.client_id == client_id,
|
||||||
|
OAuthClient.is_active == True, # noqa: E712
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
client = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if client is None or client.client_secret_hash is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
from app.core.auth import verify_password
|
||||||
|
|
||||||
|
stored_hash: str = str(client.client_secret_hash)
|
||||||
|
|
||||||
|
if stored_hash.startswith("$2"):
|
||||||
|
return verify_password(client_secret, stored_hash)
|
||||||
|
else:
|
||||||
|
import hashlib
|
||||||
|
|
||||||
|
secret_hash = hashlib.sha256(client_secret.encode()).hexdigest()
|
||||||
|
return secrets.compare_digest(stored_hash, secret_hash)
|
||||||
|
except Exception as e: # pragma: no cover
|
||||||
|
logger.error("Error verifying client secret: %s", e)
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def get_all_clients(
|
||||||
|
self, db: AsyncSession, *, include_inactive: bool = False
|
||||||
|
) -> list[OAuthClient]:
|
||||||
|
"""Get all OAuth clients."""
|
||||||
|
try:
|
||||||
|
query = select(OAuthClient).order_by(OAuthClient.created_at.desc())
|
||||||
|
if not include_inactive:
|
||||||
|
query = query.where(OAuthClient.is_active == True) # noqa: E712
|
||||||
|
|
||||||
|
result = await db.execute(query)
|
||||||
|
return list(result.scalars().all())
|
||||||
|
except Exception as e: # pragma: no cover
|
||||||
|
logger.error("Error getting all OAuth clients: %s", e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def delete_client(self, db: AsyncSession, *, client_id: str) -> bool:
|
||||||
|
"""Delete an OAuth client permanently."""
|
||||||
|
try:
|
||||||
|
result = await db.execute(
|
||||||
|
delete(OAuthClient).where(OAuthClient.client_id == client_id)
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
deleted = result.rowcount > 0
|
||||||
|
if deleted:
|
||||||
|
logger.info("OAuth client deleted: %s", client_id)
|
||||||
|
else:
|
||||||
|
logger.warning("OAuth client not found for deletion: %s", client_id)
|
||||||
|
|
||||||
|
return deleted
|
||||||
|
except Exception as e: # pragma: no cover
|
||||||
|
await db.rollback()
|
||||||
|
logger.error("Error deleting OAuth client %s: %s", client_id, e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
# Singleton instance
|
||||||
|
oauth_client_repo = OAuthClientRepository(OAuthClient)
|
||||||
113
backend/app/repositories/oauth_consent.py
Normal file
113
backend/app/repositories/oauth_consent.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
# app/repositories/oauth_consent.py
|
||||||
|
"""Repository for OAuthConsent model."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from sqlalchemy import and_, delete, select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.models.oauth_client import OAuthClient
|
||||||
|
from app.models.oauth_provider_token import OAuthConsent
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthConsentRepository:
|
||||||
|
"""Repository for OAuth consent records (user grants to clients)."""
|
||||||
|
|
||||||
|
async def get_consent(
|
||||||
|
self, db: AsyncSession, *, user_id: UUID, client_id: str
|
||||||
|
) -> OAuthConsent | None:
|
||||||
|
"""Get the consent record for a user-client pair, or None if not found."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(OAuthConsent).where(
|
||||||
|
and_(
|
||||||
|
OAuthConsent.user_id == user_id,
|
||||||
|
OAuthConsent.client_id == client_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def grant_consent(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
user_id: UUID,
|
||||||
|
client_id: str,
|
||||||
|
scopes: list[str],
|
||||||
|
) -> OAuthConsent:
|
||||||
|
"""
|
||||||
|
Create or update consent for a user-client pair.
|
||||||
|
|
||||||
|
If consent already exists, the new scopes are merged with existing ones.
|
||||||
|
Returns the created or updated consent record.
|
||||||
|
"""
|
||||||
|
consent = await self.get_consent(db, user_id=user_id, client_id=client_id)
|
||||||
|
|
||||||
|
if consent:
|
||||||
|
existing = (
|
||||||
|
set(consent.granted_scopes.split()) if consent.granted_scopes else set()
|
||||||
|
)
|
||||||
|
merged = existing | set(scopes)
|
||||||
|
consent.granted_scopes = " ".join(sorted(merged)) # type: ignore[assignment]
|
||||||
|
else:
|
||||||
|
consent = OAuthConsent(
|
||||||
|
user_id=user_id,
|
||||||
|
client_id=client_id,
|
||||||
|
granted_scopes=" ".join(sorted(set(scopes))),
|
||||||
|
)
|
||||||
|
db.add(consent)
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(consent)
|
||||||
|
return consent
|
||||||
|
|
||||||
|
async def get_user_consents_with_clients(
|
||||||
|
self, db: AsyncSession, *, user_id: UUID
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Get all consent records for a user joined with client details."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(OAuthConsent, OAuthClient)
|
||||||
|
.join(OAuthClient, OAuthConsent.client_id == OAuthClient.client_id)
|
||||||
|
.where(OAuthConsent.user_id == user_id)
|
||||||
|
)
|
||||||
|
rows = result.all()
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"client_id": consent.client_id,
|
||||||
|
"client_name": client.client_name,
|
||||||
|
"client_description": client.client_description,
|
||||||
|
"granted_scopes": consent.granted_scopes.split()
|
||||||
|
if consent.granted_scopes
|
||||||
|
else [],
|
||||||
|
"granted_at": consent.created_at.isoformat(),
|
||||||
|
}
|
||||||
|
for consent, client in rows
|
||||||
|
]
|
||||||
|
|
||||||
|
async def revoke_consent(
|
||||||
|
self, db: AsyncSession, *, user_id: UUID, client_id: str
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Delete the consent record for a user-client pair.
|
||||||
|
|
||||||
|
Returns True if a record was found and deleted.
|
||||||
|
Note: Callers are responsible for also revoking associated tokens.
|
||||||
|
"""
|
||||||
|
result = await db.execute(
|
||||||
|
delete(OAuthConsent).where(
|
||||||
|
and_(
|
||||||
|
OAuthConsent.user_id == user_id,
|
||||||
|
OAuthConsent.client_id == client_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
return result.rowcount > 0 # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
|
||||||
|
# Singleton instance
|
||||||
|
oauth_consent_repo = OAuthConsentRepository()
|
||||||
142
backend/app/repositories/oauth_provider_token.py
Normal file
142
backend/app/repositories/oauth_provider_token.py
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
# app/repositories/oauth_provider_token.py
|
||||||
|
"""Repository for OAuthProviderRefreshToken model."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from sqlalchemy import and_, delete, select, update
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.models.oauth_provider_token import OAuthProviderRefreshToken
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthProviderTokenRepository:
|
||||||
|
"""Repository for OAuth provider refresh tokens."""
|
||||||
|
|
||||||
|
async def create_token(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
token_hash: str,
|
||||||
|
jti: str,
|
||||||
|
client_id: str,
|
||||||
|
user_id: UUID,
|
||||||
|
scope: str,
|
||||||
|
expires_at: datetime,
|
||||||
|
device_info: str | None = None,
|
||||||
|
ip_address: str | None = None,
|
||||||
|
) -> OAuthProviderRefreshToken:
|
||||||
|
"""Create and persist a new refresh token record."""
|
||||||
|
token = OAuthProviderRefreshToken(
|
||||||
|
token_hash=token_hash,
|
||||||
|
jti=jti,
|
||||||
|
client_id=client_id,
|
||||||
|
user_id=user_id,
|
||||||
|
scope=scope,
|
||||||
|
expires_at=expires_at,
|
||||||
|
device_info=device_info,
|
||||||
|
ip_address=ip_address,
|
||||||
|
)
|
||||||
|
db.add(token)
|
||||||
|
await db.commit()
|
||||||
|
return token
|
||||||
|
|
||||||
|
async def get_by_token_hash(
|
||||||
|
self, db: AsyncSession, *, token_hash: str
|
||||||
|
) -> OAuthProviderRefreshToken | None:
|
||||||
|
"""Get refresh token record by SHA-256 token hash."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(OAuthProviderRefreshToken).where(
|
||||||
|
OAuthProviderRefreshToken.token_hash == token_hash
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def get_by_jti(
|
||||||
|
self, db: AsyncSession, *, jti: str
|
||||||
|
) -> OAuthProviderRefreshToken | None:
|
||||||
|
"""Get refresh token record by JWT ID (JTI)."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(OAuthProviderRefreshToken).where(
|
||||||
|
OAuthProviderRefreshToken.jti == jti
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
async def revoke(
|
||||||
|
self, db: AsyncSession, *, token: OAuthProviderRefreshToken
|
||||||
|
) -> None:
|
||||||
|
"""Mark a specific token record as revoked."""
|
||||||
|
token.revoked = True # type: ignore[assignment]
|
||||||
|
token.last_used_at = datetime.now(UTC) # type: ignore[assignment]
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
async def revoke_all_for_user_client(
|
||||||
|
self, db: AsyncSession, *, user_id: UUID, client_id: str
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Revoke all active tokens for a specific user-client pair.
|
||||||
|
|
||||||
|
Used when security incidents are detected (e.g., authorization code reuse).
|
||||||
|
Returns the number of tokens revoked.
|
||||||
|
"""
|
||||||
|
result = await db.execute(
|
||||||
|
update(OAuthProviderRefreshToken)
|
||||||
|
.where(
|
||||||
|
and_(
|
||||||
|
OAuthProviderRefreshToken.user_id == user_id,
|
||||||
|
OAuthProviderRefreshToken.client_id == client_id,
|
||||||
|
OAuthProviderRefreshToken.revoked == False, # noqa: E712
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.values(revoked=True)
|
||||||
|
)
|
||||||
|
count = result.rowcount # type: ignore[attr-defined]
|
||||||
|
if count > 0:
|
||||||
|
await db.commit()
|
||||||
|
return count
|
||||||
|
|
||||||
|
async def revoke_all_for_user(self, db: AsyncSession, *, user_id: UUID) -> int:
|
||||||
|
"""
|
||||||
|
Revoke all active tokens for a user across all clients.
|
||||||
|
|
||||||
|
Used when user changes password or logs out everywhere.
|
||||||
|
Returns the number of tokens revoked.
|
||||||
|
"""
|
||||||
|
result = await db.execute(
|
||||||
|
update(OAuthProviderRefreshToken)
|
||||||
|
.where(
|
||||||
|
and_(
|
||||||
|
OAuthProviderRefreshToken.user_id == user_id,
|
||||||
|
OAuthProviderRefreshToken.revoked == False, # noqa: E712
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.values(revoked=True)
|
||||||
|
)
|
||||||
|
count = result.rowcount # type: ignore[attr-defined]
|
||||||
|
if count > 0:
|
||||||
|
await db.commit()
|
||||||
|
return count
|
||||||
|
|
||||||
|
async def cleanup_expired(self, db: AsyncSession, *, cutoff_days: int = 7) -> int:
|
||||||
|
"""
|
||||||
|
Delete expired refresh tokens older than cutoff_days.
|
||||||
|
|
||||||
|
Should be called periodically (e.g., daily).
|
||||||
|
Returns the number of tokens deleted.
|
||||||
|
"""
|
||||||
|
cutoff = datetime.now(UTC) - timedelta(days=cutoff_days)
|
||||||
|
result = await db.execute(
|
||||||
|
delete(OAuthProviderRefreshToken).where(
|
||||||
|
OAuthProviderRefreshToken.expires_at < cutoff
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
return result.rowcount # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
|
||||||
|
# Singleton instance
|
||||||
|
oauth_provider_token_repo = OAuthProviderTokenRepository()
|
||||||
113
backend/app/repositories/oauth_state.py
Normal file
113
backend/app/repositories/oauth_state.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
# app/repositories/oauth_state.py
|
||||||
|
"""Repository for OAuthState model async database operations."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy import delete, select
|
||||||
|
from sqlalchemy.exc import IntegrityError
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.core.repository_exceptions import DuplicateEntryError
|
||||||
|
from app.models.oauth_state import OAuthState
|
||||||
|
from app.repositories.base import BaseRepository
|
||||||
|
from app.schemas.oauth import OAuthStateCreate
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class EmptySchema(BaseModel):
|
||||||
|
"""Placeholder schema for repository operations that don't need update schemas."""
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthStateRepository(BaseRepository[OAuthState, OAuthStateCreate, EmptySchema]):
|
||||||
|
"""Repository for OAuth state (CSRF protection)."""
|
||||||
|
|
||||||
|
async def create_state(
|
||||||
|
self, db: AsyncSession, *, obj_in: OAuthStateCreate
|
||||||
|
) -> OAuthState:
|
||||||
|
"""Create a new OAuth state for CSRF protection."""
|
||||||
|
try:
|
||||||
|
db_obj = OAuthState(
|
||||||
|
state=obj_in.state,
|
||||||
|
code_verifier=obj_in.code_verifier,
|
||||||
|
nonce=obj_in.nonce,
|
||||||
|
provider=obj_in.provider,
|
||||||
|
redirect_uri=obj_in.redirect_uri,
|
||||||
|
user_id=obj_in.user_id,
|
||||||
|
expires_at=obj_in.expires_at,
|
||||||
|
)
|
||||||
|
db.add(db_obj)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(db_obj)
|
||||||
|
|
||||||
|
logger.debug("OAuth state created for %s", obj_in.provider)
|
||||||
|
return db_obj
|
||||||
|
except IntegrityError as e: # pragma: no cover
|
||||||
|
await db.rollback()
|
||||||
|
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||||
|
logger.error("OAuth state collision: %s", error_msg)
|
||||||
|
raise DuplicateEntryError("Failed to create OAuth state, please retry")
|
||||||
|
except Exception as e: # pragma: no cover
|
||||||
|
await db.rollback()
|
||||||
|
logger.exception("Error creating OAuth state: %s", e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_and_consume_state(
|
||||||
|
self, db: AsyncSession, *, state: str
|
||||||
|
) -> OAuthState | None:
|
||||||
|
"""Get and delete OAuth state (consume it)."""
|
||||||
|
try:
|
||||||
|
result = await db.execute(
|
||||||
|
select(OAuthState).where(OAuthState.state == state)
|
||||||
|
)
|
||||||
|
db_obj = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if db_obj is None:
|
||||||
|
logger.warning("OAuth state not found: %s...", state[:8])
|
||||||
|
return None
|
||||||
|
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
expires_at = db_obj.expires_at
|
||||||
|
if expires_at.tzinfo is None:
|
||||||
|
expires_at = expires_at.replace(tzinfo=UTC)
|
||||||
|
|
||||||
|
if expires_at < now:
|
||||||
|
logger.warning("OAuth state expired: %s...", state[:8])
|
||||||
|
await db.delete(db_obj)
|
||||||
|
await db.commit()
|
||||||
|
return None
|
||||||
|
|
||||||
|
await db.delete(db_obj)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
logger.debug("OAuth state consumed: %s...", state[:8])
|
||||||
|
return db_obj
|
||||||
|
except Exception as e: # pragma: no cover
|
||||||
|
await db.rollback()
|
||||||
|
logger.error("Error consuming OAuth state: %s", e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def cleanup_expired(self, db: AsyncSession) -> int:
|
||||||
|
"""Clean up expired OAuth states."""
|
||||||
|
try:
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
|
||||||
|
stmt = delete(OAuthState).where(OAuthState.expires_at < now)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
count = result.rowcount
|
||||||
|
if count > 0:
|
||||||
|
logger.info("Cleaned up %s expired OAuth states", count)
|
||||||
|
|
||||||
|
return count
|
||||||
|
except Exception as e: # pragma: no cover
|
||||||
|
await db.rollback()
|
||||||
|
logger.error("Error cleaning up expired OAuth states: %s", e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
# Singleton instance
|
||||||
|
oauth_state_repo = OAuthStateRepository(OAuthState)
|
||||||
128
backend/app/crud/organization.py → backend/app/repositories/organization.py
Executable file → Normal file
128
backend/app/crud/organization.py → backend/app/repositories/organization.py
Executable file → Normal file
@@ -1,5 +1,5 @@
|
|||||||
# app/crud/organization_async.py
|
# app/repositories/organization.py
|
||||||
"""Async CRUD operations for Organization model using SQLAlchemy 2.0 patterns."""
|
"""Repository for Organization model async database operations using SQLAlchemy 2.0 patterns."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -9,10 +9,11 @@ from sqlalchemy import and_, case, func, or_, select
|
|||||||
from sqlalchemy.exc import IntegrityError
|
from sqlalchemy.exc import IntegrityError
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.crud.base import CRUDBase
|
from app.core.repository_exceptions import DuplicateEntryError, IntegrityConstraintError
|
||||||
from app.models.organization import Organization
|
from app.models.organization import Organization
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.models.user_organization import OrganizationRole, UserOrganization
|
from app.models.user_organization import OrganizationRole, UserOrganization
|
||||||
|
from app.repositories.base import BaseRepository
|
||||||
from app.schemas.organizations import (
|
from app.schemas.organizations import (
|
||||||
OrganizationCreate,
|
OrganizationCreate,
|
||||||
OrganizationUpdate,
|
OrganizationUpdate,
|
||||||
@@ -21,8 +22,10 @@ from app.schemas.organizations import (
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUpdate]):
|
class OrganizationRepository(
|
||||||
"""Async CRUD operations for Organization model."""
|
BaseRepository[Organization, OrganizationCreate, OrganizationUpdate]
|
||||||
|
):
|
||||||
|
"""Repository for Organization model."""
|
||||||
|
|
||||||
async def get_by_slug(self, db: AsyncSession, *, slug: str) -> Organization | None:
|
async def get_by_slug(self, db: AsyncSession, *, slug: str) -> Organization | None:
|
||||||
"""Get organization by slug."""
|
"""Get organization by slug."""
|
||||||
@@ -32,7 +35,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
|||||||
)
|
)
|
||||||
return result.scalar_one_or_none()
|
return result.scalar_one_or_none()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting organization by slug {slug}: {e!s}")
|
logger.error("Error getting organization by slug %s: %s", slug, e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def create(
|
async def create(
|
||||||
@@ -54,18 +57,20 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
|||||||
except IntegrityError as e:
|
except IntegrityError as e:
|
||||||
await db.rollback()
|
await db.rollback()
|
||||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||||
if "slug" in error_msg.lower():
|
if (
|
||||||
logger.warning(f"Duplicate slug attempted: {obj_in.slug}")
|
"slug" in error_msg.lower()
|
||||||
raise ValueError(
|
or "unique" in error_msg.lower()
|
||||||
|
or "duplicate" in error_msg.lower()
|
||||||
|
):
|
||||||
|
logger.warning("Duplicate slug attempted: %s", obj_in.slug)
|
||||||
|
raise DuplicateEntryError(
|
||||||
f"Organization with slug '{obj_in.slug}' already exists"
|
f"Organization with slug '{obj_in.slug}' already exists"
|
||||||
)
|
)
|
||||||
logger.error(f"Integrity error creating organization: {error_msg}")
|
logger.error("Integrity error creating organization: %s", error_msg)
|
||||||
raise ValueError(f"Database integrity error: {error_msg}")
|
raise IntegrityConstraintError(f"Database integrity error: {error_msg}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await db.rollback()
|
await db.rollback()
|
||||||
logger.error(
|
logger.exception("Unexpected error creating organization: %s", e)
|
||||||
f"Unexpected error creating organization: {e!s}", exc_info=True
|
|
||||||
)
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def get_multi_with_filters(
|
async def get_multi_with_filters(
|
||||||
@@ -79,16 +84,10 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
|||||||
sort_by: str = "created_at",
|
sort_by: str = "created_at",
|
||||||
sort_order: str = "desc",
|
sort_order: str = "desc",
|
||||||
) -> tuple[list[Organization], int]:
|
) -> tuple[list[Organization], int]:
|
||||||
"""
|
"""Get multiple organizations with filtering, searching, and sorting."""
|
||||||
Get multiple organizations with filtering, searching, and sorting.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (organizations list, total count)
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
query = select(Organization)
|
query = select(Organization)
|
||||||
|
|
||||||
# Apply filters
|
|
||||||
if is_active is not None:
|
if is_active is not None:
|
||||||
query = query.where(Organization.is_active == is_active)
|
query = query.where(Organization.is_active == is_active)
|
||||||
|
|
||||||
@@ -100,26 +99,23 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
|||||||
)
|
)
|
||||||
query = query.where(search_filter)
|
query = query.where(search_filter)
|
||||||
|
|
||||||
# Get total count before pagination
|
|
||||||
count_query = select(func.count()).select_from(query.alias())
|
count_query = select(func.count()).select_from(query.alias())
|
||||||
count_result = await db.execute(count_query)
|
count_result = await db.execute(count_query)
|
||||||
total = count_result.scalar_one()
|
total = count_result.scalar_one()
|
||||||
|
|
||||||
# Apply sorting
|
|
||||||
sort_column = getattr(Organization, sort_by, Organization.created_at)
|
sort_column = getattr(Organization, sort_by, Organization.created_at)
|
||||||
if sort_order == "desc":
|
if sort_order == "desc":
|
||||||
query = query.order_by(sort_column.desc())
|
query = query.order_by(sort_column.desc())
|
||||||
else:
|
else:
|
||||||
query = query.order_by(sort_column.asc())
|
query = query.order_by(sort_column.asc())
|
||||||
|
|
||||||
# Apply pagination
|
|
||||||
query = query.offset(skip).limit(limit)
|
query = query.offset(skip).limit(limit)
|
||||||
result = await db.execute(query)
|
result = await db.execute(query)
|
||||||
organizations = list(result.scalars().all())
|
organizations = list(result.scalars().all())
|
||||||
|
|
||||||
return organizations, total
|
return organizations, total
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting organizations with filters: {e!s}")
|
logger.error("Error getting organizations with filters: %s", e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def get_member_count(self, db: AsyncSession, *, organization_id: UUID) -> int:
|
async def get_member_count(self, db: AsyncSession, *, organization_id: UUID) -> int:
|
||||||
@@ -136,7 +132,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
|||||||
return result.scalar_one() or 0
|
return result.scalar_one() or 0
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Error getting member count for organization {organization_id}: {e!s}"
|
"Error getting member count for organization %s: %s", organization_id, e
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@@ -149,16 +145,8 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
|||||||
is_active: bool | None = None,
|
is_active: bool | None = None,
|
||||||
search: str | None = None,
|
search: str | None = None,
|
||||||
) -> tuple[list[dict[str, Any]], int]:
|
) -> tuple[list[dict[str, Any]], int]:
|
||||||
"""
|
"""Get organizations with member counts in a SINGLE QUERY using JOIN and GROUP BY."""
|
||||||
Get organizations with member counts in a SINGLE QUERY using JOIN and GROUP BY.
|
|
||||||
This eliminates the N+1 query problem.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (list of dicts with org and member_count, total count)
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
# Build base query with LEFT JOIN and GROUP BY
|
|
||||||
# Use CASE statement to count only active members
|
|
||||||
query = (
|
query = (
|
||||||
select(
|
select(
|
||||||
Organization,
|
Organization,
|
||||||
@@ -181,10 +169,10 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
|||||||
.group_by(Organization.id)
|
.group_by(Organization.id)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply filters
|
|
||||||
if is_active is not None:
|
if is_active is not None:
|
||||||
query = query.where(Organization.is_active == is_active)
|
query = query.where(Organization.is_active == is_active)
|
||||||
|
|
||||||
|
search_filter = None
|
||||||
if search:
|
if search:
|
||||||
search_filter = or_(
|
search_filter = or_(
|
||||||
Organization.name.ilike(f"%{search}%"),
|
Organization.name.ilike(f"%{search}%"),
|
||||||
@@ -193,17 +181,15 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
|||||||
)
|
)
|
||||||
query = query.where(search_filter)
|
query = query.where(search_filter)
|
||||||
|
|
||||||
# Get total count
|
|
||||||
count_query = select(func.count(Organization.id))
|
count_query = select(func.count(Organization.id))
|
||||||
if is_active is not None:
|
if is_active is not None:
|
||||||
count_query = count_query.where(Organization.is_active == is_active)
|
count_query = count_query.where(Organization.is_active == is_active)
|
||||||
if search:
|
if search_filter is not None:
|
||||||
count_query = count_query.where(search_filter)
|
count_query = count_query.where(search_filter)
|
||||||
|
|
||||||
count_result = await db.execute(count_query)
|
count_result = await db.execute(count_query)
|
||||||
total = count_result.scalar_one()
|
total = count_result.scalar_one()
|
||||||
|
|
||||||
# Apply pagination and ordering
|
|
||||||
query = (
|
query = (
|
||||||
query.order_by(Organization.created_at.desc()).offset(skip).limit(limit)
|
query.order_by(Organization.created_at.desc()).offset(skip).limit(limit)
|
||||||
)
|
)
|
||||||
@@ -211,7 +197,6 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
|||||||
result = await db.execute(query)
|
result = await db.execute(query)
|
||||||
rows = result.all()
|
rows = result.all()
|
||||||
|
|
||||||
# Convert to list of dicts
|
|
||||||
orgs_with_counts = [
|
orgs_with_counts = [
|
||||||
{"organization": org, "member_count": member_count}
|
{"organization": org, "member_count": member_count}
|
||||||
for org, member_count in rows
|
for org, member_count in rows
|
||||||
@@ -220,9 +205,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
|||||||
return orgs_with_counts, total
|
return orgs_with_counts, total
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.exception("Error getting organizations with member counts: %s", e)
|
||||||
f"Error getting organizations with member counts: {e!s}", exc_info=True
|
|
||||||
)
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def add_user(
|
async def add_user(
|
||||||
@@ -236,7 +219,6 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
|||||||
) -> UserOrganization:
|
) -> UserOrganization:
|
||||||
"""Add a user to an organization with a specific role."""
|
"""Add a user to an organization with a specific role."""
|
||||||
try:
|
try:
|
||||||
# Check if relationship already exists
|
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(UserOrganization).where(
|
select(UserOrganization).where(
|
||||||
and_(
|
and_(
|
||||||
@@ -248,7 +230,6 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
|||||||
existing = result.scalar_one_or_none()
|
existing = result.scalar_one_or_none()
|
||||||
|
|
||||||
if existing:
|
if existing:
|
||||||
# Reactivate if inactive, or raise error if already active
|
|
||||||
if not existing.is_active:
|
if not existing.is_active:
|
||||||
existing.is_active = True
|
existing.is_active = True
|
||||||
existing.role = role
|
existing.role = role
|
||||||
@@ -257,9 +238,10 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
|||||||
await db.refresh(existing)
|
await db.refresh(existing)
|
||||||
return existing
|
return existing
|
||||||
else:
|
else:
|
||||||
raise ValueError("User is already a member of this organization")
|
raise DuplicateEntryError(
|
||||||
|
"User is already a member of this organization"
|
||||||
|
)
|
||||||
|
|
||||||
# Create new relationship
|
|
||||||
user_org = UserOrganization(
|
user_org = UserOrganization(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
organization_id=organization_id,
|
organization_id=organization_id,
|
||||||
@@ -273,11 +255,11 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
|||||||
return user_org
|
return user_org
|
||||||
except IntegrityError as e:
|
except IntegrityError as e:
|
||||||
await db.rollback()
|
await db.rollback()
|
||||||
logger.error(f"Integrity error adding user to organization: {e!s}")
|
logger.error("Integrity error adding user to organization: %s", e)
|
||||||
raise ValueError("Failed to add user to organization")
|
raise IntegrityConstraintError("Failed to add user to organization")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await db.rollback()
|
await db.rollback()
|
||||||
logger.error(f"Error adding user to organization: {e!s}", exc_info=True)
|
logger.exception("Error adding user to organization: %s", e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def remove_user(
|
async def remove_user(
|
||||||
@@ -303,7 +285,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
|||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await db.rollback()
|
await db.rollback()
|
||||||
logger.error(f"Error removing user from organization: {e!s}", exc_info=True)
|
logger.exception("Error removing user from organization: %s", e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def update_user_role(
|
async def update_user_role(
|
||||||
@@ -338,7 +320,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
|||||||
return user_org
|
return user_org
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await db.rollback()
|
await db.rollback()
|
||||||
logger.error(f"Error updating user role: {e!s}", exc_info=True)
|
logger.exception("Error updating user role: %s", e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def get_organization_members(
|
async def get_organization_members(
|
||||||
@@ -348,16 +330,10 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
|||||||
organization_id: UUID,
|
organization_id: UUID,
|
||||||
skip: int = 0,
|
skip: int = 0,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
is_active: bool = True,
|
is_active: bool | None = True,
|
||||||
) -> tuple[list[dict[str, Any]], int]:
|
) -> tuple[list[dict[str, Any]], int]:
|
||||||
"""
|
"""Get members of an organization with user details."""
|
||||||
Get members of an organization with user details.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (members list with user details, total count)
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
# Build query with join
|
|
||||||
query = (
|
query = (
|
||||||
select(UserOrganization, User)
|
select(UserOrganization, User)
|
||||||
.join(User, UserOrganization.user_id == User.id)
|
.join(User, UserOrganization.user_id == User.id)
|
||||||
@@ -367,7 +343,6 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
|||||||
if is_active is not None:
|
if is_active is not None:
|
||||||
query = query.where(UserOrganization.is_active == is_active)
|
query = query.where(UserOrganization.is_active == is_active)
|
||||||
|
|
||||||
# Get total count
|
|
||||||
count_query = select(func.count()).select_from(
|
count_query = select(func.count()).select_from(
|
||||||
select(UserOrganization)
|
select(UserOrganization)
|
||||||
.where(UserOrganization.organization_id == organization_id)
|
.where(UserOrganization.organization_id == organization_id)
|
||||||
@@ -381,7 +356,6 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
|||||||
count_result = await db.execute(count_query)
|
count_result = await db.execute(count_query)
|
||||||
total = count_result.scalar_one()
|
total = count_result.scalar_one()
|
||||||
|
|
||||||
# Apply ordering and pagination
|
|
||||||
query = (
|
query = (
|
||||||
query.order_by(UserOrganization.created_at.desc())
|
query.order_by(UserOrganization.created_at.desc())
|
||||||
.offset(skip)
|
.offset(skip)
|
||||||
@@ -406,11 +380,11 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
|||||||
|
|
||||||
return members, total
|
return members, total
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting organization members: {e!s}")
|
logger.error("Error getting organization members: %s", e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def get_user_organizations(
|
async def get_user_organizations(
|
||||||
self, db: AsyncSession, *, user_id: UUID, is_active: bool = True
|
self, db: AsyncSession, *, user_id: UUID, is_active: bool | None = True
|
||||||
) -> list[Organization]:
|
) -> list[Organization]:
|
||||||
"""Get all organizations a user belongs to."""
|
"""Get all organizations a user belongs to."""
|
||||||
try:
|
try:
|
||||||
@@ -429,21 +403,14 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
|||||||
result = await db.execute(query)
|
result = await db.execute(query)
|
||||||
return list(result.scalars().all())
|
return list(result.scalars().all())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting user organizations: {e!s}")
|
logger.error("Error getting user organizations: %s", e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def get_user_organizations_with_details(
|
async def get_user_organizations_with_details(
|
||||||
self, db: AsyncSession, *, user_id: UUID, is_active: bool = True
|
self, db: AsyncSession, *, user_id: UUID, is_active: bool | None = True
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""
|
"""Get user's organizations with role and member count in SINGLE QUERY."""
|
||||||
Get user's organizations with role and member count in SINGLE QUERY.
|
|
||||||
Eliminates N+1 problem by using subquery for member counts.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of dicts with organization, role, and member_count
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
# Subquery to get member counts for each organization
|
|
||||||
member_count_subq = (
|
member_count_subq = (
|
||||||
select(
|
select(
|
||||||
UserOrganization.organization_id,
|
UserOrganization.organization_id,
|
||||||
@@ -454,7 +421,6 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
|||||||
.subquery()
|
.subquery()
|
||||||
)
|
)
|
||||||
|
|
||||||
# Main query with JOIN to get org, role, and member count
|
|
||||||
query = (
|
query = (
|
||||||
select(
|
select(
|
||||||
Organization,
|
Organization,
|
||||||
@@ -486,9 +452,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
|||||||
]
|
]
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.exception("Error getting user organizations with details: %s", e)
|
||||||
f"Error getting user organizations with details: {e!s}", exc_info=True
|
|
||||||
)
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def get_user_role_in_org(
|
async def get_user_role_in_org(
|
||||||
@@ -507,9 +471,9 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
|||||||
)
|
)
|
||||||
user_org = result.scalar_one_or_none()
|
user_org = result.scalar_one_or_none()
|
||||||
|
|
||||||
return user_org.role if user_org else None
|
return user_org.role if user_org else None # pyright: ignore[reportReturnType]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting user role in org: {e!s}")
|
logger.error("Error getting user role in org: %s", e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def is_user_org_owner(
|
async def is_user_org_owner(
|
||||||
@@ -531,5 +495,5 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
|||||||
return role in [OrganizationRole.OWNER, OrganizationRole.ADMIN]
|
return role in [OrganizationRole.OWNER, OrganizationRole.ADMIN]
|
||||||
|
|
||||||
|
|
||||||
# Create a singleton instance for use across the application
|
# Singleton instance
|
||||||
organization = CRUDOrganization(Organization)
|
organization_repo = OrganizationRepository(Organization)
|
||||||
231
backend/app/crud/session.py → backend/app/repositories/session.py
Executable file → Normal file
231
backend/app/crud/session.py → backend/app/repositories/session.py
Executable file → Normal file
@@ -1,6 +1,5 @@
|
|||||||
"""
|
# app/repositories/session.py
|
||||||
Async CRUD operations for user sessions using SQLAlchemy 2.0 patterns.
|
"""Repository for UserSession model async database operations using SQLAlchemy 2.0 patterns."""
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
@@ -11,49 +10,32 @@ from sqlalchemy import and_, delete, func, select, update
|
|||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.orm import joinedload
|
from sqlalchemy.orm import joinedload
|
||||||
|
|
||||||
from app.crud.base import CRUDBase
|
from app.core.repository_exceptions import IntegrityConstraintError, InvalidInputError
|
||||||
from app.models.user_session import UserSession
|
from app.models.user_session import UserSession
|
||||||
|
from app.repositories.base import BaseRepository
|
||||||
from app.schemas.sessions import SessionCreate, SessionUpdate
|
from app.schemas.sessions import SessionCreate, SessionUpdate
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
class SessionRepository(BaseRepository[UserSession, SessionCreate, SessionUpdate]):
|
||||||
"""Async CRUD operations for user sessions."""
|
"""Repository for UserSession model."""
|
||||||
|
|
||||||
async def get_by_jti(self, db: AsyncSession, *, jti: str) -> UserSession | None:
|
async def get_by_jti(self, db: AsyncSession, *, jti: str) -> UserSession | None:
|
||||||
"""
|
"""Get session by refresh token JTI."""
|
||||||
Get session by refresh token JTI.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
jti: Refresh token JWT ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
UserSession if found, None otherwise
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(UserSession).where(UserSession.refresh_token_jti == jti)
|
select(UserSession).where(UserSession.refresh_token_jti == jti)
|
||||||
)
|
)
|
||||||
return result.scalar_one_or_none()
|
return result.scalar_one_or_none()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting session by JTI {jti}: {e!s}")
|
logger.error("Error getting session by JTI %s: %s", jti, e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def get_active_by_jti(
|
async def get_active_by_jti(
|
||||||
self, db: AsyncSession, *, jti: str
|
self, db: AsyncSession, *, jti: str
|
||||||
) -> UserSession | None:
|
) -> UserSession | None:
|
||||||
"""
|
"""Get active session by refresh token JTI."""
|
||||||
Get active session by refresh token JTI.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
jti: Refresh token JWT ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Active UserSession if found, None otherwise
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(UserSession).where(
|
select(UserSession).where(
|
||||||
@@ -65,7 +47,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
|||||||
)
|
)
|
||||||
return result.scalar_one_or_none()
|
return result.scalar_one_or_none()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting active session by JTI {jti}: {e!s}")
|
logger.error("Error getting active session by JTI %s: %s", jti, e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def get_user_sessions(
|
async def get_user_sessions(
|
||||||
@@ -76,25 +58,12 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
|||||||
active_only: bool = True,
|
active_only: bool = True,
|
||||||
with_user: bool = False,
|
with_user: bool = False,
|
||||||
) -> list[UserSession]:
|
) -> list[UserSession]:
|
||||||
"""
|
"""Get all sessions for a user with optional eager loading."""
|
||||||
Get all sessions for a user with optional eager loading.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
user_id: User ID
|
|
||||||
active_only: If True, return only active sessions
|
|
||||||
with_user: If True, eager load user relationship to prevent N+1
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of UserSession objects
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
# Convert user_id string to UUID if needed
|
|
||||||
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
|
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||||
|
|
||||||
query = select(UserSession).where(UserSession.user_id == user_uuid)
|
query = select(UserSession).where(UserSession.user_id == user_uuid)
|
||||||
|
|
||||||
# Add eager loading if requested to prevent N+1 queries
|
|
||||||
if with_user:
|
if with_user:
|
||||||
query = query.options(joinedload(UserSession.user))
|
query = query.options(joinedload(UserSession.user))
|
||||||
|
|
||||||
@@ -105,25 +74,13 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
|||||||
result = await db.execute(query)
|
result = await db.execute(query)
|
||||||
return list(result.scalars().all())
|
return list(result.scalars().all())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting sessions for user {user_id}: {e!s}")
|
logger.error("Error getting sessions for user %s: %s", user_id, e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def create_session(
|
async def create_session(
|
||||||
self, db: AsyncSession, *, obj_in: SessionCreate
|
self, db: AsyncSession, *, obj_in: SessionCreate
|
||||||
) -> UserSession:
|
) -> UserSession:
|
||||||
"""
|
"""Create a new user session."""
|
||||||
Create a new user session.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
obj_in: SessionCreate schema with session data
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Created UserSession
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If session creation fails
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
db_obj = UserSession(
|
db_obj = UserSession(
|
||||||
user_id=obj_in.user_id,
|
user_id=obj_in.user_id,
|
||||||
@@ -143,33 +100,26 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
|||||||
await db.refresh(db_obj)
|
await db.refresh(db_obj)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Session created for user {obj_in.user_id} from {obj_in.device_name} "
|
"Session created for user %s from %s (IP: %s)",
|
||||||
f"(IP: {obj_in.ip_address})"
|
obj_in.user_id,
|
||||||
|
obj_in.device_name,
|
||||||
|
obj_in.ip_address,
|
||||||
)
|
)
|
||||||
|
|
||||||
return db_obj
|
return db_obj
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await db.rollback()
|
await db.rollback()
|
||||||
logger.error(f"Error creating session: {e!s}", exc_info=True)
|
logger.exception("Error creating session: %s", e)
|
||||||
raise ValueError(f"Failed to create session: {e!s}")
|
raise IntegrityConstraintError(f"Failed to create session: {e!s}")
|
||||||
|
|
||||||
async def deactivate(
|
async def deactivate(
|
||||||
self, db: AsyncSession, *, session_id: str
|
self, db: AsyncSession, *, session_id: str
|
||||||
) -> UserSession | None:
|
) -> UserSession | None:
|
||||||
"""
|
"""Deactivate a session (logout from device)."""
|
||||||
Deactivate a session (logout from device).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
session_id: Session UUID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Deactivated UserSession if found, None otherwise
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
session = await self.get(db, id=session_id)
|
session = await self.get(db, id=session_id)
|
||||||
if not session:
|
if not session:
|
||||||
logger.warning(f"Session {session_id} not found for deactivation")
|
logger.warning("Session %s not found for deactivation", session_id)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
session.is_active = False
|
session.is_active = False
|
||||||
@@ -178,31 +128,23 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
|||||||
await db.refresh(session)
|
await db.refresh(session)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Session {session_id} deactivated for user {session.user_id} "
|
"Session %s deactivated for user %s (%s)",
|
||||||
f"({session.device_name})"
|
session_id,
|
||||||
|
session.user_id,
|
||||||
|
session.device_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
return session
|
return session
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await db.rollback()
|
await db.rollback()
|
||||||
logger.error(f"Error deactivating session {session_id}: {e!s}")
|
logger.error("Error deactivating session %s: %s", session_id, e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def deactivate_all_user_sessions(
|
async def deactivate_all_user_sessions(
|
||||||
self, db: AsyncSession, *, user_id: str
|
self, db: AsyncSession, *, user_id: str
|
||||||
) -> int:
|
) -> int:
|
||||||
"""
|
"""Deactivate all active sessions for a user (logout from all devices)."""
|
||||||
Deactivate all active sessions for a user (logout from all devices).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
user_id: User ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Number of sessions deactivated
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
# Convert user_id string to UUID if needed
|
|
||||||
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
|
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||||
|
|
||||||
stmt = (
|
stmt = (
|
||||||
@@ -216,27 +158,18 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
|||||||
|
|
||||||
count = result.rowcount
|
count = result.rowcount
|
||||||
|
|
||||||
logger.info(f"Deactivated {count} sessions for user {user_id}")
|
logger.info("Deactivated %s sessions for user %s", count, user_id)
|
||||||
|
|
||||||
return count
|
return count
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await db.rollback()
|
await db.rollback()
|
||||||
logger.error(f"Error deactivating all sessions for user {user_id}: {e!s}")
|
logger.error("Error deactivating all sessions for user %s: %s", user_id, e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def update_last_used(
|
async def update_last_used(
|
||||||
self, db: AsyncSession, *, session: UserSession
|
self, db: AsyncSession, *, session: UserSession
|
||||||
) -> UserSession:
|
) -> UserSession:
|
||||||
"""
|
"""Update the last_used_at timestamp for a session."""
|
||||||
Update the last_used_at timestamp for a session.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
session: UserSession object
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Updated UserSession
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
session.last_used_at = datetime.now(UTC)
|
session.last_used_at = datetime.now(UTC)
|
||||||
db.add(session)
|
db.add(session)
|
||||||
@@ -245,7 +178,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
|||||||
return session
|
return session
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await db.rollback()
|
await db.rollback()
|
||||||
logger.error(f"Error updating last_used for session {session.id}: {e!s}")
|
logger.error("Error updating last_used for session %s: %s", session.id, e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def update_refresh_token(
|
async def update_refresh_token(
|
||||||
@@ -256,20 +189,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
|||||||
new_jti: str,
|
new_jti: str,
|
||||||
new_expires_at: datetime,
|
new_expires_at: datetime,
|
||||||
) -> UserSession:
|
) -> UserSession:
|
||||||
"""
|
"""Update session with new refresh token JTI and expiration."""
|
||||||
Update session with new refresh token JTI and expiration.
|
|
||||||
|
|
||||||
Called during token refresh.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
session: UserSession object
|
|
||||||
new_jti: New refresh token JTI
|
|
||||||
new_expires_at: New expiration datetime
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Updated UserSession
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
session.refresh_token_jti = new_jti
|
session.refresh_token_jti = new_jti
|
||||||
session.expires_at = new_expires_at
|
session.expires_at = new_expires_at
|
||||||
@@ -281,32 +201,16 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
await db.rollback()
|
await db.rollback()
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Error updating refresh token for session {session.id}: {e!s}"
|
"Error updating refresh token for session %s: %s", session.id, e
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def cleanup_expired(self, db: AsyncSession, *, keep_days: int = 30) -> int:
|
async def cleanup_expired(self, db: AsyncSession, *, keep_days: int = 30) -> int:
|
||||||
"""
|
"""Clean up expired sessions using optimized bulk DELETE."""
|
||||||
Clean up expired sessions using optimized bulk DELETE.
|
|
||||||
|
|
||||||
Deletes sessions that are:
|
|
||||||
- Expired AND inactive
|
|
||||||
- Older than keep_days
|
|
||||||
|
|
||||||
Uses single DELETE query instead of N individual deletes for efficiency.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
keep_days: Keep inactive sessions for this many days (for audit)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Number of sessions deleted
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
cutoff_date = datetime.now(UTC) - timedelta(days=keep_days)
|
cutoff_date = datetime.now(UTC) - timedelta(days=keep_days)
|
||||||
now = datetime.now(UTC)
|
now = datetime.now(UTC)
|
||||||
|
|
||||||
# Use bulk DELETE with WHERE clause - single query
|
|
||||||
stmt = delete(UserSession).where(
|
stmt = delete(UserSession).where(
|
||||||
and_(
|
and_(
|
||||||
UserSession.is_active == False, # noqa: E712
|
UserSession.is_active == False, # noqa: E712
|
||||||
@@ -321,38 +225,25 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
|||||||
count = result.rowcount
|
count = result.rowcount
|
||||||
|
|
||||||
if count > 0:
|
if count > 0:
|
||||||
logger.info(f"Cleaned up {count} expired sessions using bulk DELETE")
|
logger.info("Cleaned up %s expired sessions using bulk DELETE", count)
|
||||||
|
|
||||||
return count
|
return count
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await db.rollback()
|
await db.rollback()
|
||||||
logger.error(f"Error cleaning up expired sessions: {e!s}")
|
logger.error("Error cleaning up expired sessions: %s", e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def cleanup_expired_for_user(self, db: AsyncSession, *, user_id: str) -> int:
|
async def cleanup_expired_for_user(self, db: AsyncSession, *, user_id: str) -> int:
|
||||||
"""
|
"""Clean up expired and inactive sessions for a specific user."""
|
||||||
Clean up expired and inactive sessions for a specific user.
|
|
||||||
|
|
||||||
Uses single bulk DELETE query for efficiency instead of N individual deletes.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
user_id: User ID to cleanup sessions for
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Number of sessions deleted
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
# Validate UUID
|
|
||||||
try:
|
try:
|
||||||
uuid_obj = uuid.UUID(user_id)
|
uuid_obj = uuid.UUID(user_id)
|
||||||
except (ValueError, AttributeError):
|
except (ValueError, AttributeError):
|
||||||
logger.error(f"Invalid UUID format: {user_id}")
|
logger.error("Invalid UUID format: %s", user_id)
|
||||||
raise ValueError(f"Invalid user ID format: {user_id}")
|
raise InvalidInputError(f"Invalid user ID format: {user_id}")
|
||||||
|
|
||||||
now = datetime.now(UTC)
|
now = datetime.now(UTC)
|
||||||
|
|
||||||
# Use bulk DELETE with WHERE clause - single query
|
|
||||||
stmt = delete(UserSession).where(
|
stmt = delete(UserSession).where(
|
||||||
and_(
|
and_(
|
||||||
UserSession.user_id == uuid_obj,
|
UserSession.user_id == uuid_obj,
|
||||||
@@ -368,30 +259,22 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
|||||||
|
|
||||||
if count > 0:
|
if count > 0:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Cleaned up {count} expired sessions for user {user_id} using bulk DELETE"
|
"Cleaned up %s expired sessions for user %s using bulk DELETE",
|
||||||
|
count,
|
||||||
|
user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
return count
|
return count
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await db.rollback()
|
await db.rollback()
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Error cleaning up expired sessions for user {user_id}: {e!s}"
|
"Error cleaning up expired sessions for user %s: %s", user_id, e
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def get_user_session_count(self, db: AsyncSession, *, user_id: str) -> int:
|
async def get_user_session_count(self, db: AsyncSession, *, user_id: str) -> int:
|
||||||
"""
|
"""Get count of active sessions for a user."""
|
||||||
Get count of active sessions for a user.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
user_id: User ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Number of active sessions
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
# Convert user_id string to UUID if needed
|
|
||||||
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
|
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||||
|
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
@@ -401,7 +284,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
|||||||
)
|
)
|
||||||
return result.scalar_one()
|
return result.scalar_one()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error counting sessions for user {user_id}: {e!s}")
|
logger.error("Error counting sessions for user %s: %s", user_id, e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def get_all_sessions(
|
async def get_all_sessions(
|
||||||
@@ -413,31 +296,16 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
|||||||
active_only: bool = True,
|
active_only: bool = True,
|
||||||
with_user: bool = True,
|
with_user: bool = True,
|
||||||
) -> tuple[list[UserSession], int]:
|
) -> tuple[list[UserSession], int]:
|
||||||
"""
|
"""Get all sessions across all users with pagination (admin only)."""
|
||||||
Get all sessions across all users with pagination (admin only).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
skip: Number of records to skip
|
|
||||||
limit: Maximum number of records to return
|
|
||||||
active_only: If True, return only active sessions
|
|
||||||
with_user: If True, eager load user relationship to prevent N+1
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (list of UserSession objects, total count)
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
# Build query
|
|
||||||
query = select(UserSession)
|
query = select(UserSession)
|
||||||
|
|
||||||
# Add eager loading if requested to prevent N+1 queries
|
|
||||||
if with_user:
|
if with_user:
|
||||||
query = query.options(joinedload(UserSession.user))
|
query = query.options(joinedload(UserSession.user))
|
||||||
|
|
||||||
if active_only:
|
if active_only:
|
||||||
query = query.where(UserSession.is_active)
|
query = query.where(UserSession.is_active)
|
||||||
|
|
||||||
# Get total count
|
|
||||||
count_query = select(func.count(UserSession.id))
|
count_query = select(func.count(UserSession.id))
|
||||||
if active_only:
|
if active_only:
|
||||||
count_query = count_query.where(UserSession.is_active)
|
count_query = count_query.where(UserSession.is_active)
|
||||||
@@ -445,7 +313,6 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
|||||||
count_result = await db.execute(count_query)
|
count_result = await db.execute(count_query)
|
||||||
total = count_result.scalar_one()
|
total = count_result.scalar_one()
|
||||||
|
|
||||||
# Apply pagination and ordering
|
|
||||||
query = (
|
query = (
|
||||||
query.order_by(UserSession.last_used_at.desc())
|
query.order_by(UserSession.last_used_at.desc())
|
||||||
.offset(skip)
|
.offset(skip)
|
||||||
@@ -458,9 +325,9 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
|||||||
return sessions, total
|
return sessions, total
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting all sessions: {e!s}", exc_info=True)
|
logger.exception("Error getting all sessions: %s", e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
# Create singleton instance
|
# Singleton instance
|
||||||
session = CRUDSession(UserSession)
|
session_repo = SessionRepository(UserSession)
|
||||||
155
backend/app/crud/user.py → backend/app/repositories/user.py
Executable file → Normal file
155
backend/app/crud/user.py → backend/app/repositories/user.py
Executable file → Normal file
@@ -1,5 +1,5 @@
|
|||||||
# app/crud/user_async.py
|
# app/repositories/user.py
|
||||||
"""Async CRUD operations for User model using SQLAlchemy 2.0 patterns."""
|
"""Repository for User model async database operations using SQLAlchemy 2.0 patterns."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
@@ -11,15 +11,16 @@ from sqlalchemy.exc import IntegrityError
|
|||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.core.auth import get_password_hash_async
|
from app.core.auth import get_password_hash_async
|
||||||
from app.crud.base import CRUDBase
|
from app.core.repository_exceptions import DuplicateEntryError, InvalidInputError
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
from app.repositories.base import BaseRepository
|
||||||
from app.schemas.users import UserCreate, UserUpdate
|
from app.schemas.users import UserCreate, UserUpdate
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
class UserRepository(BaseRepository[User, UserCreate, UserUpdate]):
|
||||||
"""Async CRUD operations for User model."""
|
"""Repository for User model."""
|
||||||
|
|
||||||
async def get_by_email(self, db: AsyncSession, *, email: str) -> User | None:
|
async def get_by_email(self, db: AsyncSession, *, email: str) -> User | None:
|
||||||
"""Get user by email address."""
|
"""Get user by email address."""
|
||||||
@@ -27,13 +28,12 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
|||||||
result = await db.execute(select(User).where(User.email == email))
|
result = await db.execute(select(User).where(User.email == email))
|
||||||
return result.scalar_one_or_none()
|
return result.scalar_one_or_none()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting user by email {email}: {e!s}")
|
logger.error("Error getting user by email %s: %s", email, e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def create(self, db: AsyncSession, *, obj_in: UserCreate) -> User:
|
async def create(self, db: AsyncSession, *, obj_in: UserCreate) -> User:
|
||||||
"""Create a new user with async password hashing and error handling."""
|
"""Create a new user with async password hashing and error handling."""
|
||||||
try:
|
try:
|
||||||
# Hash password asynchronously to avoid blocking event loop
|
|
||||||
password_hash = await get_password_hash_async(obj_in.password)
|
password_hash = await get_password_hash_async(obj_in.password)
|
||||||
|
|
||||||
db_obj = User(
|
db_obj = User(
|
||||||
@@ -57,13 +57,49 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
|||||||
await db.rollback()
|
await db.rollback()
|
||||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||||
if "email" in error_msg.lower():
|
if "email" in error_msg.lower():
|
||||||
logger.warning(f"Duplicate email attempted: {obj_in.email}")
|
logger.warning("Duplicate email attempted: %s", obj_in.email)
|
||||||
raise ValueError(f"User with email {obj_in.email} already exists")
|
raise DuplicateEntryError(
|
||||||
logger.error(f"Integrity error creating user: {error_msg}")
|
f"User with email {obj_in.email} already exists"
|
||||||
raise ValueError(f"Database integrity error: {error_msg}")
|
)
|
||||||
|
logger.error("Integrity error creating user: %s", error_msg)
|
||||||
|
raise DuplicateEntryError(f"Database integrity error: {error_msg}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await db.rollback()
|
await db.rollback()
|
||||||
logger.error(f"Unexpected error creating user: {e!s}", exc_info=True)
|
logger.exception("Unexpected error creating user: %s", e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def create_oauth_user(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
email: str,
|
||||||
|
first_name: str = "User",
|
||||||
|
last_name: str | None = None,
|
||||||
|
) -> User:
|
||||||
|
"""Create a new passwordless user for OAuth sign-in."""
|
||||||
|
try:
|
||||||
|
db_obj = User(
|
||||||
|
email=email,
|
||||||
|
password_hash=None, # OAuth-only user
|
||||||
|
first_name=first_name,
|
||||||
|
last_name=last_name,
|
||||||
|
is_active=True,
|
||||||
|
is_superuser=False,
|
||||||
|
)
|
||||||
|
db.add(db_obj)
|
||||||
|
await db.flush() # Get user.id without committing
|
||||||
|
return db_obj
|
||||||
|
except IntegrityError as e:
|
||||||
|
await db.rollback()
|
||||||
|
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||||
|
if "email" in error_msg.lower():
|
||||||
|
logger.warning("Duplicate email attempted: %s", email)
|
||||||
|
raise DuplicateEntryError(f"User with email {email} already exists")
|
||||||
|
logger.error("Integrity error creating OAuth user: %s", error_msg)
|
||||||
|
raise DuplicateEntryError(f"Database integrity error: {error_msg}")
|
||||||
|
except Exception as e:
|
||||||
|
await db.rollback()
|
||||||
|
logger.exception("Unexpected error creating OAuth user: %s", e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def update(
|
async def update(
|
||||||
@@ -75,8 +111,6 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
|||||||
else:
|
else:
|
||||||
update_data = obj_in.model_dump(exclude_unset=True)
|
update_data = obj_in.model_dump(exclude_unset=True)
|
||||||
|
|
||||||
# Handle password separately if it exists in update data
|
|
||||||
# Hash password asynchronously to avoid blocking event loop
|
|
||||||
if "password" in update_data:
|
if "password" in update_data:
|
||||||
update_data["password_hash"] = await get_password_hash_async(
|
update_data["password_hash"] = await get_password_hash_async(
|
||||||
update_data["password"]
|
update_data["password"]
|
||||||
@@ -85,6 +119,15 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
|||||||
|
|
||||||
return await super().update(db, db_obj=db_obj, obj_in=update_data)
|
return await super().update(db, db_obj=db_obj, obj_in=update_data)
|
||||||
|
|
||||||
|
async def update_password(
|
||||||
|
self, db: AsyncSession, *, user: User, password_hash: str
|
||||||
|
) -> User:
|
||||||
|
"""Set a new password hash on a user and commit."""
|
||||||
|
user.password_hash = password_hash
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(user)
|
||||||
|
return user
|
||||||
|
|
||||||
async def get_multi_with_total(
|
async def get_multi_with_total(
|
||||||
self,
|
self,
|
||||||
db: AsyncSession,
|
db: AsyncSession,
|
||||||
@@ -96,43 +139,23 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
|||||||
filters: dict[str, Any] | None = None,
|
filters: dict[str, Any] | None = None,
|
||||||
search: str | None = None,
|
search: str | None = None,
|
||||||
) -> tuple[list[User], int]:
|
) -> tuple[list[User], int]:
|
||||||
"""
|
"""Get multiple users with total count, filtering, sorting, and search."""
|
||||||
Get multiple users with total count, filtering, sorting, and search.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
skip: Number of records to skip
|
|
||||||
limit: Maximum number of records to return
|
|
||||||
sort_by: Field name to sort by
|
|
||||||
sort_order: Sort order ("asc" or "desc")
|
|
||||||
filters: Dictionary of filters (field_name: value)
|
|
||||||
search: Search term to match against email, first_name, last_name
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (users list, total count)
|
|
||||||
"""
|
|
||||||
# Validate pagination
|
|
||||||
if skip < 0:
|
if skip < 0:
|
||||||
raise ValueError("skip must be non-negative")
|
raise InvalidInputError("skip must be non-negative")
|
||||||
if limit < 0:
|
if limit < 0:
|
||||||
raise ValueError("limit must be non-negative")
|
raise InvalidInputError("limit must be non-negative")
|
||||||
if limit > 1000:
|
if limit > 1000:
|
||||||
raise ValueError("Maximum limit is 1000")
|
raise InvalidInputError("Maximum limit is 1000")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Build base query
|
|
||||||
query = select(User)
|
query = select(User)
|
||||||
|
|
||||||
# Exclude soft-deleted users
|
|
||||||
query = query.where(User.deleted_at.is_(None))
|
query = query.where(User.deleted_at.is_(None))
|
||||||
|
|
||||||
# Apply filters
|
|
||||||
if filters:
|
if filters:
|
||||||
for field, value in filters.items():
|
for field, value in filters.items():
|
||||||
if hasattr(User, field) and value is not None:
|
if hasattr(User, field) and value is not None:
|
||||||
query = query.where(getattr(User, field) == value)
|
query = query.where(getattr(User, field) == value)
|
||||||
|
|
||||||
# Apply search
|
|
||||||
if search:
|
if search:
|
||||||
search_filter = or_(
|
search_filter = or_(
|
||||||
User.email.ilike(f"%{search}%"),
|
User.email.ilike(f"%{search}%"),
|
||||||
@@ -141,14 +164,12 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
|||||||
)
|
)
|
||||||
query = query.where(search_filter)
|
query = query.where(search_filter)
|
||||||
|
|
||||||
# Get total count
|
|
||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
|
|
||||||
count_query = select(func.count()).select_from(query.alias())
|
count_query = select(func.count()).select_from(query.alias())
|
||||||
count_result = await db.execute(count_query)
|
count_result = await db.execute(count_query)
|
||||||
total = count_result.scalar_one()
|
total = count_result.scalar_one()
|
||||||
|
|
||||||
# Apply sorting
|
|
||||||
if sort_by and hasattr(User, sort_by):
|
if sort_by and hasattr(User, sort_by):
|
||||||
sort_column = getattr(User, sort_by)
|
sort_column = getattr(User, sort_by)
|
||||||
if sort_order.lower() == "desc":
|
if sort_order.lower() == "desc":
|
||||||
@@ -156,7 +177,6 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
|||||||
else:
|
else:
|
||||||
query = query.order_by(sort_column.asc())
|
query = query.order_by(sort_column.asc())
|
||||||
|
|
||||||
# Apply pagination
|
|
||||||
query = query.offset(skip).limit(limit)
|
query = query.offset(skip).limit(limit)
|
||||||
result = await db.execute(query)
|
result = await db.execute(query)
|
||||||
users = list(result.scalars().all())
|
users = list(result.scalars().all())
|
||||||
@@ -164,32 +184,21 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
|||||||
return users, total
|
return users, total
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error retrieving paginated users: {e!s}")
|
logger.error("Error retrieving paginated users: %s", e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def bulk_update_status(
|
async def bulk_update_status(
|
||||||
self, db: AsyncSession, *, user_ids: list[UUID], is_active: bool
|
self, db: AsyncSession, *, user_ids: list[UUID], is_active: bool
|
||||||
) -> int:
|
) -> int:
|
||||||
"""
|
"""Bulk update is_active status for multiple users."""
|
||||||
Bulk update is_active status for multiple users.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
user_ids: List of user IDs to update
|
|
||||||
is_active: New active status
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Number of users updated
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
if not user_ids:
|
if not user_ids:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
# Use UPDATE with WHERE IN for efficiency
|
|
||||||
stmt = (
|
stmt = (
|
||||||
update(User)
|
update(User)
|
||||||
.where(User.id.in_(user_ids))
|
.where(User.id.in_(user_ids))
|
||||||
.where(User.deleted_at.is_(None)) # Don't update deleted users
|
.where(User.deleted_at.is_(None))
|
||||||
.values(is_active=is_active, updated_at=datetime.now(UTC))
|
.values(is_active=is_active, updated_at=datetime.now(UTC))
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -197,12 +206,14 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
|||||||
await db.commit()
|
await db.commit()
|
||||||
|
|
||||||
updated_count = result.rowcount
|
updated_count = result.rowcount
|
||||||
logger.info(f"Bulk updated {updated_count} users to is_active={is_active}")
|
logger.info(
|
||||||
|
"Bulk updated %s users to is_active=%s", updated_count, is_active
|
||||||
|
)
|
||||||
return updated_count
|
return updated_count
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await db.rollback()
|
await db.rollback()
|
||||||
logger.error(f"Error bulk updating user status: {e!s}", exc_info=True)
|
logger.exception("Error bulk updating user status: %s", e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def bulk_soft_delete(
|
async def bulk_soft_delete(
|
||||||
@@ -212,34 +223,20 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
|||||||
user_ids: list[UUID],
|
user_ids: list[UUID],
|
||||||
exclude_user_id: UUID | None = None,
|
exclude_user_id: UUID | None = None,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""
|
"""Bulk soft delete multiple users."""
|
||||||
Bulk soft delete multiple users.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
user_ids: List of user IDs to delete
|
|
||||||
exclude_user_id: Optional user ID to exclude (e.g., the admin performing the action)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Number of users deleted
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
if not user_ids:
|
if not user_ids:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
# Remove excluded user from list
|
|
||||||
filtered_ids = [uid for uid in user_ids if uid != exclude_user_id]
|
filtered_ids = [uid for uid in user_ids if uid != exclude_user_id]
|
||||||
|
|
||||||
if not filtered_ids:
|
if not filtered_ids:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
# Use UPDATE with WHERE IN for efficiency
|
|
||||||
stmt = (
|
stmt = (
|
||||||
update(User)
|
update(User)
|
||||||
.where(User.id.in_(filtered_ids))
|
.where(User.id.in_(filtered_ids))
|
||||||
.where(
|
.where(User.deleted_at.is_(None))
|
||||||
User.deleted_at.is_(None)
|
|
||||||
) # Don't re-delete already deleted users
|
|
||||||
.values(
|
.values(
|
||||||
deleted_at=datetime.now(UTC),
|
deleted_at=datetime.now(UTC),
|
||||||
is_active=False,
|
is_active=False,
|
||||||
@@ -251,22 +248,22 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
|||||||
await db.commit()
|
await db.commit()
|
||||||
|
|
||||||
deleted_count = result.rowcount
|
deleted_count = result.rowcount
|
||||||
logger.info(f"Bulk soft deleted {deleted_count} users")
|
logger.info("Bulk soft deleted %s users", deleted_count)
|
||||||
return deleted_count
|
return deleted_count
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await db.rollback()
|
await db.rollback()
|
||||||
logger.error(f"Error bulk deleting users: {e!s}", exc_info=True)
|
logger.exception("Error bulk deleting users: %s", e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def is_active(self, user: User) -> bool:
|
def is_active(self, user: User) -> bool:
|
||||||
"""Check if user is active."""
|
"""Check if user is active."""
|
||||||
return user.is_active
|
return bool(user.is_active)
|
||||||
|
|
||||||
def is_superuser(self, user: User) -> bool:
|
def is_superuser(self, user: User) -> bool:
|
||||||
"""Check if user is a superuser."""
|
"""Check if user is a superuser."""
|
||||||
return user.is_superuser
|
return bool(user.is_superuser)
|
||||||
|
|
||||||
|
|
||||||
# Create a singleton instance for use across the application
|
# Singleton instance
|
||||||
user = CRUDUser(User)
|
user_repo = UserRepository(User)
|
||||||
395
backend/app/schemas/oauth.py
Normal file
395
backend/app/schemas/oauth.py
Normal file
@@ -0,0 +1,395 @@
|
|||||||
|
"""
|
||||||
|
Pydantic schemas for OAuth authentication.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# OAuth Provider Info (for frontend to display available providers)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthProviderInfo(BaseModel):
|
||||||
|
"""Information about an available OAuth provider."""
|
||||||
|
|
||||||
|
provider: str = Field(..., description="Provider identifier (google, github)")
|
||||||
|
name: str = Field(..., description="Human-readable provider name")
|
||||||
|
icon: str | None = Field(None, description="Icon identifier for frontend")
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthProvidersResponse(BaseModel):
|
||||||
|
"""Response containing list of enabled OAuth providers."""
|
||||||
|
|
||||||
|
enabled: bool = Field(..., description="Whether OAuth is globally enabled")
|
||||||
|
providers: list[OAuthProviderInfo] = Field(
|
||||||
|
default_factory=list, description="List of enabled providers"
|
||||||
|
)
|
||||||
|
|
||||||
|
model_config = ConfigDict(
|
||||||
|
json_schema_extra={
|
||||||
|
"example": {
|
||||||
|
"enabled": True,
|
||||||
|
"providers": [
|
||||||
|
{"provider": "google", "name": "Google", "icon": "google"},
|
||||||
|
{"provider": "github", "name": "GitHub", "icon": "github"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# OAuth Account (linked provider accounts)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthAccountBase(BaseModel):
|
||||||
|
"""Base schema for OAuth accounts."""
|
||||||
|
|
||||||
|
provider: str = Field(..., max_length=50, description="OAuth provider name")
|
||||||
|
provider_email: str | None = Field(
|
||||||
|
None, max_length=255, description="Email from OAuth provider"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthAccountCreate(OAuthAccountBase):
|
||||||
|
"""Schema for creating an OAuth account link (internal use)."""
|
||||||
|
|
||||||
|
user_id: UUID
|
||||||
|
provider_user_id: str = Field(..., max_length=255)
|
||||||
|
access_token: str | None = None
|
||||||
|
refresh_token: str | None = None
|
||||||
|
token_expires_at: datetime | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthAccountResponse(OAuthAccountBase):
|
||||||
|
"""Schema for OAuth account response to clients."""
|
||||||
|
|
||||||
|
id: UUID
|
||||||
|
created_at: datetime
|
||||||
|
|
||||||
|
model_config = ConfigDict(
|
||||||
|
from_attributes=True,
|
||||||
|
json_schema_extra={
|
||||||
|
"example": {
|
||||||
|
"id": "123e4567-e89b-12d3-a456-426614174000",
|
||||||
|
"provider": "google",
|
||||||
|
"provider_email": "user@gmail.com",
|
||||||
|
"created_at": "2025-11-24T12:00:00Z",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthAccountsListResponse(BaseModel):
|
||||||
|
"""Response containing list of linked OAuth accounts."""
|
||||||
|
|
||||||
|
accounts: list[OAuthAccountResponse]
|
||||||
|
|
||||||
|
model_config = ConfigDict(
|
||||||
|
json_schema_extra={
|
||||||
|
"example": {
|
||||||
|
"accounts": [
|
||||||
|
{
|
||||||
|
"id": "123e4567-e89b-12d3-a456-426614174000",
|
||||||
|
"provider": "google",
|
||||||
|
"provider_email": "user@gmail.com",
|
||||||
|
"created_at": "2025-11-24T12:00:00Z",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# OAuth Flow (authorization, callback, etc.)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthAuthorizeRequest(BaseModel):
|
||||||
|
"""Request parameters for OAuth authorization."""
|
||||||
|
|
||||||
|
provider: str = Field(..., description="OAuth provider (google, github)")
|
||||||
|
redirect_uri: str | None = Field(
|
||||||
|
None, description="Frontend callback URL after OAuth"
|
||||||
|
)
|
||||||
|
mode: str = Field(
|
||||||
|
default="login",
|
||||||
|
description="OAuth mode: login, register, or link",
|
||||||
|
pattern="^(login|register|link)$",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthCallbackRequest(BaseModel):
|
||||||
|
"""Request parameters for OAuth callback."""
|
||||||
|
|
||||||
|
code: str = Field(..., description="Authorization code from provider")
|
||||||
|
state: str = Field(..., description="State parameter for CSRF protection")
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthCallbackResponse(BaseModel):
|
||||||
|
"""Response after successful OAuth authentication."""
|
||||||
|
|
||||||
|
access_token: str = Field(..., description="JWT access token")
|
||||||
|
refresh_token: str = Field(..., description="JWT refresh token")
|
||||||
|
token_type: str = Field(default="bearer")
|
||||||
|
expires_in: int = Field(..., description="Token expiration in seconds")
|
||||||
|
is_new_user: bool = Field(
|
||||||
|
default=False, description="Whether a new user was created"
|
||||||
|
)
|
||||||
|
|
||||||
|
model_config = ConfigDict(
|
||||||
|
json_schema_extra={
|
||||||
|
"example": {
|
||||||
|
"access_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...",
|
||||||
|
"refresh_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...",
|
||||||
|
"token_type": "bearer",
|
||||||
|
"expires_in": 900,
|
||||||
|
"is_new_user": False,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthUnlinkResponse(BaseModel):
|
||||||
|
"""Response after unlinking an OAuth account."""
|
||||||
|
|
||||||
|
success: bool = Field(..., description="Whether the unlink was successful")
|
||||||
|
message: str = Field(..., description="Status message")
|
||||||
|
|
||||||
|
model_config = ConfigDict(
|
||||||
|
json_schema_extra={
|
||||||
|
"example": {"success": True, "message": "Google account unlinked"}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# OAuth State (CSRF protection - internal use)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthStateCreate(BaseModel):
|
||||||
|
"""Schema for creating OAuth state (internal use)."""
|
||||||
|
|
||||||
|
state: str = Field(..., max_length=255)
|
||||||
|
code_verifier: str | None = Field(None, max_length=128)
|
||||||
|
nonce: str | None = Field(None, max_length=255)
|
||||||
|
provider: str = Field(..., max_length=50)
|
||||||
|
redirect_uri: str | None = Field(None, max_length=500)
|
||||||
|
user_id: UUID | None = None
|
||||||
|
expires_at: datetime
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# OAuth Client (Provider Mode - MCP clients)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthClientBase(BaseModel):
|
||||||
|
"""Base schema for OAuth clients."""
|
||||||
|
|
||||||
|
client_name: str = Field(..., max_length=255, description="Client application name")
|
||||||
|
client_description: str | None = Field(
|
||||||
|
None, max_length=1000, description="Client description"
|
||||||
|
)
|
||||||
|
redirect_uris: list[str] = Field(
|
||||||
|
default_factory=list, description="Allowed redirect URIs"
|
||||||
|
)
|
||||||
|
allowed_scopes: list[str] = Field(
|
||||||
|
default_factory=list, description="Allowed OAuth scopes"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthClientCreate(OAuthClientBase):
|
||||||
|
"""Schema for creating an OAuth client."""
|
||||||
|
|
||||||
|
client_type: str = Field(
|
||||||
|
default="public",
|
||||||
|
description="Client type: public or confidential",
|
||||||
|
pattern="^(public|confidential)$",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthClientResponse(OAuthClientBase):
|
||||||
|
"""Schema for OAuth client response."""
|
||||||
|
|
||||||
|
id: UUID
|
||||||
|
client_id: str = Field(..., description="OAuth client ID")
|
||||||
|
client_type: str
|
||||||
|
is_active: bool
|
||||||
|
created_at: datetime
|
||||||
|
|
||||||
|
model_config = ConfigDict(
|
||||||
|
from_attributes=True,
|
||||||
|
json_schema_extra={
|
||||||
|
"example": {
|
||||||
|
"id": "123e4567-e89b-12d3-a456-426614174000",
|
||||||
|
"client_id": "abc123def456",
|
||||||
|
"client_name": "My MCP App",
|
||||||
|
"client_description": "My application that uses MCP",
|
||||||
|
"client_type": "public",
|
||||||
|
"redirect_uris": ["http://localhost:3000/callback"],
|
||||||
|
"allowed_scopes": ["read:users", "write:users"],
|
||||||
|
"is_active": True,
|
||||||
|
"created_at": "2025-11-24T12:00:00Z",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthClientWithSecret(OAuthClientResponse):
|
||||||
|
"""Schema for OAuth client response including secret (only shown once)."""
|
||||||
|
|
||||||
|
client_secret: str | None = Field(
|
||||||
|
None, description="Client secret (only shown once for confidential clients)"
|
||||||
|
)
|
||||||
|
|
||||||
|
model_config = ConfigDict(
|
||||||
|
from_attributes=True,
|
||||||
|
json_schema_extra={
|
||||||
|
"example": {
|
||||||
|
"id": "123e4567-e89b-12d3-a456-426614174000",
|
||||||
|
"client_id": "abc123def456",
|
||||||
|
"client_secret": "secret_xyz789",
|
||||||
|
"client_name": "My MCP App",
|
||||||
|
"client_type": "confidential",
|
||||||
|
"redirect_uris": ["http://localhost:3000/callback"],
|
||||||
|
"allowed_scopes": ["read:users"],
|
||||||
|
"is_active": True,
|
||||||
|
"created_at": "2025-11-24T12:00:00Z",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# OAuth Provider Discovery (RFC 8414 - skeleton)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthServerMetadata(BaseModel):
|
||||||
|
"""OAuth 2.0 Authorization Server Metadata (RFC 8414)."""
|
||||||
|
|
||||||
|
issuer: str = Field(..., description="Authorization server issuer URL")
|
||||||
|
authorization_endpoint: str = Field(..., description="Authorization endpoint URL")
|
||||||
|
token_endpoint: str = Field(..., description="Token endpoint URL")
|
||||||
|
registration_endpoint: str | None = Field(
|
||||||
|
None, description="Dynamic client registration endpoint"
|
||||||
|
)
|
||||||
|
revocation_endpoint: str | None = Field(
|
||||||
|
None, description="Token revocation endpoint"
|
||||||
|
)
|
||||||
|
introspection_endpoint: str | None = Field(
|
||||||
|
None, description="Token introspection endpoint (RFC 7662)"
|
||||||
|
)
|
||||||
|
scopes_supported: list[str] = Field(
|
||||||
|
default_factory=list, description="Supported scopes"
|
||||||
|
)
|
||||||
|
response_types_supported: list[str] = Field(
|
||||||
|
default_factory=lambda: ["code"], description="Supported response types"
|
||||||
|
)
|
||||||
|
grant_types_supported: list[str] = Field(
|
||||||
|
default_factory=lambda: ["authorization_code", "refresh_token"],
|
||||||
|
description="Supported grant types",
|
||||||
|
)
|
||||||
|
code_challenge_methods_supported: list[str] = Field(
|
||||||
|
default_factory=lambda: ["S256"], description="Supported PKCE methods"
|
||||||
|
)
|
||||||
|
token_endpoint_auth_methods_supported: list[str] = Field(
|
||||||
|
default_factory=lambda: ["client_secret_basic", "client_secret_post", "none"],
|
||||||
|
description="Supported client authentication methods",
|
||||||
|
)
|
||||||
|
|
||||||
|
model_config = ConfigDict(
|
||||||
|
json_schema_extra={
|
||||||
|
"example": {
|
||||||
|
"issuer": "https://api.example.com",
|
||||||
|
"authorization_endpoint": "https://api.example.com/oauth/authorize",
|
||||||
|
"token_endpoint": "https://api.example.com/oauth/token",
|
||||||
|
"revocation_endpoint": "https://api.example.com/oauth/revoke",
|
||||||
|
"introspection_endpoint": "https://api.example.com/oauth/introspect",
|
||||||
|
"scopes_supported": ["openid", "profile", "email", "read:users"],
|
||||||
|
"response_types_supported": ["code"],
|
||||||
|
"grant_types_supported": ["authorization_code", "refresh_token"],
|
||||||
|
"code_challenge_methods_supported": ["S256"],
|
||||||
|
"token_endpoint_auth_methods_supported": [
|
||||||
|
"client_secret_basic",
|
||||||
|
"client_secret_post",
|
||||||
|
"none",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# OAuth Token Responses (RFC 6749)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthTokenResponse(BaseModel):
|
||||||
|
"""OAuth 2.0 Token Response (RFC 6749 Section 5.1)."""
|
||||||
|
|
||||||
|
access_token: str = Field(..., description="The access token issued by the server")
|
||||||
|
token_type: str = Field(
|
||||||
|
default="Bearer", description="The type of token (typically 'Bearer')"
|
||||||
|
)
|
||||||
|
expires_in: int | None = Field(None, description="Token lifetime in seconds")
|
||||||
|
refresh_token: str | None = Field(
|
||||||
|
None, description="Refresh token for obtaining new access tokens"
|
||||||
|
)
|
||||||
|
scope: str | None = Field(
|
||||||
|
None, description="Space-separated list of granted scopes"
|
||||||
|
)
|
||||||
|
|
||||||
|
model_config = ConfigDict(
|
||||||
|
json_schema_extra={
|
||||||
|
"example": {
|
||||||
|
"access_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...",
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"expires_in": 3600,
|
||||||
|
"refresh_token": "dGhpcyBpcyBhIHJlZnJlc2ggdG9rZW4...",
|
||||||
|
"scope": "openid profile email",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthTokenIntrospectionResponse(BaseModel):
|
||||||
|
"""OAuth 2.0 Token Introspection Response (RFC 7662)."""
|
||||||
|
|
||||||
|
active: bool = Field(..., description="Whether the token is currently active")
|
||||||
|
scope: str | None = Field(None, description="Space-separated list of scopes")
|
||||||
|
client_id: str | None = Field(None, description="Client identifier for the token")
|
||||||
|
username: str | None = Field(
|
||||||
|
None, description="Human-readable identifier for the resource owner"
|
||||||
|
)
|
||||||
|
token_type: str | None = Field(
|
||||||
|
None, description="Type of the token (e.g., 'Bearer')"
|
||||||
|
)
|
||||||
|
exp: int | None = Field(None, description="Token expiration time (Unix timestamp)")
|
||||||
|
iat: int | None = Field(None, description="Token issue time (Unix timestamp)")
|
||||||
|
nbf: int | None = Field(None, description="Token not-before time (Unix timestamp)")
|
||||||
|
sub: str | None = Field(None, description="Subject of the token (user ID)")
|
||||||
|
aud: str | None = Field(None, description="Intended audience of the token")
|
||||||
|
iss: str | None = Field(None, description="Issuer of the token")
|
||||||
|
|
||||||
|
model_config = ConfigDict(
|
||||||
|
json_schema_extra={
|
||||||
|
"example": {
|
||||||
|
"active": True,
|
||||||
|
"scope": "openid profile",
|
||||||
|
"client_id": "client123",
|
||||||
|
"username": "user@example.com",
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"exp": 1735689600,
|
||||||
|
"iat": 1735686000,
|
||||||
|
"sub": "user-uuid-here",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
@@ -48,7 +48,7 @@ class OrganizationCreate(OrganizationBase):
|
|||||||
"""Schema for creating a new organization."""
|
"""Schema for creating a new organization."""
|
||||||
|
|
||||||
name: str = Field(..., min_length=1, max_length=255)
|
name: str = Field(..., min_length=1, max_length=255)
|
||||||
slug: str = Field(..., min_length=1, max_length=255)
|
slug: str = Field(..., min_length=1, max_length=255) # pyright: ignore[reportIncompatibleVariableOverride]
|
||||||
|
|
||||||
|
|
||||||
class OrganizationUpdate(BaseModel):
|
class OrganizationUpdate(BaseModel):
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ class UserBase(BaseModel):
|
|||||||
class UserCreate(UserBase):
|
class UserCreate(UserBase):
|
||||||
password: str
|
password: str
|
||||||
is_superuser: bool = False
|
is_superuser: bool = False
|
||||||
|
is_active: bool = True
|
||||||
|
|
||||||
@field_validator("password")
|
@field_validator("password")
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -40,9 +41,9 @@ class UserUpdate(BaseModel):
|
|||||||
locale: str | None = Field(
|
locale: str | None = Field(
|
||||||
None,
|
None,
|
||||||
max_length=10,
|
max_length=10,
|
||||||
pattern=r'^[a-z]{2}(-[A-Z]{2})?$',
|
pattern=r"^[a-z]{2}(-[A-Z]{2})?$",
|
||||||
description="User's preferred locale (BCP 47 format: en, it, en-US, it-IT)",
|
description="User's preferred locale (BCP 47 format: en, it, en-US, it-IT)",
|
||||||
examples=["en", "it", "en-US", "it-IT"]
|
examples=["en", "it", "en-US", "it-IT"],
|
||||||
)
|
)
|
||||||
is_active: bool | None = (
|
is_active: bool | None = (
|
||||||
None # Changed default from True to None to avoid unintended updates
|
None # Changed default from True to None to avoid unintended updates
|
||||||
@@ -70,12 +71,12 @@ class UserUpdate(BaseModel):
|
|||||||
return v
|
return v
|
||||||
# Only support English and Italian for template showcase
|
# Only support English and Italian for template showcase
|
||||||
# Note: Locales stored in lowercase for case-insensitive matching
|
# Note: Locales stored in lowercase for case-insensitive matching
|
||||||
SUPPORTED_LOCALES = {"en", "it", "en-us", "en-gb", "it-it"}
|
supported_locales = {"en", "it", "en-us", "en-gb", "it-it"}
|
||||||
# Normalize to lowercase for comparison and storage
|
# Normalize to lowercase for comparison and storage
|
||||||
v_lower = v.lower()
|
v_lower = v.lower()
|
||||||
if v_lower not in SUPPORTED_LOCALES:
|
if v_lower not in supported_locales:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported locale '{v}'. Supported locales: {sorted(SUPPORTED_LOCALES)}"
|
f"Unsupported locale '{v}'. Supported locales: {sorted(supported_locales)}"
|
||||||
)
|
)
|
||||||
# Return normalized lowercase version for consistency
|
# Return normalized lowercase version for consistency
|
||||||
return v_lower
|
return v_lower
|
||||||
|
|||||||
@@ -60,6 +60,15 @@ def validate_password_strength(password: str) -> str:
|
|||||||
>>> validate_password_strength("MySecureP@ss123") # Valid
|
>>> validate_password_strength("MySecureP@ss123") # Valid
|
||||||
>>> validate_password_strength("password1") # Invalid - too weak
|
>>> validate_password_strength("password1") # Invalid - too weak
|
||||||
"""
|
"""
|
||||||
|
# Check if we are in demo mode
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
if settings.DEMO_MODE:
|
||||||
|
# In demo mode, allow specific weak passwords for demo accounts
|
||||||
|
demo_passwords = {"Demo123!", "Admin123!"}
|
||||||
|
if password in demo_passwords:
|
||||||
|
return password
|
||||||
|
|
||||||
# Check minimum length
|
# Check minimum length
|
||||||
if len(password) < 12:
|
if len(password) < 12:
|
||||||
raise ValueError("Password must be at least 12 characters long")
|
raise ValueError("Password must be at least 12 characters long")
|
||||||
|
|||||||
@@ -0,0 +1,19 @@
|
|||||||
|
# app/services/__init__.py
|
||||||
|
from . import oauth_provider_service
|
||||||
|
from .auth_service import AuthService
|
||||||
|
from .oauth_service import OAuthService
|
||||||
|
from .organization_service import OrganizationService, organization_service
|
||||||
|
from .session_service import SessionService, session_service
|
||||||
|
from .user_service import UserService, user_service
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AuthService",
|
||||||
|
"OAuthService",
|
||||||
|
"OrganizationService",
|
||||||
|
"SessionService",
|
||||||
|
"UserService",
|
||||||
|
"oauth_provider_service",
|
||||||
|
"organization_service",
|
||||||
|
"session_service",
|
||||||
|
"user_service",
|
||||||
|
]
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from sqlalchemy import select
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.core.auth import (
|
from app.core.auth import (
|
||||||
@@ -14,12 +13,18 @@ from app.core.auth import (
|
|||||||
verify_password_async,
|
verify_password_async,
|
||||||
)
|
)
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.exceptions import AuthenticationError
|
from app.core.exceptions import AuthenticationError, DuplicateError
|
||||||
|
from app.core.repository_exceptions import DuplicateEntryError
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
from app.repositories.user import user_repo
|
||||||
from app.schemas.users import Token, UserCreate, UserResponse
|
from app.schemas.users import Token, UserCreate, UserResponse
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Pre-computed bcrypt hash used for constant-time comparison when user is not found,
|
||||||
|
# preventing timing attacks that could enumerate valid email addresses.
|
||||||
|
_DUMMY_HASH = "$2b$12$EixZaYVK1fsbw1ZfbX3OXePaWxn96p36zLFbnJHfxPSEFBzXKiHia"
|
||||||
|
|
||||||
|
|
||||||
class AuthService:
|
class AuthService:
|
||||||
"""Service for handling authentication operations"""
|
"""Service for handling authentication operations"""
|
||||||
@@ -39,10 +44,12 @@ class AuthService:
|
|||||||
Returns:
|
Returns:
|
||||||
User if authenticated, None otherwise
|
User if authenticated, None otherwise
|
||||||
"""
|
"""
|
||||||
result = await db.execute(select(User).where(User.email == email))
|
user = await user_repo.get_by_email(db, email=email)
|
||||||
user = result.scalar_one_or_none()
|
|
||||||
|
|
||||||
if not user:
|
if not user:
|
||||||
|
# Perform a dummy verification to match timing of a real bcrypt check,
|
||||||
|
# preventing email enumeration via response-time differences.
|
||||||
|
await verify_password_async(password, _DUMMY_HASH)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Verify password asynchronously to avoid blocking event loop
|
# Verify password asynchronously to avoid blocking event loop
|
||||||
@@ -71,40 +78,23 @@ class AuthService:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Check if user already exists
|
# Check if user already exists
|
||||||
result = await db.execute(select(User).where(User.email == user_data.email))
|
existing_user = await user_repo.get_by_email(db, email=user_data.email)
|
||||||
existing_user = result.scalar_one_or_none()
|
|
||||||
if existing_user:
|
if existing_user:
|
||||||
raise AuthenticationError("User with this email already exists")
|
raise DuplicateError("User with this email already exists")
|
||||||
|
|
||||||
# Create new user with async password hashing
|
# Delegate creation (hashing + commit) to the repository
|
||||||
# Hash password asynchronously to avoid blocking event loop
|
user = await user_repo.create(db, obj_in=user_data)
|
||||||
hashed_password = await get_password_hash_async(user_data.password)
|
|
||||||
|
|
||||||
# Create user object from model
|
logger.info("User created successfully: %s", user.email)
|
||||||
user = User(
|
|
||||||
email=user_data.email,
|
|
||||||
password_hash=hashed_password,
|
|
||||||
first_name=user_data.first_name,
|
|
||||||
last_name=user_data.last_name,
|
|
||||||
phone_number=user_data.phone_number,
|
|
||||||
is_active=True,
|
|
||||||
is_superuser=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
db.add(user)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(user)
|
|
||||||
|
|
||||||
logger.info(f"User created successfully: {user.email}")
|
|
||||||
return user
|
return user
|
||||||
|
|
||||||
except AuthenticationError:
|
except (AuthenticationError, DuplicateError):
|
||||||
# Re-raise authentication errors without rollback
|
# Re-raise API exceptions without rollback
|
||||||
raise
|
raise
|
||||||
|
except DuplicateEntryError as e:
|
||||||
|
raise DuplicateError(str(e))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Rollback on any database errors
|
logger.exception("Error creating user: %s", e)
|
||||||
await db.rollback()
|
|
||||||
logger.error(f"Error creating user: {e!s}", exc_info=True)
|
|
||||||
raise AuthenticationError(f"Failed to create user: {e!s}")
|
raise AuthenticationError(f"Failed to create user: {e!s}")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -168,8 +158,7 @@ class AuthService:
|
|||||||
user_id = token_data.user_id
|
user_id = token_data.user_id
|
||||||
|
|
||||||
# Get user from database
|
# Get user from database
|
||||||
result = await db.execute(select(User).where(User.id == user_id))
|
user = await user_repo.get(db, id=str(user_id))
|
||||||
user = result.scalar_one_or_none()
|
|
||||||
if not user or not user.is_active:
|
if not user or not user.is_active:
|
||||||
raise TokenInvalidError("Invalid user or inactive account")
|
raise TokenInvalidError("Invalid user or inactive account")
|
||||||
|
|
||||||
@@ -177,7 +166,7 @@ class AuthService:
|
|||||||
return AuthService.create_tokens(user)
|
return AuthService.create_tokens(user)
|
||||||
|
|
||||||
except (TokenExpiredError, TokenInvalidError) as e:
|
except (TokenExpiredError, TokenInvalidError) as e:
|
||||||
logger.warning(f"Token refresh failed: {e!s}")
|
logger.warning("Token refresh failed: %s", e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -200,8 +189,7 @@ class AuthService:
|
|||||||
AuthenticationError: If current password is incorrect or update fails
|
AuthenticationError: If current password is incorrect or update fails
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
result = await db.execute(select(User).where(User.id == user_id))
|
user = await user_repo.get(db, id=str(user_id))
|
||||||
user = result.scalar_one_or_none()
|
|
||||||
if not user:
|
if not user:
|
||||||
raise AuthenticationError("User not found")
|
raise AuthenticationError("User not found")
|
||||||
|
|
||||||
@@ -210,10 +198,10 @@ class AuthService:
|
|||||||
raise AuthenticationError("Current password is incorrect")
|
raise AuthenticationError("Current password is incorrect")
|
||||||
|
|
||||||
# Hash new password asynchronously to avoid blocking event loop
|
# Hash new password asynchronously to avoid blocking event loop
|
||||||
user.password_hash = await get_password_hash_async(new_password)
|
new_hash = await get_password_hash_async(new_password)
|
||||||
await db.commit()
|
await user_repo.update_password(db, user=user, password_hash=new_hash)
|
||||||
|
|
||||||
logger.info(f"Password changed successfully for user {user_id}")
|
logger.info("Password changed successfully for user %s", user_id)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except AuthenticationError:
|
except AuthenticationError:
|
||||||
@@ -222,7 +210,34 @@ class AuthService:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Rollback on any database errors
|
# Rollback on any database errors
|
||||||
await db.rollback()
|
await db.rollback()
|
||||||
logger.error(
|
logger.exception("Error changing password for user %s: %s", user_id, e)
|
||||||
f"Error changing password for user {user_id}: {e!s}", exc_info=True
|
|
||||||
)
|
|
||||||
raise AuthenticationError(f"Failed to change password: {e!s}")
|
raise AuthenticationError(f"Failed to change password: {e!s}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def reset_password(
|
||||||
|
db: AsyncSession, *, email: str, new_password: str
|
||||||
|
) -> User:
|
||||||
|
"""
|
||||||
|
Reset a user's password without requiring the current password.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Database session
|
||||||
|
email: User email address
|
||||||
|
new_password: New password to set
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated user
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AuthenticationError: If user not found or inactive
|
||||||
|
"""
|
||||||
|
user = await user_repo.get_by_email(db, email=email)
|
||||||
|
if not user:
|
||||||
|
raise AuthenticationError("User not found")
|
||||||
|
if not user.is_active:
|
||||||
|
raise AuthenticationError("User account is inactive")
|
||||||
|
|
||||||
|
new_hash = await get_password_hash_async(new_password)
|
||||||
|
user = await user_repo.update_password(db, user=user, password_hash=new_hash)
|
||||||
|
logger.info("Password reset successfully for %s", email)
|
||||||
|
return user
|
||||||
|
|||||||
@@ -58,8 +58,8 @@ class ConsoleEmailBackend(EmailBackend):
|
|||||||
logger.info("=" * 80)
|
logger.info("=" * 80)
|
||||||
logger.info("EMAIL SENT (Console Backend)")
|
logger.info("EMAIL SENT (Console Backend)")
|
||||||
logger.info("=" * 80)
|
logger.info("=" * 80)
|
||||||
logger.info(f"To: {', '.join(to)}")
|
logger.info("To: %s", ", ".join(to))
|
||||||
logger.info(f"Subject: {subject}")
|
logger.info("Subject: %s", subject)
|
||||||
logger.info("-" * 80)
|
logger.info("-" * 80)
|
||||||
if text_content:
|
if text_content:
|
||||||
logger.info("Plain Text Content:")
|
logger.info("Plain Text Content:")
|
||||||
@@ -199,7 +199,7 @@ The {settings.PROJECT_NAME} Team
|
|||||||
text_content=text_content,
|
text_content=text_content,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to send password reset email to {to_email}: {e!s}")
|
logger.error("Failed to send password reset email to %s: %s", to_email, e)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def send_email_verification(
|
async def send_email_verification(
|
||||||
@@ -287,7 +287,7 @@ The {settings.PROJECT_NAME} Team
|
|||||||
text_content=text_content,
|
text_content=text_content,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to send verification email to {to_email}: {e!s}")
|
logger.error("Failed to send verification email to %s: %s", to_email, e)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
970
backend/app/services/oauth_provider_service.py
Executable file
970
backend/app/services/oauth_provider_service.py
Executable file
@@ -0,0 +1,970 @@
|
|||||||
|
"""
|
||||||
|
OAuth Provider Service for MCP integration.
|
||||||
|
|
||||||
|
Implements OAuth 2.0 Authorization Server functionality:
|
||||||
|
- Authorization code flow with PKCE
|
||||||
|
- Token issuance (JWT access tokens, opaque refresh tokens)
|
||||||
|
- Token refresh
|
||||||
|
- Token revocation
|
||||||
|
- Consent management
|
||||||
|
|
||||||
|
Security features:
|
||||||
|
- PKCE required for public clients (S256)
|
||||||
|
- Short-lived authorization codes (10 minutes)
|
||||||
|
- JWT access tokens (self-contained, no DB lookup)
|
||||||
|
- Secure refresh token storage (hashed)
|
||||||
|
- Token rotation on refresh
|
||||||
|
- Comprehensive validation
|
||||||
|
"""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
import secrets
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from typing import Any
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
import jwt
|
||||||
|
from jwt.exceptions import ExpiredSignatureError, InvalidTokenError
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
from app.models.oauth_client import OAuthClient
|
||||||
|
from app.models.user import User
|
||||||
|
from app.repositories.oauth_authorization_code import oauth_authorization_code_repo
|
||||||
|
from app.repositories.oauth_client import oauth_client_repo
|
||||||
|
from app.repositories.oauth_consent import oauth_consent_repo
|
||||||
|
from app.repositories.oauth_provider_token import oauth_provider_token_repo
|
||||||
|
from app.repositories.user import user_repo
|
||||||
|
from app.schemas.oauth import OAuthClientCreate
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Constants
|
||||||
|
AUTHORIZATION_CODE_EXPIRY_MINUTES = 10
|
||||||
|
ACCESS_TOKEN_EXPIRY_MINUTES = 60 # 1 hour for MCP clients
|
||||||
|
REFRESH_TOKEN_EXPIRY_DAYS = 30
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthProviderError(Exception):
|
||||||
|
"""Base exception for OAuth provider errors."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
error: str,
|
||||||
|
error_description: str | None = None,
|
||||||
|
error_uri: str | None = None,
|
||||||
|
):
|
||||||
|
self.error = error
|
||||||
|
self.error_description = error_description
|
||||||
|
self.error_uri = error_uri
|
||||||
|
super().__init__(error_description or error)
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidClientError(OAuthProviderError):
|
||||||
|
"""Client authentication failed."""
|
||||||
|
|
||||||
|
def __init__(self, description: str = "Invalid client credentials"):
|
||||||
|
super().__init__("invalid_client", description)
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidGrantError(OAuthProviderError):
|
||||||
|
"""Invalid authorization grant."""
|
||||||
|
|
||||||
|
def __init__(self, description: str = "Invalid grant"):
|
||||||
|
super().__init__("invalid_grant", description)
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidRequestError(OAuthProviderError):
|
||||||
|
"""Invalid request parameters."""
|
||||||
|
|
||||||
|
def __init__(self, description: str = "Invalid request"):
|
||||||
|
super().__init__("invalid_request", description)
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidScopeError(OAuthProviderError):
|
||||||
|
"""Invalid scope requested."""
|
||||||
|
|
||||||
|
def __init__(self, description: str = "Invalid scope"):
|
||||||
|
super().__init__("invalid_scope", description)
|
||||||
|
|
||||||
|
|
||||||
|
class UnauthorizedClientError(OAuthProviderError):
|
||||||
|
"""Client not authorized for this grant type."""
|
||||||
|
|
||||||
|
def __init__(self, description: str = "Unauthorized client"):
|
||||||
|
super().__init__("unauthorized_client", description)
|
||||||
|
|
||||||
|
|
||||||
|
class AccessDeniedError(OAuthProviderError):
|
||||||
|
"""User denied authorization."""
|
||||||
|
|
||||||
|
def __init__(self, description: str = "Access denied"):
|
||||||
|
super().__init__("access_denied", description)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Helper Functions
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def generate_code() -> str:
|
||||||
|
"""Generate a cryptographically secure authorization code."""
|
||||||
|
return secrets.token_urlsafe(64)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_token() -> str:
|
||||||
|
"""Generate a cryptographically secure token."""
|
||||||
|
return secrets.token_urlsafe(48)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_jti() -> str:
|
||||||
|
"""Generate a unique JWT ID."""
|
||||||
|
return secrets.token_urlsafe(32)
|
||||||
|
|
||||||
|
|
||||||
|
def hash_token(token: str) -> str:
|
||||||
|
"""Hash a token using SHA-256."""
|
||||||
|
return hashlib.sha256(token.encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def verify_pkce(code_verifier: str, code_challenge: str, method: str) -> bool:
|
||||||
|
"""
|
||||||
|
Verify PKCE code_verifier against stored code_challenge.
|
||||||
|
|
||||||
|
SECURITY: Only S256 method is supported. The 'plain' method provides
|
||||||
|
no security benefit and is explicitly rejected per RFC 7636 Section 4.3.
|
||||||
|
"""
|
||||||
|
if method != "S256":
|
||||||
|
# SECURITY: Reject any method other than S256
|
||||||
|
# 'plain' method provides no security against code interception attacks
|
||||||
|
logger.warning("PKCE verification rejected for unsupported method: %s", method)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# SHA-256 hash, then base64url encode (RFC 7636 Section 4.2)
|
||||||
|
digest = hashlib.sha256(code_verifier.encode()).digest()
|
||||||
|
computed = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
|
||||||
|
return secrets.compare_digest(computed, code_challenge)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_scope(scope: str) -> list[str]:
|
||||||
|
"""Parse space-separated scope string into list."""
|
||||||
|
return [s.strip() for s in scope.split() if s.strip()]
|
||||||
|
|
||||||
|
|
||||||
|
def join_scope(scopes: list[str]) -> str:
|
||||||
|
"""Join scope list into space-separated string."""
|
||||||
|
return " ".join(sorted(set(scopes)))
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Client Validation
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
async def get_client(db: AsyncSession, client_id: str) -> OAuthClient | None:
|
||||||
|
"""Get OAuth client by client_id."""
|
||||||
|
return await oauth_client_repo.get_by_client_id(db, client_id=client_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def validate_client(
|
||||||
|
db: AsyncSession,
|
||||||
|
client_id: str,
|
||||||
|
client_secret: str | None = None,
|
||||||
|
require_secret: bool = False,
|
||||||
|
) -> OAuthClient:
|
||||||
|
"""
|
||||||
|
Validate OAuth client credentials.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Database session
|
||||||
|
client_id: Client identifier
|
||||||
|
client_secret: Client secret (required for confidential clients)
|
||||||
|
require_secret: Whether to require secret validation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Validated OAuthClient
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidClientError: If client validation fails
|
||||||
|
"""
|
||||||
|
client = await get_client(db, client_id)
|
||||||
|
if not client:
|
||||||
|
raise InvalidClientError("Unknown client_id")
|
||||||
|
|
||||||
|
# Confidential clients must provide valid secret
|
||||||
|
if client.client_type == "confidential" or require_secret:
|
||||||
|
if not client_secret:
|
||||||
|
raise InvalidClientError("Client secret required")
|
||||||
|
if not client.client_secret_hash:
|
||||||
|
raise InvalidClientError("Client not configured with secret")
|
||||||
|
|
||||||
|
# SECURITY: Verify secret using bcrypt
|
||||||
|
from app.core.auth import verify_password
|
||||||
|
|
||||||
|
stored_hash = str(client.client_secret_hash)
|
||||||
|
|
||||||
|
if not stored_hash.startswith("$2"):
|
||||||
|
raise InvalidClientError(
|
||||||
|
"Client secret uses deprecated hash format. "
|
||||||
|
"Please regenerate your client credentials."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not verify_password(client_secret, stored_hash):
|
||||||
|
raise InvalidClientError("Invalid client secret")
|
||||||
|
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
def validate_redirect_uri(client: OAuthClient, redirect_uri: str) -> None:
|
||||||
|
"""
|
||||||
|
Validate redirect_uri against client's registered URIs.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidRequestError: If redirect_uri is not registered
|
||||||
|
"""
|
||||||
|
if not client.redirect_uris:
|
||||||
|
raise InvalidRequestError("Client has no registered redirect URIs")
|
||||||
|
|
||||||
|
if redirect_uri not in client.redirect_uris:
|
||||||
|
raise InvalidRequestError("Invalid redirect_uri")
|
||||||
|
|
||||||
|
|
||||||
|
def validate_scopes(client: OAuthClient, requested_scopes: list[str]) -> list[str]:
|
||||||
|
"""
|
||||||
|
Validate requested scopes against client's allowed scopes.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of valid scopes (intersection of requested and allowed)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidScopeError: If no valid scopes
|
||||||
|
"""
|
||||||
|
allowed = set(client.allowed_scopes or [])
|
||||||
|
requested = set(requested_scopes)
|
||||||
|
|
||||||
|
# If no scopes requested, use all allowed scopes
|
||||||
|
if not requested:
|
||||||
|
return list(allowed)
|
||||||
|
|
||||||
|
valid = requested & allowed
|
||||||
|
if not valid:
|
||||||
|
raise InvalidScopeError(
|
||||||
|
"None of the requested scopes are allowed for this client"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Warn if some scopes were filtered out
|
||||||
|
invalid = requested - allowed
|
||||||
|
if invalid:
|
||||||
|
logger.warning(
|
||||||
|
"Client %s requested invalid scopes: %s", client.client_id, invalid
|
||||||
|
)
|
||||||
|
|
||||||
|
return list(valid)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Authorization Code Flow
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
async def create_authorization_code(
|
||||||
|
db: AsyncSession,
|
||||||
|
client: OAuthClient,
|
||||||
|
user: User,
|
||||||
|
redirect_uri: str,
|
||||||
|
scope: str,
|
||||||
|
code_challenge: str | None = None,
|
||||||
|
code_challenge_method: str | None = None,
|
||||||
|
state: str | None = None,
|
||||||
|
nonce: str | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Create an authorization code for the authorization code flow.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Database session
|
||||||
|
client: Validated OAuth client
|
||||||
|
user: Authenticated user
|
||||||
|
redirect_uri: Validated redirect URI
|
||||||
|
scope: Granted scopes (space-separated)
|
||||||
|
code_challenge: PKCE code challenge
|
||||||
|
code_challenge_method: PKCE method (S256)
|
||||||
|
state: CSRF state parameter
|
||||||
|
nonce: OpenID Connect nonce
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Authorization code string
|
||||||
|
"""
|
||||||
|
# Public clients MUST use PKCE
|
||||||
|
if client.client_type == "public":
|
||||||
|
if not code_challenge or code_challenge_method != "S256":
|
||||||
|
raise InvalidRequestError("PKCE with S256 is required for public clients")
|
||||||
|
|
||||||
|
code = generate_code()
|
||||||
|
expires_at = datetime.now(UTC) + timedelta(
|
||||||
|
minutes=AUTHORIZATION_CODE_EXPIRY_MINUTES
|
||||||
|
)
|
||||||
|
|
||||||
|
await oauth_authorization_code_repo.create_code(
|
||||||
|
db,
|
||||||
|
code=code,
|
||||||
|
client_id=client.client_id,
|
||||||
|
user_id=user.id,
|
||||||
|
redirect_uri=redirect_uri,
|
||||||
|
scope=scope,
|
||||||
|
expires_at=expires_at,
|
||||||
|
code_challenge=code_challenge,
|
||||||
|
code_challenge_method=code_challenge_method,
|
||||||
|
state=state,
|
||||||
|
nonce=nonce,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Created authorization code for user %s and client %s",
|
||||||
|
user.id,
|
||||||
|
client.client_id,
|
||||||
|
)
|
||||||
|
return code
|
||||||
|
|
||||||
|
|
||||||
|
async def exchange_authorization_code(
|
||||||
|
db: AsyncSession,
|
||||||
|
code: str,
|
||||||
|
client_id: str,
|
||||||
|
redirect_uri: str,
|
||||||
|
code_verifier: str | None = None,
|
||||||
|
client_secret: str | None = None,
|
||||||
|
device_info: str | None = None,
|
||||||
|
ip_address: str | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Exchange authorization code for tokens.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Database session
|
||||||
|
code: Authorization code
|
||||||
|
client_id: Client identifier
|
||||||
|
redirect_uri: Must match the original redirect_uri
|
||||||
|
code_verifier: PKCE code verifier
|
||||||
|
client_secret: Client secret (for confidential clients)
|
||||||
|
device_info: Optional device information
|
||||||
|
ip_address: Optional IP address
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Token response dict with access_token, refresh_token, etc.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidGrantError: If code is invalid, expired, or already used
|
||||||
|
InvalidClientError: If client validation fails
|
||||||
|
"""
|
||||||
|
# Atomically mark code as used and fetch it (prevents race condition)
|
||||||
|
# RFC 6749 Section 4.1.2: Authorization codes MUST be single-use
|
||||||
|
updated_id = await oauth_authorization_code_repo.consume_code_atomically(
|
||||||
|
db, code=code
|
||||||
|
)
|
||||||
|
|
||||||
|
if not updated_id:
|
||||||
|
# Either code doesn't exist or was already used
|
||||||
|
# Check if it exists to provide appropriate error
|
||||||
|
existing_code = await oauth_authorization_code_repo.get_by_code(db, code=code)
|
||||||
|
|
||||||
|
if existing_code and existing_code.used:
|
||||||
|
# Code reuse is a security incident - revoke all tokens for this grant
|
||||||
|
logger.warning(
|
||||||
|
"Authorization code reuse detected for client %s",
|
||||||
|
existing_code.client_id,
|
||||||
|
)
|
||||||
|
await revoke_tokens_for_user_client(
|
||||||
|
db, UUID(str(existing_code.user_id)), str(existing_code.client_id)
|
||||||
|
)
|
||||||
|
raise InvalidGrantError("Authorization code has already been used")
|
||||||
|
else:
|
||||||
|
raise InvalidGrantError("Invalid authorization code")
|
||||||
|
|
||||||
|
# Now fetch the full auth code record
|
||||||
|
auth_code = await oauth_authorization_code_repo.get_by_id(db, code_id=updated_id)
|
||||||
|
if auth_code is None:
|
||||||
|
raise InvalidGrantError("Authorization code not found after consumption")
|
||||||
|
|
||||||
|
if auth_code.is_expired:
|
||||||
|
raise InvalidGrantError("Authorization code has expired")
|
||||||
|
|
||||||
|
if auth_code.client_id != client_id:
|
||||||
|
raise InvalidGrantError("Authorization code was not issued to this client")
|
||||||
|
|
||||||
|
if auth_code.redirect_uri != redirect_uri:
|
||||||
|
raise InvalidGrantError("redirect_uri mismatch")
|
||||||
|
|
||||||
|
# Validate client - ALWAYS require secret for confidential clients
|
||||||
|
client = await get_client(db, client_id)
|
||||||
|
if not client:
|
||||||
|
raise InvalidClientError("Unknown client_id")
|
||||||
|
|
||||||
|
# Confidential clients MUST authenticate (RFC 6749 Section 3.2.1)
|
||||||
|
if client.client_type == "confidential":
|
||||||
|
if not client_secret:
|
||||||
|
raise InvalidClientError("Client secret required for confidential clients")
|
||||||
|
client = await validate_client(
|
||||||
|
db, client_id, client_secret, require_secret=True
|
||||||
|
)
|
||||||
|
elif client_secret:
|
||||||
|
# Public client provided secret - validate it if given
|
||||||
|
client = await validate_client(
|
||||||
|
db, client_id, client_secret, require_secret=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify PKCE
|
||||||
|
if auth_code.code_challenge:
|
||||||
|
if not code_verifier:
|
||||||
|
raise InvalidGrantError("code_verifier required")
|
||||||
|
if not verify_pkce(
|
||||||
|
code_verifier,
|
||||||
|
str(auth_code.code_challenge),
|
||||||
|
str(auth_code.code_challenge_method or "S256"),
|
||||||
|
):
|
||||||
|
raise InvalidGrantError("Invalid code_verifier")
|
||||||
|
elif client.client_type == "public":
|
||||||
|
# Public clients without PKCE - this shouldn't happen if we validated on authorize
|
||||||
|
raise InvalidGrantError("PKCE required for public clients")
|
||||||
|
|
||||||
|
# Get user
|
||||||
|
user = await user_repo.get(db, id=str(auth_code.user_id))
|
||||||
|
if not user or not user.is_active:
|
||||||
|
raise InvalidGrantError("User not found or inactive")
|
||||||
|
|
||||||
|
# Generate tokens
|
||||||
|
return await create_tokens(
|
||||||
|
db=db,
|
||||||
|
client=client,
|
||||||
|
user=user,
|
||||||
|
scope=str(auth_code.scope),
|
||||||
|
nonce=str(auth_code.nonce) if auth_code.nonce else None,
|
||||||
|
device_info=device_info,
|
||||||
|
ip_address=ip_address,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Token Generation
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
async def create_tokens(
|
||||||
|
db: AsyncSession,
|
||||||
|
client: OAuthClient,
|
||||||
|
user: User,
|
||||||
|
scope: str,
|
||||||
|
nonce: str | None = None,
|
||||||
|
device_info: str | None = None,
|
||||||
|
ip_address: str | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Create access and refresh tokens.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Database session
|
||||||
|
client: OAuth client
|
||||||
|
user: User
|
||||||
|
scope: Granted scopes
|
||||||
|
nonce: OpenID Connect nonce (included in ID token)
|
||||||
|
device_info: Optional device information
|
||||||
|
ip_address: Optional IP address
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Token response dict
|
||||||
|
"""
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
jti = generate_jti()
|
||||||
|
|
||||||
|
# Access token expiry
|
||||||
|
access_token_lifetime = int(client.access_token_lifetime or "3600")
|
||||||
|
access_expires = now + timedelta(seconds=access_token_lifetime)
|
||||||
|
|
||||||
|
# Refresh token expiry
|
||||||
|
refresh_token_lifetime = int(
|
||||||
|
client.refresh_token_lifetime or str(REFRESH_TOKEN_EXPIRY_DAYS * 86400)
|
||||||
|
)
|
||||||
|
refresh_expires = now + timedelta(seconds=refresh_token_lifetime)
|
||||||
|
|
||||||
|
# Create JWT access token
|
||||||
|
# SECURITY: Include all standard JWT claims per RFC 7519
|
||||||
|
access_token_payload = {
|
||||||
|
"iss": settings.OAUTH_ISSUER,
|
||||||
|
"sub": str(user.id),
|
||||||
|
"aud": client.client_id,
|
||||||
|
"exp": int(access_expires.timestamp()),
|
||||||
|
"iat": int(now.timestamp()),
|
||||||
|
"nbf": int(now.timestamp()), # Not Before - token is valid immediately
|
||||||
|
"jti": jti,
|
||||||
|
"scope": scope,
|
||||||
|
"client_id": client.client_id,
|
||||||
|
# User info (basic claims)
|
||||||
|
"email": user.email,
|
||||||
|
"name": f"{user.first_name or ''} {user.last_name or ''}".strip() or user.email,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add nonce for OpenID Connect
|
||||||
|
if nonce:
|
||||||
|
access_token_payload["nonce"] = nonce
|
||||||
|
|
||||||
|
access_token = jwt.encode(
|
||||||
|
access_token_payload,
|
||||||
|
settings.SECRET_KEY,
|
||||||
|
algorithm=settings.ALGORITHM,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create opaque refresh token
|
||||||
|
refresh_token = generate_token()
|
||||||
|
refresh_token_hash = hash_token(refresh_token)
|
||||||
|
|
||||||
|
# Store refresh token in database
|
||||||
|
await oauth_provider_token_repo.create_token(
|
||||||
|
db,
|
||||||
|
token_hash=refresh_token_hash,
|
||||||
|
jti=jti,
|
||||||
|
client_id=client.client_id,
|
||||||
|
user_id=user.id,
|
||||||
|
scope=scope,
|
||||||
|
expires_at=refresh_expires,
|
||||||
|
device_info=device_info,
|
||||||
|
ip_address=ip_address,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Issued tokens for user %s to client %s", user.id, client.client_id)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"access_token": access_token,
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"expires_in": access_token_lifetime,
|
||||||
|
"refresh_token": refresh_token,
|
||||||
|
"scope": scope,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def refresh_tokens(
|
||||||
|
db: AsyncSession,
|
||||||
|
refresh_token: str,
|
||||||
|
client_id: str,
|
||||||
|
client_secret: str | None = None,
|
||||||
|
scope: str | None = None,
|
||||||
|
device_info: str | None = None,
|
||||||
|
ip_address: str | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Refresh access token using refresh token.
|
||||||
|
|
||||||
|
Implements token rotation - old refresh token is invalidated,
|
||||||
|
new refresh token is issued.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Database session
|
||||||
|
refresh_token: Refresh token
|
||||||
|
client_id: Client identifier
|
||||||
|
client_secret: Client secret (for confidential clients)
|
||||||
|
scope: Optional reduced scope
|
||||||
|
device_info: Optional device information
|
||||||
|
ip_address: Optional IP address
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
New token response dict
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidGrantError: If refresh token is invalid
|
||||||
|
"""
|
||||||
|
# Find refresh token
|
||||||
|
token_hash = hash_token(refresh_token)
|
||||||
|
token_record = await oauth_provider_token_repo.get_by_token_hash(
|
||||||
|
db, token_hash=token_hash
|
||||||
|
)
|
||||||
|
|
||||||
|
if not token_record:
|
||||||
|
raise InvalidGrantError("Invalid refresh token")
|
||||||
|
|
||||||
|
if token_record.revoked:
|
||||||
|
# Token reuse after revocation - security incident
|
||||||
|
logger.warning(
|
||||||
|
"Revoked refresh token reuse detected for client %s", token_record.client_id
|
||||||
|
)
|
||||||
|
raise InvalidGrantError("Refresh token has been revoked")
|
||||||
|
|
||||||
|
if token_record.is_expired:
|
||||||
|
raise InvalidGrantError("Refresh token has expired")
|
||||||
|
|
||||||
|
if token_record.client_id != client_id:
|
||||||
|
raise InvalidGrantError("Refresh token was not issued to this client")
|
||||||
|
|
||||||
|
# Validate client
|
||||||
|
client = await validate_client(
|
||||||
|
db,
|
||||||
|
client_id,
|
||||||
|
client_secret,
|
||||||
|
require_secret=(client_secret is not None),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get user
|
||||||
|
user = await user_repo.get(db, id=str(token_record.user_id))
|
||||||
|
if not user or not user.is_active:
|
||||||
|
raise InvalidGrantError("User not found or inactive")
|
||||||
|
|
||||||
|
# Validate scope (can only reduce, not expand)
|
||||||
|
token_scope = str(token_record.scope) if token_record.scope else ""
|
||||||
|
original_scopes = set(parse_scope(token_scope))
|
||||||
|
if scope:
|
||||||
|
requested_scopes = set(parse_scope(scope))
|
||||||
|
if not requested_scopes.issubset(original_scopes):
|
||||||
|
raise InvalidScopeError("Cannot expand scope on refresh")
|
||||||
|
final_scope = join_scope(list(requested_scopes))
|
||||||
|
else:
|
||||||
|
final_scope = token_scope
|
||||||
|
|
||||||
|
# Revoke old refresh token (token rotation)
|
||||||
|
await oauth_provider_token_repo.revoke(db, token=token_record)
|
||||||
|
|
||||||
|
# Issue new tokens
|
||||||
|
device = str(token_record.device_info) if token_record.device_info else None
|
||||||
|
ip_addr = str(token_record.ip_address) if token_record.ip_address else None
|
||||||
|
return await create_tokens(
|
||||||
|
db=db,
|
||||||
|
client=client,
|
||||||
|
user=user,
|
||||||
|
scope=final_scope,
|
||||||
|
device_info=device_info or device,
|
||||||
|
ip_address=ip_address or ip_addr,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Token Revocation
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
async def revoke_token(
|
||||||
|
db: AsyncSession,
|
||||||
|
token: str,
|
||||||
|
token_type_hint: str | None = None,
|
||||||
|
client_id: str | None = None,
|
||||||
|
client_secret: str | None = None,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Revoke a token (access or refresh).
|
||||||
|
|
||||||
|
For refresh tokens: marks as revoked in database
|
||||||
|
For access tokens: we can't truly revoke JWTs, but we can revoke
|
||||||
|
the associated refresh token to prevent further refreshes
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Database session
|
||||||
|
token: Token to revoke
|
||||||
|
token_type_hint: "access_token" or "refresh_token"
|
||||||
|
client_id: Client identifier (for validation)
|
||||||
|
client_secret: Client secret (for confidential clients)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if token was revoked, False if not found
|
||||||
|
"""
|
||||||
|
# Try as refresh token first (more likely)
|
||||||
|
if token_type_hint != "access_token":
|
||||||
|
token_hash = hash_token(token)
|
||||||
|
refresh_record = await oauth_provider_token_repo.get_by_token_hash(
|
||||||
|
db, token_hash=token_hash
|
||||||
|
)
|
||||||
|
|
||||||
|
if refresh_record:
|
||||||
|
# Validate client if provided
|
||||||
|
if client_id and refresh_record.client_id != client_id:
|
||||||
|
raise InvalidClientError("Token was not issued to this client")
|
||||||
|
|
||||||
|
await oauth_provider_token_repo.revoke(db, token=refresh_record)
|
||||||
|
logger.info("Revoked refresh token %s...", refresh_record.jti[:8])
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Try as access token (JWT)
|
||||||
|
if token_type_hint != "refresh_token":
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(
|
||||||
|
token,
|
||||||
|
settings.SECRET_KEY,
|
||||||
|
algorithms=[settings.ALGORITHM],
|
||||||
|
options={
|
||||||
|
"verify_exp": False,
|
||||||
|
"verify_aud": False,
|
||||||
|
}, # Allow expired tokens
|
||||||
|
)
|
||||||
|
jti = payload.get("jti")
|
||||||
|
if jti:
|
||||||
|
# Find and revoke the associated refresh token
|
||||||
|
refresh_record = await oauth_provider_token_repo.get_by_jti(db, jti=jti)
|
||||||
|
if refresh_record:
|
||||||
|
if client_id and refresh_record.client_id != client_id:
|
||||||
|
raise InvalidClientError("Token was not issued to this client")
|
||||||
|
await oauth_provider_token_repo.revoke(db, token=refresh_record)
|
||||||
|
logger.info(
|
||||||
|
"Revoked refresh token via access token JTI %s...", jti[:8]
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
except InvalidTokenError:
|
||||||
|
pass
|
||||||
|
except Exception: # noqa: S110 - Intentional: invalid JWT not an error
|
||||||
|
pass
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def revoke_tokens_for_user_client(
|
||||||
|
db: AsyncSession,
|
||||||
|
user_id: UUID,
|
||||||
|
client_id: str,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Revoke all tokens for a specific user-client pair.
|
||||||
|
|
||||||
|
Used when security incidents are detected (e.g., code reuse).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Database session
|
||||||
|
user_id: User identifier
|
||||||
|
client_id: Client identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of tokens revoked
|
||||||
|
"""
|
||||||
|
count = await oauth_provider_token_repo.revoke_all_for_user_client(
|
||||||
|
db, user_id=user_id, client_id=client_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if count > 0:
|
||||||
|
logger.warning(
|
||||||
|
"Revoked %s tokens for user %s and client %s", count, user_id, client_id
|
||||||
|
)
|
||||||
|
|
||||||
|
return count
|
||||||
|
|
||||||
|
|
||||||
|
async def revoke_all_user_tokens(db: AsyncSession, user_id: UUID) -> int:
|
||||||
|
"""
|
||||||
|
Revoke all OAuth provider tokens for a user.
|
||||||
|
|
||||||
|
Used when user changes password or explicitly logs out everywhere.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Database session
|
||||||
|
user_id: User identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of tokens revoked
|
||||||
|
"""
|
||||||
|
count = await oauth_provider_token_repo.revoke_all_for_user(db, user_id=user_id)
|
||||||
|
|
||||||
|
if count > 0:
|
||||||
|
logger.info("Revoked %s OAuth provider tokens for user %s", count, user_id)
|
||||||
|
|
||||||
|
return count
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Token Introspection (RFC 7662)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
async def introspect_token(
|
||||||
|
db: AsyncSession,
|
||||||
|
token: str,
|
||||||
|
token_type_hint: str | None = None,
|
||||||
|
client_id: str | None = None,
|
||||||
|
client_secret: str | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Introspect a token to determine its validity and metadata.
|
||||||
|
|
||||||
|
Implements RFC 7662 Token Introspection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Database session
|
||||||
|
token: Token to introspect
|
||||||
|
token_type_hint: "access_token" or "refresh_token"
|
||||||
|
client_id: Client requesting introspection
|
||||||
|
client_secret: Client secret
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Introspection response dict
|
||||||
|
"""
|
||||||
|
# Validate client if credentials provided
|
||||||
|
if client_id:
|
||||||
|
await validate_client(db, client_id, client_secret)
|
||||||
|
|
||||||
|
# Try as access token (JWT) first
|
||||||
|
if token_type_hint != "refresh_token":
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(
|
||||||
|
token,
|
||||||
|
settings.SECRET_KEY,
|
||||||
|
algorithms=[settings.ALGORITHM],
|
||||||
|
options={
|
||||||
|
"verify_aud": False
|
||||||
|
}, # Don't require audience match for introspection
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if associated refresh token is revoked
|
||||||
|
jti = payload.get("jti")
|
||||||
|
if jti:
|
||||||
|
refresh_record = await oauth_provider_token_repo.get_by_jti(db, jti=jti)
|
||||||
|
if refresh_record and refresh_record.revoked:
|
||||||
|
return {"active": False}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"active": True,
|
||||||
|
"scope": payload.get("scope", ""),
|
||||||
|
"client_id": payload.get("client_id"),
|
||||||
|
"username": payload.get("email"),
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"exp": payload.get("exp"),
|
||||||
|
"iat": payload.get("iat"),
|
||||||
|
"sub": payload.get("sub"),
|
||||||
|
"aud": payload.get("aud"),
|
||||||
|
"iss": payload.get("iss"),
|
||||||
|
}
|
||||||
|
except ExpiredSignatureError:
|
||||||
|
return {"active": False}
|
||||||
|
except InvalidTokenError:
|
||||||
|
pass
|
||||||
|
except Exception: # noqa: S110 - Intentional: invalid JWT falls through to refresh token check
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Try as refresh token
|
||||||
|
if token_type_hint != "access_token":
|
||||||
|
token_hash = hash_token(token)
|
||||||
|
refresh_record = await oauth_provider_token_repo.get_by_token_hash(
|
||||||
|
db, token_hash=token_hash
|
||||||
|
)
|
||||||
|
|
||||||
|
if refresh_record and refresh_record.is_valid:
|
||||||
|
return {
|
||||||
|
"active": True,
|
||||||
|
"scope": refresh_record.scope,
|
||||||
|
"client_id": refresh_record.client_id,
|
||||||
|
"token_type": "refresh_token",
|
||||||
|
"exp": int(refresh_record.expires_at.timestamp()),
|
||||||
|
"iat": int(refresh_record.created_at.timestamp()),
|
||||||
|
"sub": str(refresh_record.user_id),
|
||||||
|
}
|
||||||
|
|
||||||
|
return {"active": False}
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Consent Management
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
async def get_consent(
|
||||||
|
db: AsyncSession,
|
||||||
|
user_id: UUID,
|
||||||
|
client_id: str,
|
||||||
|
):
|
||||||
|
"""Get existing consent record for user-client pair."""
|
||||||
|
return await oauth_consent_repo.get_consent(
|
||||||
|
db, user_id=user_id, client_id=client_id
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def check_consent(
|
||||||
|
db: AsyncSession,
|
||||||
|
user_id: UUID,
|
||||||
|
client_id: str,
|
||||||
|
requested_scopes: list[str],
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Check if user has already consented to the requested scopes.
|
||||||
|
|
||||||
|
Returns True if all requested scopes are already granted.
|
||||||
|
"""
|
||||||
|
consent = await get_consent(db, user_id, client_id)
|
||||||
|
if not consent:
|
||||||
|
return False
|
||||||
|
return consent.has_scopes(requested_scopes)
|
||||||
|
|
||||||
|
|
||||||
|
async def grant_consent(
|
||||||
|
db: AsyncSession,
|
||||||
|
user_id: UUID,
|
||||||
|
client_id: str,
|
||||||
|
scopes: list[str],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Grant or update consent for a user-client pair.
|
||||||
|
|
||||||
|
If consent already exists, updates the granted scopes.
|
||||||
|
"""
|
||||||
|
return await oauth_consent_repo.grant_consent(
|
||||||
|
db, user_id=user_id, client_id=client_id, scopes=scopes
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def revoke_consent(
|
||||||
|
db: AsyncSession,
|
||||||
|
user_id: UUID,
|
||||||
|
client_id: str,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Revoke consent and all tokens for a user-client pair.
|
||||||
|
|
||||||
|
Returns True if consent was found and revoked.
|
||||||
|
"""
|
||||||
|
# Revoke all tokens first
|
||||||
|
await revoke_tokens_for_user_client(db, user_id, client_id)
|
||||||
|
|
||||||
|
# Delete consent record
|
||||||
|
return await oauth_consent_repo.revoke_consent(
|
||||||
|
db, user_id=user_id, client_id=client_id
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Cleanup
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
async def register_client(db: AsyncSession, client_data: OAuthClientCreate) -> tuple:
|
||||||
|
"""Create a new OAuth client. Returns (client, secret)."""
|
||||||
|
return await oauth_client_repo.create_client(db, obj_in=client_data)
|
||||||
|
|
||||||
|
|
||||||
|
async def list_clients(db: AsyncSession) -> list:
|
||||||
|
"""List all registered OAuth clients."""
|
||||||
|
return await oauth_client_repo.get_all_clients(db)
|
||||||
|
|
||||||
|
|
||||||
|
async def delete_client_by_id(db: AsyncSession, client_id: str) -> None:
|
||||||
|
"""Delete an OAuth client by client_id."""
|
||||||
|
await oauth_client_repo.delete_client(db, client_id=client_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def list_user_consents(db: AsyncSession, user_id: UUID) -> list[dict]:
|
||||||
|
"""Get all OAuth consents for a user with client details."""
|
||||||
|
return await oauth_consent_repo.get_user_consents_with_clients(db, user_id=user_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def cleanup_expired_codes(db: AsyncSession) -> int:
|
||||||
|
"""
|
||||||
|
Delete expired authorization codes.
|
||||||
|
|
||||||
|
Should be called periodically (e.g., every hour).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of codes deleted
|
||||||
|
"""
|
||||||
|
return await oauth_authorization_code_repo.cleanup_expired(db)
|
||||||
|
|
||||||
|
|
||||||
|
async def cleanup_expired_tokens(db: AsyncSession) -> int:
|
||||||
|
"""
|
||||||
|
Delete expired and revoked refresh tokens.
|
||||||
|
|
||||||
|
Should be called periodically (e.g., daily).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of tokens deleted
|
||||||
|
"""
|
||||||
|
return await oauth_provider_token_repo.cleanup_expired(db, cutoff_days=7)
|
||||||
744
backend/app/services/oauth_service.py
Normal file
744
backend/app/services/oauth_service.py
Normal file
@@ -0,0 +1,744 @@
|
|||||||
|
"""
|
||||||
|
OAuth Service for handling social authentication flows.
|
||||||
|
|
||||||
|
Supports:
|
||||||
|
- Google OAuth (OpenID Connect)
|
||||||
|
- GitHub OAuth
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- PKCE support for public clients
|
||||||
|
- State parameter for CSRF protection
|
||||||
|
- Auto-linking by email (configurable)
|
||||||
|
- Account linking for existing users
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import secrets
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from typing import TypedDict, cast
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from authlib.integrations.httpx_client import AsyncOAuth2Client
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.core.auth import create_access_token, create_refresh_token
|
||||||
|
from app.core.config import settings
|
||||||
|
from app.core.exceptions import AuthenticationError
|
||||||
|
from app.models.user import User
|
||||||
|
from app.repositories.oauth_account import oauth_account_repo as oauth_account
|
||||||
|
from app.repositories.oauth_state import oauth_state_repo as oauth_state
|
||||||
|
from app.repositories.user import user_repo
|
||||||
|
from app.schemas.oauth import (
|
||||||
|
OAuthAccountCreate,
|
||||||
|
OAuthCallbackResponse,
|
||||||
|
OAuthProviderInfo,
|
||||||
|
OAuthProvidersResponse,
|
||||||
|
OAuthStateCreate,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class _OAuthProviderConfigRequired(TypedDict):
|
||||||
|
name: str
|
||||||
|
icon: str
|
||||||
|
authorize_url: str
|
||||||
|
token_url: str
|
||||||
|
userinfo_url: str
|
||||||
|
scopes: list[str]
|
||||||
|
supports_pkce: bool
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthProviderConfig(_OAuthProviderConfigRequired, total=False):
|
||||||
|
"""Type definition for OAuth provider configuration."""
|
||||||
|
|
||||||
|
email_url: str # Optional, GitHub-only
|
||||||
|
|
||||||
|
|
||||||
|
# Provider configurations
|
||||||
|
OAUTH_PROVIDERS: dict[str, OAuthProviderConfig] = {
|
||||||
|
"google": {
|
||||||
|
"name": "Google",
|
||||||
|
"icon": "google",
|
||||||
|
"authorize_url": "https://accounts.google.com/o/oauth2/v2/auth",
|
||||||
|
"token_url": "https://oauth2.googleapis.com/token",
|
||||||
|
"userinfo_url": "https://www.googleapis.com/oauth2/v3/userinfo",
|
||||||
|
"scopes": ["openid", "email", "profile"],
|
||||||
|
"supports_pkce": True,
|
||||||
|
},
|
||||||
|
"github": {
|
||||||
|
"name": "GitHub",
|
||||||
|
"icon": "github",
|
||||||
|
"authorize_url": "https://github.com/login/oauth/authorize",
|
||||||
|
"token_url": "https://github.com/login/oauth/access_token",
|
||||||
|
"userinfo_url": "https://api.github.com/user",
|
||||||
|
"email_url": "https://api.github.com/user/emails",
|
||||||
|
"scopes": ["read:user", "user:email"],
|
||||||
|
"supports_pkce": False, # GitHub doesn't support PKCE
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthService:
|
||||||
|
"""Service for handling OAuth authentication flows."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_enabled_providers() -> OAuthProvidersResponse:
|
||||||
|
"""
|
||||||
|
Get list of enabled OAuth providers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OAuthProvidersResponse with enabled providers
|
||||||
|
"""
|
||||||
|
providers = []
|
||||||
|
|
||||||
|
for provider_id in settings.enabled_oauth_providers:
|
||||||
|
if provider_id in OAUTH_PROVIDERS:
|
||||||
|
config = OAUTH_PROVIDERS[provider_id]
|
||||||
|
providers.append(
|
||||||
|
OAuthProviderInfo(
|
||||||
|
provider=provider_id,
|
||||||
|
name=config["name"],
|
||||||
|
icon=config["icon"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return OAuthProvidersResponse(
|
||||||
|
enabled=settings.OAUTH_ENABLED and len(providers) > 0,
|
||||||
|
providers=providers,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_provider_credentials(provider: str) -> tuple[str, str]:
|
||||||
|
"""Get client ID and secret for a provider."""
|
||||||
|
if provider == "google":
|
||||||
|
client_id = settings.OAUTH_GOOGLE_CLIENT_ID
|
||||||
|
client_secret = settings.OAUTH_GOOGLE_CLIENT_SECRET
|
||||||
|
elif provider == "github":
|
||||||
|
client_id = settings.OAUTH_GITHUB_CLIENT_ID
|
||||||
|
client_secret = settings.OAUTH_GITHUB_CLIENT_SECRET
|
||||||
|
else:
|
||||||
|
raise AuthenticationError(f"Unknown OAuth provider: {provider}")
|
||||||
|
|
||||||
|
if not client_id or not client_secret:
|
||||||
|
raise AuthenticationError(f"OAuth provider {provider} is not configured")
|
||||||
|
|
||||||
|
return client_id, client_secret
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def create_authorization_url(
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
provider: str,
|
||||||
|
redirect_uri: str,
|
||||||
|
user_id: str | None = None,
|
||||||
|
) -> tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Create OAuth authorization URL with state and optional PKCE.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Database session
|
||||||
|
provider: OAuth provider (google, github)
|
||||||
|
redirect_uri: Callback URL after OAuth
|
||||||
|
user_id: User ID if linking account (user is logged in)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (authorization_url, state)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AuthenticationError: If provider is not configured
|
||||||
|
"""
|
||||||
|
if not settings.OAUTH_ENABLED:
|
||||||
|
raise AuthenticationError("OAuth is not enabled")
|
||||||
|
|
||||||
|
if provider not in OAUTH_PROVIDERS:
|
||||||
|
raise AuthenticationError(f"Unknown OAuth provider: {provider}")
|
||||||
|
|
||||||
|
if provider not in settings.enabled_oauth_providers:
|
||||||
|
raise AuthenticationError(f"OAuth provider {provider} is not enabled")
|
||||||
|
|
||||||
|
config = OAUTH_PROVIDERS[provider]
|
||||||
|
client_id, client_secret = OAuthService._get_provider_credentials(provider)
|
||||||
|
|
||||||
|
# Generate state for CSRF protection
|
||||||
|
state = secrets.token_urlsafe(32)
|
||||||
|
|
||||||
|
# Generate PKCE code verifier and challenge if supported
|
||||||
|
code_verifier = None
|
||||||
|
code_challenge = None
|
||||||
|
if config.get("supports_pkce"):
|
||||||
|
code_verifier = secrets.token_urlsafe(64)
|
||||||
|
# Create code_challenge using S256 method
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
|
||||||
|
code_challenge_bytes = hashlib.sha256(code_verifier.encode()).digest()
|
||||||
|
code_challenge = (
|
||||||
|
base64.urlsafe_b64encode(code_challenge_bytes).decode().rstrip("=")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate nonce for OIDC (Google)
|
||||||
|
nonce = secrets.token_urlsafe(32) if provider == "google" else None
|
||||||
|
|
||||||
|
# Store state in database
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
state_data = OAuthStateCreate(
|
||||||
|
state=state,
|
||||||
|
code_verifier=code_verifier,
|
||||||
|
nonce=nonce,
|
||||||
|
provider=provider,
|
||||||
|
redirect_uri=redirect_uri,
|
||||||
|
user_id=UUID(user_id) if user_id else None,
|
||||||
|
expires_at=datetime.now(UTC)
|
||||||
|
+ timedelta(minutes=settings.OAUTH_STATE_EXPIRE_MINUTES),
|
||||||
|
)
|
||||||
|
await oauth_state.create_state(db, obj_in=state_data)
|
||||||
|
|
||||||
|
# Build authorization URL
|
||||||
|
async with AsyncOAuth2Client(
|
||||||
|
client_id=client_id,
|
||||||
|
client_secret=client_secret,
|
||||||
|
redirect_uri=redirect_uri,
|
||||||
|
) as client:
|
||||||
|
# Prepare authorization params
|
||||||
|
auth_params = {
|
||||||
|
"state": state,
|
||||||
|
"scope": " ".join(config["scopes"]),
|
||||||
|
}
|
||||||
|
|
||||||
|
if code_challenge:
|
||||||
|
auth_params["code_challenge"] = code_challenge
|
||||||
|
auth_params["code_challenge_method"] = "S256"
|
||||||
|
|
||||||
|
if nonce:
|
||||||
|
auth_params["nonce"] = nonce
|
||||||
|
|
||||||
|
url, _ = client.create_authorization_url(
|
||||||
|
config["authorize_url"],
|
||||||
|
**auth_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("OAuth authorization URL created for %s", provider)
|
||||||
|
return url, state
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def handle_callback(
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
code: str,
|
||||||
|
state: str,
|
||||||
|
redirect_uri: str,
|
||||||
|
) -> OAuthCallbackResponse:
|
||||||
|
"""
|
||||||
|
Handle OAuth callback and authenticate/create user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Database session
|
||||||
|
code: Authorization code from provider
|
||||||
|
state: State parameter for CSRF verification
|
||||||
|
redirect_uri: Callback URL (must match authorization request)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OAuthCallbackResponse with tokens
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AuthenticationError: If authentication fails
|
||||||
|
"""
|
||||||
|
# Validate and consume state
|
||||||
|
state_record = await oauth_state.get_and_consume_state(db, state=state)
|
||||||
|
if not state_record:
|
||||||
|
raise AuthenticationError("Invalid or expired OAuth state")
|
||||||
|
|
||||||
|
# SECURITY: Validate redirect_uri matches the one from authorization request
|
||||||
|
# This prevents authorization code injection attacks (RFC 6749 Section 10.6)
|
||||||
|
if state_record.redirect_uri != redirect_uri:
|
||||||
|
logger.warning(
|
||||||
|
"OAuth redirect_uri mismatch: expected %s, got %s",
|
||||||
|
state_record.redirect_uri,
|
||||||
|
redirect_uri,
|
||||||
|
)
|
||||||
|
raise AuthenticationError("Redirect URI mismatch")
|
||||||
|
|
||||||
|
# Extract provider from state record (str for type safety)
|
||||||
|
provider: str = str(state_record.provider)
|
||||||
|
|
||||||
|
if provider not in OAUTH_PROVIDERS:
|
||||||
|
raise AuthenticationError(f"Unknown OAuth provider: {provider}")
|
||||||
|
|
||||||
|
config = OAUTH_PROVIDERS[provider]
|
||||||
|
client_id, client_secret = OAuthService._get_provider_credentials(provider)
|
||||||
|
|
||||||
|
# Exchange code for tokens
|
||||||
|
async with AsyncOAuth2Client(
|
||||||
|
client_id=client_id,
|
||||||
|
client_secret=client_secret,
|
||||||
|
redirect_uri=redirect_uri,
|
||||||
|
) as client:
|
||||||
|
try:
|
||||||
|
# Prepare token request params
|
||||||
|
token_params: dict[str, str] = {"code": code}
|
||||||
|
|
||||||
|
if state_record.code_verifier:
|
||||||
|
token_params["code_verifier"] = str(state_record.code_verifier)
|
||||||
|
|
||||||
|
token = await client.fetch_token(
|
||||||
|
config["token_url"],
|
||||||
|
**token_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
# SECURITY: Validate ID token signature and nonce for OpenID Connect
|
||||||
|
# This prevents token forgery and replay attacks (OIDC Core 3.1.3.7)
|
||||||
|
if provider == "google" and state_record.nonce:
|
||||||
|
id_token = token.get("id_token")
|
||||||
|
if id_token:
|
||||||
|
await OAuthService._verify_google_id_token(
|
||||||
|
id_token=str(id_token),
|
||||||
|
expected_nonce=str(state_record.nonce),
|
||||||
|
client_id=client_id,
|
||||||
|
)
|
||||||
|
except AuthenticationError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("OAuth token exchange failed: %s", e)
|
||||||
|
raise AuthenticationError("Failed to exchange authorization code")
|
||||||
|
|
||||||
|
# Get user info from provider
|
||||||
|
try:
|
||||||
|
access_token = token.get("access_token")
|
||||||
|
if not access_token:
|
||||||
|
raise AuthenticationError("No access token received")
|
||||||
|
|
||||||
|
user_info = await OAuthService._get_user_info(
|
||||||
|
client, provider, config, access_token
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to get user info: %s", e)
|
||||||
|
raise AuthenticationError(
|
||||||
|
"Failed to get user information from provider"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process user info and create/link account
|
||||||
|
provider_user_id = str(user_info.get("id") or user_info.get("sub"))
|
||||||
|
# Email can be None if user didn't grant email permission
|
||||||
|
# SECURITY: Normalize email (lowercase, strip) to prevent case-based account duplication
|
||||||
|
email_raw = user_info.get("email")
|
||||||
|
provider_email: str | None = (
|
||||||
|
str(email_raw).lower().strip() if email_raw else None
|
||||||
|
)
|
||||||
|
|
||||||
|
if not provider_user_id:
|
||||||
|
raise AuthenticationError("Provider did not return user ID")
|
||||||
|
|
||||||
|
# Check if this OAuth account already exists
|
||||||
|
existing_oauth = await oauth_account.get_by_provider_id(
|
||||||
|
db, provider=provider, provider_user_id=provider_user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
is_new_user = False
|
||||||
|
|
||||||
|
if existing_oauth:
|
||||||
|
# Existing OAuth account - login
|
||||||
|
user = existing_oauth.user
|
||||||
|
if not user.is_active:
|
||||||
|
raise AuthenticationError("User account is inactive")
|
||||||
|
|
||||||
|
# Update tokens if stored
|
||||||
|
if token.get("access_token"):
|
||||||
|
await oauth_account.update_tokens(
|
||||||
|
db,
|
||||||
|
account=existing_oauth,
|
||||||
|
access_token=token.get("access_token"),
|
||||||
|
refresh_token=token.get("refresh_token"),
|
||||||
|
token_expires_at=datetime.now(UTC)
|
||||||
|
+ timedelta(seconds=token.get("expires_in", 3600)),
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("OAuth login successful for %s via %s", user.email, provider)
|
||||||
|
|
||||||
|
elif state_record.user_id:
|
||||||
|
# Account linking flow (user is already logged in)
|
||||||
|
user = await user_repo.get(db, id=str(state_record.user_id))
|
||||||
|
|
||||||
|
if not user:
|
||||||
|
raise AuthenticationError("User not found for account linking")
|
||||||
|
|
||||||
|
# Check if user already has this provider linked
|
||||||
|
user_id = cast(UUID, user.id)
|
||||||
|
existing_provider = await oauth_account.get_user_account_by_provider(
|
||||||
|
db, user_id=user_id, provider=provider
|
||||||
|
)
|
||||||
|
if existing_provider:
|
||||||
|
raise AuthenticationError(
|
||||||
|
f"You already have a {provider} account linked"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create OAuth account link
|
||||||
|
oauth_create = OAuthAccountCreate(
|
||||||
|
user_id=user_id,
|
||||||
|
provider=provider,
|
||||||
|
provider_user_id=provider_user_id,
|
||||||
|
provider_email=provider_email,
|
||||||
|
access_token=token.get("access_token"),
|
||||||
|
refresh_token=token.get("refresh_token"),
|
||||||
|
token_expires_at=datetime.now(UTC)
|
||||||
|
+ timedelta(seconds=token.get("expires_in", 3600))
|
||||||
|
if token.get("expires_in")
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
await oauth_account.create_account(db, obj_in=oauth_create)
|
||||||
|
|
||||||
|
logger.info("OAuth account linked: %s -> %s", provider, user.email)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# New OAuth login - check for existing user by email
|
||||||
|
user = None
|
||||||
|
|
||||||
|
if provider_email and settings.OAUTH_AUTO_LINK_BY_EMAIL:
|
||||||
|
user = await user_repo.get_by_email(db, email=provider_email)
|
||||||
|
|
||||||
|
if user:
|
||||||
|
# Auto-link to existing user
|
||||||
|
if not user.is_active:
|
||||||
|
raise AuthenticationError("User account is inactive")
|
||||||
|
|
||||||
|
# Check if user already has this provider linked
|
||||||
|
user_id = cast(UUID, user.id)
|
||||||
|
existing_provider = await oauth_account.get_user_account_by_provider(
|
||||||
|
db, user_id=user_id, provider=provider
|
||||||
|
)
|
||||||
|
if existing_provider:
|
||||||
|
# This shouldn't happen if we got here, but safety check
|
||||||
|
logger.warning(
|
||||||
|
"OAuth account already linked (race condition?): %s -> %s",
|
||||||
|
provider,
|
||||||
|
user.email,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Create OAuth account link
|
||||||
|
oauth_create = OAuthAccountCreate(
|
||||||
|
user_id=user_id,
|
||||||
|
provider=provider,
|
||||||
|
provider_user_id=provider_user_id,
|
||||||
|
provider_email=provider_email,
|
||||||
|
access_token=token.get("access_token"),
|
||||||
|
refresh_token=token.get("refresh_token"),
|
||||||
|
token_expires_at=datetime.now(UTC)
|
||||||
|
+ timedelta(seconds=token.get("expires_in", 3600))
|
||||||
|
if token.get("expires_in")
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
await oauth_account.create_account(db, obj_in=oauth_create)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"OAuth auto-linked by email: %s -> %s", provider, user.email
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Create new user
|
||||||
|
if not provider_email:
|
||||||
|
raise AuthenticationError(
|
||||||
|
f"Email is required for registration. "
|
||||||
|
f"Please grant email permission to {provider}."
|
||||||
|
)
|
||||||
|
|
||||||
|
user = await OAuthService._create_oauth_user(
|
||||||
|
db,
|
||||||
|
email=provider_email,
|
||||||
|
provider=provider,
|
||||||
|
provider_user_id=provider_user_id,
|
||||||
|
user_info=user_info,
|
||||||
|
token=token,
|
||||||
|
)
|
||||||
|
is_new_user = True
|
||||||
|
|
||||||
|
logger.info("New user created via OAuth: %s (%s)", user.email, provider)
|
||||||
|
|
||||||
|
# Generate JWT tokens
|
||||||
|
claims = {
|
||||||
|
"is_superuser": user.is_superuser,
|
||||||
|
"email": user.email,
|
||||||
|
"first_name": user.first_name,
|
||||||
|
}
|
||||||
|
|
||||||
|
access_token_jwt = create_access_token(subject=str(user.id), claims=claims)
|
||||||
|
refresh_token_jwt = create_refresh_token(subject=str(user.id))
|
||||||
|
|
||||||
|
return OAuthCallbackResponse(
|
||||||
|
access_token=access_token_jwt,
|
||||||
|
refresh_token=refresh_token_jwt,
|
||||||
|
token_type="bearer",
|
||||||
|
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||||||
|
is_new_user=is_new_user,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _get_user_info(
|
||||||
|
client: AsyncOAuth2Client,
|
||||||
|
provider: str,
|
||||||
|
config: OAuthProviderConfig,
|
||||||
|
access_token: str,
|
||||||
|
) -> dict[str, object]:
|
||||||
|
"""Get user info from OAuth provider."""
|
||||||
|
headers = {"Authorization": f"Bearer {access_token}"}
|
||||||
|
|
||||||
|
if provider == "github":
|
||||||
|
# GitHub returns JSON with Accept header
|
||||||
|
headers["Accept"] = "application/vnd.github+json"
|
||||||
|
|
||||||
|
resp = await client.get(config["userinfo_url"], headers=headers)
|
||||||
|
resp.raise_for_status()
|
||||||
|
user_info = resp.json()
|
||||||
|
|
||||||
|
# GitHub requires separate request for email
|
||||||
|
if provider == "github" and not user_info.get("email"):
|
||||||
|
email_resp = await client.get(
|
||||||
|
config["email_url"], # pyright: ignore[reportTypedDictNotRequiredAccess]
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
email_resp.raise_for_status()
|
||||||
|
emails = email_resp.json()
|
||||||
|
|
||||||
|
# Find primary verified email
|
||||||
|
for email_data in emails:
|
||||||
|
if email_data.get("primary") and email_data.get("verified"):
|
||||||
|
user_info["email"] = email_data["email"]
|
||||||
|
break
|
||||||
|
|
||||||
|
return user_info
|
||||||
|
|
||||||
|
# Google's OIDC configuration endpoints
|
||||||
|
GOOGLE_JWKS_URL = "https://www.googleapis.com/oauth2/v3/certs"
|
||||||
|
GOOGLE_ISSUERS = ("https://accounts.google.com", "accounts.google.com")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _verify_google_id_token(
|
||||||
|
id_token: str,
|
||||||
|
expected_nonce: str,
|
||||||
|
client_id: str,
|
||||||
|
) -> dict[str, object]:
|
||||||
|
"""
|
||||||
|
Verify Google ID token signature and claims.
|
||||||
|
|
||||||
|
SECURITY: This properly verifies the ID token by:
|
||||||
|
1. Fetching Google's public keys (JWKS)
|
||||||
|
2. Verifying the JWT signature against the public key
|
||||||
|
3. Validating issuer, audience, expiry, and nonce claims
|
||||||
|
|
||||||
|
Args:
|
||||||
|
id_token: The ID token JWT string
|
||||||
|
expected_nonce: The nonce we sent in the authorization request
|
||||||
|
client_id: Our OAuth client ID (expected audience)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Decoded ID token payload
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AuthenticationError: If verification fails
|
||||||
|
"""
|
||||||
|
import httpx
|
||||||
|
import jwt as pyjwt
|
||||||
|
from jwt.algorithms import RSAAlgorithm
|
||||||
|
from jwt.exceptions import InvalidTokenError
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Fetch Google's public keys (JWKS)
|
||||||
|
# In production, consider caching this with TTL matching Cache-Control header
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
jwks_response = await client.get(
|
||||||
|
OAuthService.GOOGLE_JWKS_URL,
|
||||||
|
timeout=10.0,
|
||||||
|
)
|
||||||
|
jwks_response.raise_for_status()
|
||||||
|
jwks = jwks_response.json()
|
||||||
|
|
||||||
|
# Get the key ID from the token header
|
||||||
|
unverified_header = pyjwt.get_unverified_header(id_token)
|
||||||
|
kid = unverified_header.get("kid")
|
||||||
|
if not kid:
|
||||||
|
raise AuthenticationError("ID token missing key ID (kid)")
|
||||||
|
|
||||||
|
# Find the matching public key
|
||||||
|
jwk_data = None
|
||||||
|
for key in jwks.get("keys", []):
|
||||||
|
if key.get("kid") == kid:
|
||||||
|
jwk_data = key
|
||||||
|
break
|
||||||
|
|
||||||
|
if not jwk_data:
|
||||||
|
raise AuthenticationError("ID token signed with unknown key")
|
||||||
|
|
||||||
|
# Convert JWK to a public key object for PyJWT
|
||||||
|
public_key = RSAAlgorithm.from_jwk(jwk_data)
|
||||||
|
|
||||||
|
# Verify the token signature and decode claims
|
||||||
|
# PyJWT will verify signature against the RSA public key
|
||||||
|
payload = pyjwt.decode(
|
||||||
|
id_token,
|
||||||
|
public_key,
|
||||||
|
algorithms=["RS256"], # Google uses RS256
|
||||||
|
audience=client_id,
|
||||||
|
issuer=OAuthService.GOOGLE_ISSUERS,
|
||||||
|
options={
|
||||||
|
"verify_signature": True,
|
||||||
|
"verify_aud": True,
|
||||||
|
"verify_iss": True,
|
||||||
|
"verify_exp": True,
|
||||||
|
"verify_iat": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify nonce (OIDC replay attack protection)
|
||||||
|
token_nonce = payload.get("nonce")
|
||||||
|
if token_nonce != expected_nonce:
|
||||||
|
logger.warning(
|
||||||
|
"OAuth ID token nonce mismatch: expected %s, got %s",
|
||||||
|
expected_nonce,
|
||||||
|
token_nonce,
|
||||||
|
)
|
||||||
|
raise AuthenticationError("Invalid ID token nonce")
|
||||||
|
|
||||||
|
logger.debug("Google ID token verified successfully")
|
||||||
|
return payload
|
||||||
|
|
||||||
|
except InvalidTokenError as e:
|
||||||
|
logger.warning("Google ID token verification failed: %s", e)
|
||||||
|
raise AuthenticationError("Invalid ID token signature")
|
||||||
|
except httpx.HTTPError as e:
|
||||||
|
logger.error("Failed to fetch Google JWKS: %s", e)
|
||||||
|
# If we can't verify the ID token, fail closed for security
|
||||||
|
raise AuthenticationError("Failed to verify ID token")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Unexpected error verifying Google ID token: %s", e)
|
||||||
|
raise AuthenticationError("ID token verification error")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _create_oauth_user(
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
email: str,
|
||||||
|
provider: str,
|
||||||
|
provider_user_id: str,
|
||||||
|
user_info: dict,
|
||||||
|
token: dict,
|
||||||
|
) -> User:
|
||||||
|
"""Create a new user from OAuth provider data."""
|
||||||
|
# Extract name from user_info
|
||||||
|
first_name = "User"
|
||||||
|
last_name = None
|
||||||
|
|
||||||
|
if provider == "google":
|
||||||
|
first_name = user_info.get("given_name") or user_info.get("name", "User")
|
||||||
|
last_name = user_info.get("family_name")
|
||||||
|
elif provider == "github":
|
||||||
|
# GitHub has full name, try to split
|
||||||
|
name = user_info.get("name") or user_info.get("login", "User")
|
||||||
|
parts = name.split(" ", 1)
|
||||||
|
first_name = parts[0]
|
||||||
|
last_name = parts[1] if len(parts) > 1 else None
|
||||||
|
|
||||||
|
# Create user (no password for OAuth-only users)
|
||||||
|
user = User(
|
||||||
|
email=email,
|
||||||
|
password_hash=None, # OAuth-only user
|
||||||
|
first_name=first_name,
|
||||||
|
last_name=last_name,
|
||||||
|
is_active=True,
|
||||||
|
is_superuser=False,
|
||||||
|
)
|
||||||
|
db.add(user)
|
||||||
|
await db.flush() # Get user.id
|
||||||
|
|
||||||
|
# Create OAuth account link
|
||||||
|
user_id = cast(UUID, user.id)
|
||||||
|
oauth_create = OAuthAccountCreate(
|
||||||
|
user_id=user_id,
|
||||||
|
provider=provider,
|
||||||
|
provider_user_id=provider_user_id,
|
||||||
|
provider_email=email,
|
||||||
|
access_token=token.get("access_token"),
|
||||||
|
refresh_token=token.get("refresh_token"),
|
||||||
|
token_expires_at=datetime.now(UTC)
|
||||||
|
+ timedelta(seconds=token.get("expires_in", 3600))
|
||||||
|
if token.get("expires_in")
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
await oauth_account.create_account(db, obj_in=oauth_create)
|
||||||
|
|
||||||
|
await db.refresh(user)
|
||||||
|
|
||||||
|
return user
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def unlink_provider(
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
user: User,
|
||||||
|
provider: str,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Unlink an OAuth provider from a user account.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Database session
|
||||||
|
user: User to unlink from
|
||||||
|
provider: Provider to unlink
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if unlinked successfully
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AuthenticationError: If unlinking would leave user without login method
|
||||||
|
"""
|
||||||
|
# Check if user can safely remove this OAuth account
|
||||||
|
# Note: We query directly instead of using user.can_remove_oauth property
|
||||||
|
# because the property uses lazy loading which doesn't work in async context
|
||||||
|
user_id = cast(UUID, user.id)
|
||||||
|
has_password = user.password_hash is not None
|
||||||
|
oauth_accounts = await oauth_account.get_user_accounts(db, user_id=user_id)
|
||||||
|
can_remove = has_password or len(oauth_accounts) > 1
|
||||||
|
|
||||||
|
if not can_remove:
|
||||||
|
raise AuthenticationError(
|
||||||
|
"Cannot unlink OAuth account. You must have either a password set "
|
||||||
|
"or at least one other OAuth provider linked."
|
||||||
|
)
|
||||||
|
|
||||||
|
deleted = await oauth_account.delete_account(
|
||||||
|
db, user_id=user_id, provider=provider
|
||||||
|
)
|
||||||
|
|
||||||
|
if not deleted:
|
||||||
|
raise AuthenticationError(f"No {provider} account found to unlink")
|
||||||
|
|
||||||
|
logger.info("OAuth provider unlinked: %s from %s", provider, user.email)
|
||||||
|
return True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_user_accounts(db: AsyncSession, *, user_id: UUID) -> list:
|
||||||
|
"""Get all OAuth accounts linked to a user."""
|
||||||
|
return await oauth_account.get_user_accounts(db, user_id=user_id)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_user_account_by_provider(
|
||||||
|
db: AsyncSession, *, user_id: UUID, provider: str
|
||||||
|
):
|
||||||
|
"""Get a specific OAuth account for a user and provider."""
|
||||||
|
return await oauth_account.get_user_account_by_provider(
|
||||||
|
db, user_id=user_id, provider=provider
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def cleanup_expired_states(db: AsyncSession) -> int:
|
||||||
|
"""
|
||||||
|
Clean up expired OAuth states.
|
||||||
|
|
||||||
|
Should be called periodically (e.g., by a background task).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Database session
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of states cleaned up
|
||||||
|
"""
|
||||||
|
return await oauth_state.cleanup_expired(db)
|
||||||
155
backend/app/services/organization_service.py
Normal file
155
backend/app/services/organization_service.py
Normal file
@@ -0,0 +1,155 @@
|
|||||||
|
# app/services/organization_service.py
|
||||||
|
"""Service layer for organization operations — delegates to OrganizationRepository."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.core.exceptions import NotFoundError
|
||||||
|
from app.models.organization import Organization
|
||||||
|
from app.models.user_organization import OrganizationRole, UserOrganization
|
||||||
|
from app.repositories.organization import OrganizationRepository, organization_repo
|
||||||
|
from app.schemas.organizations import OrganizationCreate, OrganizationUpdate
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OrganizationService:
|
||||||
|
"""Service for organization management operations."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, organization_repository: OrganizationRepository | None = None
|
||||||
|
) -> None:
|
||||||
|
self._repo = organization_repository or organization_repo
|
||||||
|
|
||||||
|
async def get_organization(self, db: AsyncSession, org_id: str) -> Organization:
|
||||||
|
"""Get organization by ID, raising NotFoundError if not found."""
|
||||||
|
org = await self._repo.get(db, id=org_id)
|
||||||
|
if not org:
|
||||||
|
raise NotFoundError(f"Organization {org_id} not found")
|
||||||
|
return org
|
||||||
|
|
||||||
|
async def create_organization(
|
||||||
|
self, db: AsyncSession, *, obj_in: OrganizationCreate
|
||||||
|
) -> Organization:
|
||||||
|
"""Create a new organization."""
|
||||||
|
return await self._repo.create(db, obj_in=obj_in)
|
||||||
|
|
||||||
|
async def update_organization(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
org: Organization,
|
||||||
|
obj_in: OrganizationUpdate | dict[str, Any],
|
||||||
|
) -> Organization:
|
||||||
|
"""Update an existing organization."""
|
||||||
|
return await self._repo.update(db, db_obj=org, obj_in=obj_in)
|
||||||
|
|
||||||
|
async def remove_organization(self, db: AsyncSession, org_id: str) -> None:
|
||||||
|
"""Permanently delete an organization by ID."""
|
||||||
|
await self._repo.remove(db, id=org_id)
|
||||||
|
|
||||||
|
async def get_member_count(self, db: AsyncSession, *, organization_id: UUID) -> int:
|
||||||
|
"""Get number of active members in an organization."""
|
||||||
|
return await self._repo.get_member_count(db, organization_id=organization_id)
|
||||||
|
|
||||||
|
async def get_multi_with_member_counts(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
skip: int = 0,
|
||||||
|
limit: int = 100,
|
||||||
|
is_active: bool | None = None,
|
||||||
|
search: str | None = None,
|
||||||
|
) -> tuple[list[dict[str, Any]], int]:
|
||||||
|
"""List organizations with member counts and pagination."""
|
||||||
|
return await self._repo.get_multi_with_member_counts(
|
||||||
|
db, skip=skip, limit=limit, is_active=is_active, search=search
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_user_organizations_with_details(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
user_id: UUID,
|
||||||
|
is_active: bool | None = None,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Get all organizations a user belongs to, with membership details."""
|
||||||
|
return await self._repo.get_user_organizations_with_details(
|
||||||
|
db, user_id=user_id, is_active=is_active
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_organization_members(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
organization_id: UUID,
|
||||||
|
skip: int = 0,
|
||||||
|
limit: int = 100,
|
||||||
|
is_active: bool | None = True,
|
||||||
|
) -> tuple[list[dict[str, Any]], int]:
|
||||||
|
"""Get members of an organization with pagination."""
|
||||||
|
return await self._repo.get_organization_members(
|
||||||
|
db,
|
||||||
|
organization_id=organization_id,
|
||||||
|
skip=skip,
|
||||||
|
limit=limit,
|
||||||
|
is_active=is_active,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def add_member(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
organization_id: UUID,
|
||||||
|
user_id: UUID,
|
||||||
|
role: OrganizationRole = OrganizationRole.MEMBER,
|
||||||
|
) -> UserOrganization:
|
||||||
|
"""Add a user to an organization."""
|
||||||
|
return await self._repo.add_user(
|
||||||
|
db, organization_id=organization_id, user_id=user_id, role=role
|
||||||
|
)
|
||||||
|
|
||||||
|
async def remove_member(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
organization_id: UUID,
|
||||||
|
user_id: UUID,
|
||||||
|
) -> bool:
|
||||||
|
"""Remove a user from an organization. Returns True if found and removed."""
|
||||||
|
return await self._repo.remove_user(
|
||||||
|
db, organization_id=organization_id, user_id=user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_user_role_in_org(
|
||||||
|
self, db: AsyncSession, *, user_id: UUID, organization_id: UUID
|
||||||
|
) -> OrganizationRole | None:
|
||||||
|
"""Get the role of a user in an organization."""
|
||||||
|
return await self._repo.get_user_role_in_org(
|
||||||
|
db, user_id=user_id, organization_id=organization_id
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_org_distribution(
|
||||||
|
self, db: AsyncSession, *, limit: int = 6
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Return top organizations by member count for admin dashboard."""
|
||||||
|
from sqlalchemy import func, select
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(
|
||||||
|
Organization.name,
|
||||||
|
func.count(UserOrganization.user_id).label("count"),
|
||||||
|
)
|
||||||
|
.join(UserOrganization, Organization.id == UserOrganization.organization_id)
|
||||||
|
.group_by(Organization.name)
|
||||||
|
.order_by(func.count(UserOrganization.user_id).desc())
|
||||||
|
.limit(limit)
|
||||||
|
)
|
||||||
|
return [{"name": row.name, "value": row.count} for row in result.all()]
|
||||||
|
|
||||||
|
|
||||||
|
# Default singleton
|
||||||
|
organization_service = OrganizationService()
|
||||||
@@ -8,7 +8,7 @@ import logging
|
|||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
from app.core.database import SessionLocal
|
from app.core.database import SessionLocal
|
||||||
from app.crud.session import session as session_crud
|
from app.repositories.session import session_repo as session_repo
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -32,15 +32,15 @@ async def cleanup_expired_sessions(keep_days: int = 30) -> int:
|
|||||||
|
|
||||||
async with SessionLocal() as db:
|
async with SessionLocal() as db:
|
||||||
try:
|
try:
|
||||||
# Use CRUD method to cleanup
|
# Use repository method to cleanup
|
||||||
count = await session_crud.cleanup_expired(db, keep_days=keep_days)
|
count = await session_repo.cleanup_expired(db, keep_days=keep_days)
|
||||||
|
|
||||||
logger.info(f"Session cleanup complete: {count} sessions deleted")
|
logger.info("Session cleanup complete: %s sessions deleted", count)
|
||||||
|
|
||||||
return count
|
return count
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error during session cleanup: {e!s}", exc_info=True)
|
logger.exception("Error during session cleanup: %s", e)
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
@@ -79,10 +79,10 @@ async def get_session_statistics() -> dict:
|
|||||||
"expired": expired_sessions,
|
"expired": expired_sessions,
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(f"Session statistics: {stats}")
|
logger.info("Session statistics: %s", stats)
|
||||||
|
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting session statistics: {e!s}", exc_info=True)
|
logger.exception("Error getting session statistics: %s", e)
|
||||||
return {}
|
return {}
|
||||||
|
|||||||
97
backend/app/services/session_service.py
Normal file
97
backend/app/services/session_service.py
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
# app/services/session_service.py
|
||||||
|
"""Service layer for session operations — delegates to SessionRepository."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.models.user_session import UserSession
|
||||||
|
from app.repositories.session import SessionRepository, session_repo
|
||||||
|
from app.schemas.sessions import SessionCreate
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SessionService:
|
||||||
|
"""Service for user session management operations."""
|
||||||
|
|
||||||
|
def __init__(self, session_repository: SessionRepository | None = None) -> None:
|
||||||
|
self._repo = session_repository or session_repo
|
||||||
|
|
||||||
|
async def create_session(
|
||||||
|
self, db: AsyncSession, *, obj_in: SessionCreate
|
||||||
|
) -> UserSession:
|
||||||
|
"""Create a new session record."""
|
||||||
|
return await self._repo.create_session(db, obj_in=obj_in)
|
||||||
|
|
||||||
|
async def get_session(
|
||||||
|
self, db: AsyncSession, session_id: str
|
||||||
|
) -> UserSession | None:
|
||||||
|
"""Get session by ID."""
|
||||||
|
return await self._repo.get(db, id=session_id)
|
||||||
|
|
||||||
|
async def get_user_sessions(
|
||||||
|
self, db: AsyncSession, *, user_id: str, active_only: bool = True
|
||||||
|
) -> list[UserSession]:
|
||||||
|
"""Get all sessions for a user."""
|
||||||
|
return await self._repo.get_user_sessions(
|
||||||
|
db, user_id=user_id, active_only=active_only
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_active_by_jti(
|
||||||
|
self, db: AsyncSession, *, jti: str
|
||||||
|
) -> UserSession | None:
|
||||||
|
"""Get active session by refresh token JTI."""
|
||||||
|
return await self._repo.get_active_by_jti(db, jti=jti)
|
||||||
|
|
||||||
|
async def get_by_jti(self, db: AsyncSession, *, jti: str) -> UserSession | None:
|
||||||
|
"""Get session by refresh token JTI (active or inactive)."""
|
||||||
|
return await self._repo.get_by_jti(db, jti=jti)
|
||||||
|
|
||||||
|
async def deactivate(
|
||||||
|
self, db: AsyncSession, *, session_id: str
|
||||||
|
) -> UserSession | None:
|
||||||
|
"""Deactivate a session (logout from device)."""
|
||||||
|
return await self._repo.deactivate(db, session_id=session_id)
|
||||||
|
|
||||||
|
async def deactivate_all_user_sessions(
|
||||||
|
self, db: AsyncSession, *, user_id: str
|
||||||
|
) -> int:
|
||||||
|
"""Deactivate all sessions for a user. Returns count deactivated."""
|
||||||
|
return await self._repo.deactivate_all_user_sessions(db, user_id=user_id)
|
||||||
|
|
||||||
|
async def update_refresh_token(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
session: UserSession,
|
||||||
|
new_jti: str,
|
||||||
|
new_expires_at: datetime,
|
||||||
|
) -> UserSession:
|
||||||
|
"""Update session with a rotated refresh token."""
|
||||||
|
return await self._repo.update_refresh_token(
|
||||||
|
db, session=session, new_jti=new_jti, new_expires_at=new_expires_at
|
||||||
|
)
|
||||||
|
|
||||||
|
async def cleanup_expired_for_user(self, db: AsyncSession, *, user_id: str) -> int:
|
||||||
|
"""Remove expired sessions for a user. Returns count removed."""
|
||||||
|
return await self._repo.cleanup_expired_for_user(db, user_id=user_id)
|
||||||
|
|
||||||
|
async def get_all_sessions(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
skip: int = 0,
|
||||||
|
limit: int = 100,
|
||||||
|
active_only: bool = True,
|
||||||
|
with_user: bool = True,
|
||||||
|
) -> tuple[list[UserSession], int]:
|
||||||
|
"""Get all sessions with pagination (admin only)."""
|
||||||
|
return await self._repo.get_all_sessions(
|
||||||
|
db, skip=skip, limit=limit, active_only=active_only, with_user=with_user
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Default singleton
|
||||||
|
session_service = SessionService()
|
||||||
120
backend/app/services/user_service.py
Normal file
120
backend/app/services/user_service.py
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
# app/services/user_service.py
|
||||||
|
"""Service layer for user operations — delegates to UserRepository."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.core.exceptions import NotFoundError
|
||||||
|
from app.models.user import User
|
||||||
|
from app.repositories.user import UserRepository, user_repo
|
||||||
|
from app.schemas.users import UserCreate, UserUpdate
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class UserService:
|
||||||
|
"""Service for user management operations."""
|
||||||
|
|
||||||
|
def __init__(self, user_repository: UserRepository | None = None) -> None:
|
||||||
|
self._repo = user_repository or user_repo
|
||||||
|
|
||||||
|
async def get_user(self, db: AsyncSession, user_id: str) -> User:
|
||||||
|
"""Get user by ID, raising NotFoundError if not found."""
|
||||||
|
user = await self._repo.get(db, id=user_id)
|
||||||
|
if not user:
|
||||||
|
raise NotFoundError(f"User {user_id} not found")
|
||||||
|
return user
|
||||||
|
|
||||||
|
async def get_by_email(self, db: AsyncSession, email: str) -> User | None:
|
||||||
|
"""Get user by email address."""
|
||||||
|
return await self._repo.get_by_email(db, email=email)
|
||||||
|
|
||||||
|
async def create_user(self, db: AsyncSession, user_data: UserCreate) -> User:
|
||||||
|
"""Create a new user."""
|
||||||
|
return await self._repo.create(db, obj_in=user_data)
|
||||||
|
|
||||||
|
async def update_user(
|
||||||
|
self, db: AsyncSession, *, user: User, obj_in: UserUpdate | dict[str, Any]
|
||||||
|
) -> User:
|
||||||
|
"""Update an existing user."""
|
||||||
|
return await self._repo.update(db, db_obj=user, obj_in=obj_in)
|
||||||
|
|
||||||
|
async def soft_delete_user(self, db: AsyncSession, user_id: str) -> None:
|
||||||
|
"""Soft-delete a user by ID."""
|
||||||
|
await self._repo.soft_delete(db, id=user_id)
|
||||||
|
|
||||||
|
async def list_users(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
skip: int = 0,
|
||||||
|
limit: int = 100,
|
||||||
|
sort_by: str | None = None,
|
||||||
|
sort_order: str = "asc",
|
||||||
|
filters: dict[str, Any] | None = None,
|
||||||
|
search: str | None = None,
|
||||||
|
) -> tuple[list[User], int]:
|
||||||
|
"""List users with pagination, sorting, filtering, and search."""
|
||||||
|
return await self._repo.get_multi_with_total(
|
||||||
|
db,
|
||||||
|
skip=skip,
|
||||||
|
limit=limit,
|
||||||
|
sort_by=sort_by,
|
||||||
|
sort_order=sort_order,
|
||||||
|
filters=filters,
|
||||||
|
search=search,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def bulk_update_status(
|
||||||
|
self, db: AsyncSession, *, user_ids: list[UUID], is_active: bool
|
||||||
|
) -> int:
|
||||||
|
"""Bulk update active status for multiple users. Returns count updated."""
|
||||||
|
return await self._repo.bulk_update_status(
|
||||||
|
db, user_ids=user_ids, is_active=is_active
|
||||||
|
)
|
||||||
|
|
||||||
|
async def bulk_soft_delete(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
*,
|
||||||
|
user_ids: list[UUID],
|
||||||
|
exclude_user_id: UUID | None = None,
|
||||||
|
) -> int:
|
||||||
|
"""Bulk soft-delete multiple users. Returns count deleted."""
|
||||||
|
return await self._repo.bulk_soft_delete(
|
||||||
|
db, user_ids=user_ids, exclude_user_id=exclude_user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_stats(self, db: AsyncSession) -> dict[str, Any]:
|
||||||
|
"""Return user stats needed for the admin dashboard."""
|
||||||
|
from sqlalchemy import func, select
|
||||||
|
|
||||||
|
total_users = (
|
||||||
|
await db.execute(select(func.count()).select_from(User))
|
||||||
|
).scalar() or 0
|
||||||
|
active_count = (
|
||||||
|
await db.execute(
|
||||||
|
select(func.count()).select_from(User).where(User.is_active)
|
||||||
|
)
|
||||||
|
).scalar() or 0
|
||||||
|
inactive_count = (
|
||||||
|
await db.execute(
|
||||||
|
select(func.count()).select_from(User).where(User.is_active.is_(False))
|
||||||
|
)
|
||||||
|
).scalar() or 0
|
||||||
|
all_users = list(
|
||||||
|
(await db.execute(select(User).order_by(User.created_at))).scalars().all()
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"total_users": total_users,
|
||||||
|
"active_count": active_count,
|
||||||
|
"inactive_count": inactive_count,
|
||||||
|
"all_users": all_users,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Default singleton
|
||||||
|
user_service = UserService()
|
||||||
@@ -65,10 +65,10 @@ async def setup_async_test_db():
|
|||||||
async with test_engine.begin() as conn:
|
async with test_engine.begin() as conn:
|
||||||
await conn.run_sync(Base.metadata.create_all)
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
|
||||||
AsyncTestingSessionLocal = sessionmaker(
|
AsyncTestingSessionLocal = sessionmaker( # pyright: ignore[reportCallIssue]
|
||||||
autocommit=False,
|
autocommit=False,
|
||||||
autoflush=False,
|
autoflush=False,
|
||||||
bind=test_engine,
|
bind=test_engine, # pyright: ignore[reportArgumentType]
|
||||||
expire_on_commit=False,
|
expire_on_commit=False,
|
||||||
class_=AsyncSession,
|
class_=AsyncSession,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -79,12 +79,13 @@ This FastAPI backend application follows a **clean layered architecture** patter
|
|||||||
|
|
||||||
### Authentication & Security
|
### Authentication & Security
|
||||||
|
|
||||||
- **python-jose**: JWT token generation and validation
|
- **PyJWT**: JWT token generation and validation
|
||||||
- Cryptographic signing
|
- Cryptographic signing (HS256, RS256)
|
||||||
- Token expiration handling
|
- Token expiration handling
|
||||||
- Claims validation
|
- Claims validation
|
||||||
|
- JWK support for Google ID token verification
|
||||||
|
|
||||||
- **passlib + bcrypt**: Password hashing
|
- **bcrypt**: Password hashing
|
||||||
- Industry-standard bcrypt algorithm
|
- Industry-standard bcrypt algorithm
|
||||||
- Configurable cost factor
|
- Configurable cost factor
|
||||||
- Salt generation
|
- Salt generation
|
||||||
@@ -117,7 +118,8 @@ backend/
|
|||||||
│ ├── api/ # API layer
|
│ ├── api/ # API layer
|
||||||
│ │ ├── dependencies/ # Dependency injection
|
│ │ ├── dependencies/ # Dependency injection
|
||||||
│ │ │ ├── auth.py # Authentication dependencies
|
│ │ │ ├── auth.py # Authentication dependencies
|
||||||
│ │ │ └── permissions.py # Authorization dependencies
|
│ │ │ ├── permissions.py # Authorization dependencies
|
||||||
|
│ │ │ └── services.py # Service singleton injection
|
||||||
│ │ ├── routes/ # API endpoints
|
│ │ ├── routes/ # API endpoints
|
||||||
│ │ │ ├── auth.py # Authentication routes
|
│ │ │ ├── auth.py # Authentication routes
|
||||||
│ │ │ ├── users.py # User management routes
|
│ │ │ ├── users.py # User management routes
|
||||||
@@ -131,13 +133,14 @@ backend/
|
|||||||
│ │ ├── config.py # Application configuration
|
│ │ ├── config.py # Application configuration
|
||||||
│ │ ├── database.py # Database connection
|
│ │ ├── database.py # Database connection
|
||||||
│ │ ├── exceptions.py # Custom exception classes
|
│ │ ├── exceptions.py # Custom exception classes
|
||||||
|
│ │ ├── repository_exceptions.py # Repository-level exception hierarchy
|
||||||
│ │ └── middleware.py # Custom middleware
|
│ │ └── middleware.py # Custom middleware
|
||||||
│ │
|
│ │
|
||||||
│ ├── crud/ # Database operations
|
│ ├── repositories/ # Data access layer
|
||||||
│ │ ├── base.py # Generic CRUD base class
|
│ │ ├── base.py # Generic repository base class
|
||||||
│ │ ├── user.py # User CRUD operations
|
│ │ ├── user.py # User repository
|
||||||
│ │ ├── session.py # Session CRUD operations
|
│ │ ├── session.py # Session repository
|
||||||
│ │ └── organization.py # Organization CRUD
|
│ │ └── organization.py # Organization repository
|
||||||
│ │
|
│ │
|
||||||
│ ├── models/ # SQLAlchemy models
|
│ ├── models/ # SQLAlchemy models
|
||||||
│ │ ├── base.py # Base model with mixins
|
│ │ ├── base.py # Base model with mixins
|
||||||
@@ -153,8 +156,11 @@ backend/
|
|||||||
│ │ ├── sessions.py # Session schemas
|
│ │ ├── sessions.py # Session schemas
|
||||||
│ │ └── organizations.py # Organization schemas
|
│ │ └── organizations.py # Organization schemas
|
||||||
│ │
|
│ │
|
||||||
│ ├── services/ # Business logic
|
│ ├── services/ # Business logic layer
|
||||||
│ │ ├── auth_service.py # Authentication service
|
│ │ ├── auth_service.py # Authentication service
|
||||||
|
│ │ ├── user_service.py # User management service
|
||||||
|
│ │ ├── session_service.py # Session management service
|
||||||
|
│ │ ├── organization_service.py # Organization service
|
||||||
│ │ ├── email_service.py # Email service
|
│ │ ├── email_service.py # Email service
|
||||||
│ │ └── session_cleanup.py # Background cleanup
|
│ │ └── session_cleanup.py # Background cleanup
|
||||||
│ │
|
│ │
|
||||||
@@ -168,20 +174,25 @@ backend/
|
|||||||
│
|
│
|
||||||
├── tests/ # Test suite
|
├── tests/ # Test suite
|
||||||
│ ├── api/ # Integration tests
|
│ ├── api/ # Integration tests
|
||||||
│ ├── crud/ # CRUD tests
|
│ ├── repositories/ # Repository unit tests
|
||||||
|
│ ├── services/ # Service unit tests
|
||||||
│ ├── models/ # Model tests
|
│ ├── models/ # Model tests
|
||||||
│ ├── services/ # Service tests
|
|
||||||
│ └── conftest.py # Test configuration
|
│ └── conftest.py # Test configuration
|
||||||
│
|
│
|
||||||
├── docs/ # Documentation
|
├── docs/ # Documentation
|
||||||
│ ├── ARCHITECTURE.md # This file
|
│ ├── ARCHITECTURE.md # This file
|
||||||
│ ├── CODING_STANDARDS.md # Coding standards
|
│ ├── CODING_STANDARDS.md # Coding standards
|
||||||
|
│ ├── COMMON_PITFALLS.md # Common mistakes to avoid
|
||||||
|
│ ├── E2E_TESTING.md # E2E testing guide
|
||||||
│ └── FEATURE_EXAMPLE.md # Feature implementation guide
|
│ └── FEATURE_EXAMPLE.md # Feature implementation guide
|
||||||
│
|
│
|
||||||
├── requirements.txt # Python dependencies
|
├── pyproject.toml # Dependencies, tool configs (Ruff, pytest, coverage, Pyright)
|
||||||
├── pytest.ini # Pytest configuration
|
├── uv.lock # Locked dependency versions (commit to git)
|
||||||
├── .coveragerc # Coverage configuration
|
├── Makefile # Development commands (quality, security, testing)
|
||||||
└── alembic.ini # Alembic configuration
|
├── .pre-commit-config.yaml # Pre-commit hook configuration
|
||||||
|
├── .secrets.baseline # detect-secrets baseline (known false positives)
|
||||||
|
├── alembic.ini # Alembic configuration
|
||||||
|
└── migrate.py # Migration helper script
|
||||||
```
|
```
|
||||||
|
|
||||||
## Layered Architecture
|
## Layered Architecture
|
||||||
@@ -214,11 +225,11 @@ The application follows a strict 5-layer architecture:
|
|||||||
└──────────────────────────┬──────────────────────────────────┘
|
└──────────────────────────┬──────────────────────────────────┘
|
||||||
│ calls
|
│ calls
|
||||||
┌──────────────────────────▼──────────────────────────────────┐
|
┌──────────────────────────▼──────────────────────────────────┐
|
||||||
│ CRUD Layer (crud/) │
|
│ Repository Layer (repositories/) │
|
||||||
│ - Database operations │
|
│ - Database operations │
|
||||||
│ - Query building │
|
│ - Query building │
|
||||||
│ - Transaction management │
|
│ - Custom repository exceptions │
|
||||||
│ - Error handling │
|
│ - No business logic │
|
||||||
└──────────────────────────┬──────────────────────────────────┘
|
└──────────────────────────┬──────────────────────────────────┘
|
||||||
│ uses
|
│ uses
|
||||||
┌──────────────────────────▼──────────────────────────────────┐
|
┌──────────────────────────▼──────────────────────────────────┐
|
||||||
@@ -262,7 +273,7 @@ async def get_current_user_info(
|
|||||||
|
|
||||||
**Rules**:
|
**Rules**:
|
||||||
- Should NOT contain business logic
|
- Should NOT contain business logic
|
||||||
- Should NOT directly perform database operations (use CRUD or services)
|
- Should NOT directly call repositories (use services injected via `dependencies/services.py`)
|
||||||
- Must validate all input via Pydantic schemas
|
- Must validate all input via Pydantic schemas
|
||||||
- Must specify response models
|
- Must specify response models
|
||||||
- Should apply appropriate rate limits
|
- Should apply appropriate rate limits
|
||||||
@@ -279,9 +290,9 @@ async def get_current_user_info(
|
|||||||
|
|
||||||
**Example**:
|
**Example**:
|
||||||
```python
|
```python
|
||||||
def get_current_user(
|
async def get_current_user(
|
||||||
token: str = Depends(oauth2_scheme),
|
token: str = Depends(oauth2_scheme),
|
||||||
db: Session = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
) -> User:
|
) -> User:
|
||||||
"""
|
"""
|
||||||
Extract and validate user from JWT token.
|
Extract and validate user from JWT token.
|
||||||
@@ -295,7 +306,7 @@ def get_current_user(
|
|||||||
except Exception:
|
except Exception:
|
||||||
raise AuthenticationError("Invalid authentication credentials")
|
raise AuthenticationError("Invalid authentication credentials")
|
||||||
|
|
||||||
user = user_crud.get(db, id=user_id)
|
user = await user_repo.get(db, id=user_id)
|
||||||
if not user:
|
if not user:
|
||||||
raise AuthenticationError("User not found")
|
raise AuthenticationError("User not found")
|
||||||
|
|
||||||
@@ -313,7 +324,7 @@ def get_current_user(
|
|||||||
**Responsibility**: Implement complex business logic
|
**Responsibility**: Implement complex business logic
|
||||||
|
|
||||||
**Key Functions**:
|
**Key Functions**:
|
||||||
- Orchestrate multiple CRUD operations
|
- Orchestrate multiple repository operations
|
||||||
- Implement business rules
|
- Implement business rules
|
||||||
- Handle external service integration
|
- Handle external service integration
|
||||||
- Coordinate transactions
|
- Coordinate transactions
|
||||||
@@ -323,9 +334,9 @@ def get_current_user(
|
|||||||
class AuthService:
|
class AuthService:
|
||||||
"""Authentication service with business logic."""
|
"""Authentication service with business logic."""
|
||||||
|
|
||||||
def login(
|
async def login(
|
||||||
self,
|
self,
|
||||||
db: Session,
|
db: AsyncSession,
|
||||||
email: str,
|
email: str,
|
||||||
password: str,
|
password: str,
|
||||||
request: Request
|
request: Request
|
||||||
@@ -339,8 +350,8 @@ class AuthService:
|
|||||||
3. Generate tokens
|
3. Generate tokens
|
||||||
4. Return tokens and user info
|
4. Return tokens and user info
|
||||||
"""
|
"""
|
||||||
# Validate credentials
|
# Validate credentials via repository
|
||||||
user = user_crud.get_by_email(db, email=email)
|
user = await user_repo.get_by_email(db, email=email)
|
||||||
if not user or not verify_password(password, user.hashed_password):
|
if not user or not verify_password(password, user.hashed_password):
|
||||||
raise AuthenticationError("Invalid credentials")
|
raise AuthenticationError("Invalid credentials")
|
||||||
|
|
||||||
@@ -350,11 +361,10 @@ class AuthService:
|
|||||||
# Extract device info
|
# Extract device info
|
||||||
device_info = extract_device_info(request)
|
device_info = extract_device_info(request)
|
||||||
|
|
||||||
# Create session
|
# Create session via repository
|
||||||
session = session_crud.create_session(
|
session = await session_repo.create(
|
||||||
db,
|
db,
|
||||||
user_id=user.id,
|
obj_in=SessionCreate(user_id=user.id, **device_info)
|
||||||
device_info=device_info
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Generate tokens
|
# Generate tokens
|
||||||
@@ -373,75 +383,60 @@ class AuthService:
|
|||||||
|
|
||||||
**Rules**:
|
**Rules**:
|
||||||
- Contains business logic, not just data operations
|
- Contains business logic, not just data operations
|
||||||
- Can call multiple CRUD operations
|
- Can call multiple repository operations
|
||||||
- Should handle complex workflows
|
- Should handle complex workflows
|
||||||
- Must maintain data consistency
|
- Must maintain data consistency
|
||||||
- Should use transactions when needed
|
- Should use transactions when needed
|
||||||
|
|
||||||
#### 4. CRUD Layer (`app/crud/`)
|
#### 4. Repository Layer (`app/repositories/`)
|
||||||
|
|
||||||
**Responsibility**: Database operations and queries
|
**Responsibility**: Database operations and queries — no business logic
|
||||||
|
|
||||||
**Key Functions**:
|
**Key Functions**:
|
||||||
- Create, read, update, delete operations
|
- Create, read, update, delete operations
|
||||||
- Build database queries
|
- Build database queries
|
||||||
- Handle database errors
|
- Raise custom repository exceptions (`DuplicateEntryError`, `IntegrityConstraintError`)
|
||||||
- Manage soft deletes
|
- Manage soft deletes
|
||||||
- Implement pagination and filtering
|
- Implement pagination and filtering
|
||||||
|
|
||||||
**Example**:
|
**Example**:
|
||||||
```python
|
```python
|
||||||
class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
class SessionRepository(RepositoryBase[UserSession, SessionCreate, SessionUpdate]):
|
||||||
"""CRUD operations for user sessions."""
|
"""Repository for user sessions — database operations only."""
|
||||||
|
|
||||||
def get_by_jti(self, db: Session, jti: UUID) -> Optional[UserSession]:
|
async def get_by_jti(self, db: AsyncSession, *, jti: str) -> UserSession | None:
|
||||||
"""Get session by refresh token JTI."""
|
"""Get session by refresh token JTI."""
|
||||||
try:
|
result = await db.execute(
|
||||||
return (
|
select(UserSession).where(UserSession.refresh_token_jti == jti)
|
||||||
db.query(UserSession)
|
)
|
||||||
.filter(UserSession.refresh_token_jti == jti)
|
return result.scalar_one_or_none()
|
||||||
.first()
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error getting session by JTI: {str(e)}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_active_by_jti(
|
async def deactivate(self, db: AsyncSession, *, session_id: UUID) -> bool:
|
||||||
self,
|
|
||||||
db: Session,
|
|
||||||
jti: UUID
|
|
||||||
) -> Optional[UserSession]:
|
|
||||||
"""Get active session by refresh token JTI."""
|
|
||||||
session = self.get_by_jti(db, jti=jti)
|
|
||||||
if session and session.is_active and not session.is_expired:
|
|
||||||
return session
|
|
||||||
return None
|
|
||||||
|
|
||||||
def deactivate(self, db: Session, session_id: UUID) -> bool:
|
|
||||||
"""Deactivate a session (logout)."""
|
"""Deactivate a session (logout)."""
|
||||||
try:
|
try:
|
||||||
session = self.get(db, id=session_id)
|
session = await self.get(db, id=session_id)
|
||||||
if not session:
|
if not session:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
session.is_active = False
|
session.is_active = False
|
||||||
db.commit()
|
await db.commit()
|
||||||
logger.info(f"Session {session_id} deactivated")
|
logger.info(f"Session {session_id} deactivated")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.rollback()
|
await db.rollback()
|
||||||
logger.error(f"Error deactivating session: {str(e)}")
|
logger.error(f"Error deactivating session: {str(e)}")
|
||||||
return False
|
return False
|
||||||
```
|
```
|
||||||
|
|
||||||
**Rules**:
|
**Rules**:
|
||||||
- Should NOT contain business logic
|
- Should NOT contain business logic
|
||||||
- Must handle database exceptions
|
- Must raise custom repository exceptions (not raw `ValueError`/`IntegrityError`)
|
||||||
- Must use parameterized queries (SQLAlchemy does this)
|
- Must use async SQLAlchemy 2.0 `select()` API (never `db.query()`)
|
||||||
- Should log all database errors
|
- Should log all database errors
|
||||||
- Must rollback on errors
|
- Must rollback on errors
|
||||||
- Should use soft deletes when possible
|
- Should use soft deletes when possible
|
||||||
|
- **Never imported directly by routes** — always called through services
|
||||||
|
|
||||||
#### 5. Data Layer (`app/models/` + `app/schemas/`)
|
#### 5. Data Layer (`app/models/` + `app/schemas/`)
|
||||||
|
|
||||||
@@ -546,51 +541,23 @@ SessionLocal = sessionmaker(
|
|||||||
#### Dependency Injection Pattern
|
#### Dependency Injection Pattern
|
||||||
|
|
||||||
```python
|
```python
|
||||||
def get_db() -> Generator[Session, None, None]:
|
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||||
"""
|
"""
|
||||||
Database session dependency for FastAPI routes.
|
Async database session dependency for FastAPI routes.
|
||||||
|
|
||||||
Automatically commits on success, rolls back on error.
|
The session is passed to service methods; commit/rollback is
|
||||||
|
managed inside service or repository methods.
|
||||||
"""
|
"""
|
||||||
db = SessionLocal()
|
async with AsyncSessionLocal() as db:
|
||||||
try:
|
|
||||||
yield db
|
yield db
|
||||||
finally:
|
|
||||||
db.close()
|
|
||||||
|
|
||||||
# Usage in routes
|
# Usage in routes — always through a service, never direct repository
|
||||||
@router.get("/users")
|
@router.get("/users")
|
||||||
def list_users(db: Session = Depends(get_db)):
|
async def list_users(
|
||||||
return user_crud.get_multi(db)
|
user_service: UserService = Depends(get_user_service),
|
||||||
```
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
#### Context Manager Pattern
|
return await user_service.get_users(db)
|
||||||
|
|
||||||
```python
|
|
||||||
@contextmanager
|
|
||||||
def transaction_scope() -> Generator[Session, None, None]:
|
|
||||||
"""
|
|
||||||
Context manager for database transactions.
|
|
||||||
|
|
||||||
Use for complex operations requiring multiple steps.
|
|
||||||
Automatically commits on success, rolls back on error.
|
|
||||||
"""
|
|
||||||
db = SessionLocal()
|
|
||||||
try:
|
|
||||||
yield db
|
|
||||||
db.commit()
|
|
||||||
except Exception:
|
|
||||||
db.rollback()
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
db.close()
|
|
||||||
|
|
||||||
# Usage in services
|
|
||||||
def complex_operation():
|
|
||||||
with transaction_scope() as db:
|
|
||||||
user = user_crud.create(db, obj_in=user_data)
|
|
||||||
session = session_crud.create(db, session_data)
|
|
||||||
return user, session
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Model Mixins
|
### Model Mixins
|
||||||
@@ -782,22 +749,15 @@ def get_profile(
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
@router.delete("/sessions/{session_id}")
|
@router.delete("/sessions/{session_id}")
|
||||||
def revoke_session(
|
async def revoke_session(
|
||||||
session_id: UUID,
|
session_id: UUID,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db)
|
session_service: SessionService = Depends(get_session_service),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
"""Users can only revoke their own sessions."""
|
"""Users can only revoke their own sessions."""
|
||||||
session = session_crud.get(db, id=session_id)
|
# SessionService verifies ownership and raises NotFoundError / AuthorizationError
|
||||||
|
await session_service.revoke_session(db, session_id=session_id, user_id=current_user.id)
|
||||||
if not session:
|
|
||||||
raise NotFoundError("Session not found")
|
|
||||||
|
|
||||||
# Check ownership
|
|
||||||
if session.user_id != current_user.id:
|
|
||||||
raise AuthorizationError("You can only revoke your own sessions")
|
|
||||||
|
|
||||||
session_crud.deactivate(db, session_id=session_id)
|
|
||||||
return MessageResponse(success=True, message="Session revoked")
|
return MessageResponse(success=True, message="Session revoked")
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -818,6 +778,84 @@ def add_member(
|
|||||||
pass
|
pass
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### OAuth Integration
|
||||||
|
|
||||||
|
The system supports two OAuth modes:
|
||||||
|
|
||||||
|
#### OAuth Consumer Mode (Social Login)
|
||||||
|
|
||||||
|
Users can authenticate via Google or GitHub OAuth providers:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Get authorization URL with PKCE support
|
||||||
|
GET /oauth/authorize/{provider}?redirect_uri=https://yourapp.com/callback
|
||||||
|
|
||||||
|
# Handle callback and exchange code for tokens
|
||||||
|
POST /oauth/callback/{provider}
|
||||||
|
{
|
||||||
|
"code": "authorization_code_from_provider",
|
||||||
|
"state": "csrf_state_token"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Security Features:**
|
||||||
|
- PKCE (S256) for Google
|
||||||
|
- State parameter for CSRF protection
|
||||||
|
- Nonce for Google OIDC replay attack prevention
|
||||||
|
- Google ID token signature verification via JWKS
|
||||||
|
- Email normalization to prevent account duplication
|
||||||
|
- Auto-linking by email (configurable)
|
||||||
|
|
||||||
|
#### OAuth Provider Mode (MCP Integration)
|
||||||
|
|
||||||
|
Full OAuth 2.0 Authorization Server for third-party clients (RFC compliant):
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────┐ ┌─────────────┐
|
||||||
|
│ MCP Client │ │ Backend │
|
||||||
|
└──────┬──────┘ └──────┬──────┘
|
||||||
|
│ │
|
||||||
|
│ GET /.well-known/oauth-authorization-server│
|
||||||
|
│─────────────────────────────────────────────>│
|
||||||
|
│ {metadata} │
|
||||||
|
│<─────────────────────────────────────────────│
|
||||||
|
│ │
|
||||||
|
│ GET /oauth/provider/authorize │
|
||||||
|
│ ?response_type=code&client_id=... │
|
||||||
|
│ &redirect_uri=...&code_challenge=... │
|
||||||
|
│─────────────────────────────────────────────>│
|
||||||
|
│ │
|
||||||
|
│ (User consents) │
|
||||||
|
│ │
|
||||||
|
│ 302 redirect_uri?code=AUTH_CODE&state=... │
|
||||||
|
│<─────────────────────────────────────────────│
|
||||||
|
│ │
|
||||||
|
│ POST /oauth/provider/token │
|
||||||
|
│ {grant_type=authorization_code, │
|
||||||
|
│ code=AUTH_CODE, code_verifier=...} │
|
||||||
|
│─────────────────────────────────────────────>│
|
||||||
|
│ │
|
||||||
|
│ {access_token, refresh_token, expires_in} │
|
||||||
|
│<─────────────────────────────────────────────│
|
||||||
|
│ │
|
||||||
|
```
|
||||||
|
|
||||||
|
**Endpoints:**
|
||||||
|
- `GET /.well-known/oauth-authorization-server` - RFC 8414 metadata
|
||||||
|
- `GET /oauth/provider/authorize` - Authorization endpoint
|
||||||
|
- `POST /oauth/provider/token` - Token endpoint (authorization_code, refresh_token)
|
||||||
|
- `POST /oauth/provider/revoke` - RFC 7009 token revocation
|
||||||
|
- `POST /oauth/provider/introspect` - RFC 7662 token introspection
|
||||||
|
|
||||||
|
**Security Features:**
|
||||||
|
- PKCE S256 required for public clients (plain method rejected)
|
||||||
|
- Authorization codes are single-use with 10-minute expiry
|
||||||
|
- Code reuse detection triggers security incident (all tokens revoked)
|
||||||
|
- Refresh token rotation on use
|
||||||
|
- Opaque refresh tokens (hashed in database)
|
||||||
|
- JWT access tokens with standard claims
|
||||||
|
- Consent management per client
|
||||||
|
|
||||||
## Error Handling
|
## Error Handling
|
||||||
|
|
||||||
### Exception Hierarchy
|
### Exception Hierarchy
|
||||||
@@ -983,23 +1021,27 @@ from app.services.session_cleanup import cleanup_expired_sessions
|
|||||||
|
|
||||||
scheduler = AsyncIOScheduler()
|
scheduler = AsyncIOScheduler()
|
||||||
|
|
||||||
@app.on_event("startup")
|
@asynccontextmanager
|
||||||
async def startup_event():
|
async def lifespan(app: FastAPI):
|
||||||
"""Start background jobs on application startup."""
|
"""Application lifespan context manager."""
|
||||||
if not settings.IS_TEST: # Don't run in tests
|
# Startup
|
||||||
|
if os.getenv("IS_TEST", "False") != "True":
|
||||||
scheduler.add_job(
|
scheduler.add_job(
|
||||||
cleanup_expired_sessions,
|
cleanup_expired_sessions,
|
||||||
"cron",
|
"cron",
|
||||||
hour=2, # Run at 2 AM daily
|
hour=2, # Run at 2 AM daily
|
||||||
id="cleanup_expired_sessions"
|
id="cleanup_expired_sessions",
|
||||||
|
replace_existing=True,
|
||||||
)
|
)
|
||||||
scheduler.start()
|
scheduler.start()
|
||||||
logger.info("Background jobs started")
|
logger.info("Background jobs started")
|
||||||
|
|
||||||
@app.on_event("shutdown")
|
yield
|
||||||
async def shutdown_event():
|
|
||||||
"""Stop background jobs on application shutdown."""
|
# Shutdown
|
||||||
scheduler.shutdown()
|
if os.getenv("IS_TEST", "False") != "True":
|
||||||
|
scheduler.shutdown()
|
||||||
|
await close_async_db() # Dispose database engine connections
|
||||||
```
|
```
|
||||||
|
|
||||||
### Job Implementation
|
### Job Implementation
|
||||||
@@ -1014,8 +1056,8 @@ async def cleanup_expired_sessions():
|
|||||||
Runs daily at 2 AM. Removes sessions expired for more than 30 days.
|
Runs daily at 2 AM. Removes sessions expired for more than 30 days.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
with transaction_scope() as db:
|
async with AsyncSessionLocal() as db:
|
||||||
count = session_crud.cleanup_expired(db, keep_days=30)
|
count = await session_repo.cleanup_expired(db, keep_days=30)
|
||||||
logger.info(f"Cleaned up {count} expired sessions")
|
logger.info(f"Cleaned up {count} expired sessions")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error cleaning up sessions: {str(e)}", exc_info=True)
|
logger.error(f"Error cleaning up sessions: {str(e)}", exc_info=True)
|
||||||
@@ -1032,7 +1074,7 @@ async def cleanup_expired_sessions():
|
|||||||
│Integration │ ← API endpoint tests
|
│Integration │ ← API endpoint tests
|
||||||
│ Tests │
|
│ Tests │
|
||||||
├─────────────┤
|
├─────────────┤
|
||||||
│ Unit │ ← CRUD, services, utilities
|
│ Unit │ ← repositories, services, utilities
|
||||||
│ Tests │
|
│ Tests │
|
||||||
└─────────────┘
|
└─────────────┘
|
||||||
```
|
```
|
||||||
@@ -1127,6 +1169,8 @@ app.add_middleware(
|
|||||||
|
|
||||||
## Performance Considerations
|
## Performance Considerations
|
||||||
|
|
||||||
|
> 📖 For the full benchmarking guide (how to run, read results, write new benchmarks, and manage baselines), see **[BENCHMARKS.md](BENCHMARKS.md)**.
|
||||||
|
|
||||||
### Database Connection Pooling
|
### Database Connection Pooling
|
||||||
|
|
||||||
- Pool size: 20 connections
|
- Pool size: 20 connections
|
||||||
|
|||||||
311
backend/docs/BENCHMARKS.md
Normal file
311
backend/docs/BENCHMARKS.md
Normal file
@@ -0,0 +1,311 @@
|
|||||||
|
# Performance Benchmarks Guide
|
||||||
|
|
||||||
|
Automated performance benchmarking infrastructure using **pytest-benchmark** to detect latency regressions in critical API endpoints.
|
||||||
|
|
||||||
|
## Table of Contents
|
||||||
|
|
||||||
|
- [Why Benchmark?](#why-benchmark)
|
||||||
|
- [Quick Start](#quick-start)
|
||||||
|
- [How It Works](#how-it-works)
|
||||||
|
- [Understanding Results](#understanding-results)
|
||||||
|
- [Test Organization](#test-organization)
|
||||||
|
- [Writing Benchmark Tests](#writing-benchmark-tests)
|
||||||
|
- [Baseline Management](#baseline-management)
|
||||||
|
- [CI/CD Integration](#cicd-integration)
|
||||||
|
- [Troubleshooting](#troubleshooting)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Why Benchmark?
|
||||||
|
|
||||||
|
Performance regressions are silent bugs — they don't break tests or cause errors, but they degrade the user experience over time. Common causes include:
|
||||||
|
|
||||||
|
- **Unintended N+1 queries** after adding a relationship
|
||||||
|
- **Heavier serialization** after adding new fields to a response model
|
||||||
|
- **Middleware overhead** from new security headers or logging
|
||||||
|
- **Dependency upgrades** that introduce slower code paths
|
||||||
|
|
||||||
|
Without automated benchmarks, these regressions go unnoticed until users complain. Performance benchmarks serve as an **early warning system** — they measure endpoint latency on every run and flag significant deviations from an established baseline.
|
||||||
|
|
||||||
|
### What benchmarks give you
|
||||||
|
|
||||||
|
| Benefit | Description |
|
||||||
|
|---------|-------------|
|
||||||
|
| **Regression detection** | Automatically flags when an endpoint becomes significantly slower |
|
||||||
|
| **Baseline tracking** | Stores known-good performance numbers for comparison |
|
||||||
|
| **Confidence in refactors** | Verify that code changes don't degrade response times |
|
||||||
|
| **Visibility** | Makes performance a first-class, measurable quality attribute |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Run benchmarks (no comparison, just see current numbers)
|
||||||
|
make benchmark
|
||||||
|
|
||||||
|
# Save current results as the baseline
|
||||||
|
make benchmark-save
|
||||||
|
|
||||||
|
# Run benchmarks and compare against the saved baseline
|
||||||
|
make benchmark-check
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## How It Works
|
||||||
|
|
||||||
|
The benchmarking system has three layers:
|
||||||
|
|
||||||
|
### 1. pytest-benchmark integration
|
||||||
|
|
||||||
|
[pytest-benchmark](https://pytest-benchmark.readthedocs.io/) is a pytest plugin that provides a `benchmark` fixture. It handles:
|
||||||
|
|
||||||
|
- **Calibration**: Automatically determines how many iterations to run for statistical significance
|
||||||
|
- **Timing**: Uses `time.perf_counter` for high-resolution measurements
|
||||||
|
- **Statistics**: Computes min, max, mean, median, standard deviation, IQR, and outlier detection
|
||||||
|
- **Comparison**: Compares current results against saved baselines and flags regressions
|
||||||
|
|
||||||
|
### 2. Benchmark types
|
||||||
|
|
||||||
|
The test suite includes two categories of performance tests:
|
||||||
|
|
||||||
|
| Type | How it works | Examples |
|
||||||
|
|------|-------------|----------|
|
||||||
|
| **pytest-benchmark tests** | Uses the `benchmark` fixture for precise, multi-round timing | `test_health_endpoint_performance`, `test_openapi_schema_performance`, `test_password_hashing_performance`, `test_password_verification_performance`, `test_access_token_creation_performance`, `test_refresh_token_creation_performance`, `test_token_decode_performance` |
|
||||||
|
| **Manual latency tests** | Uses `time.perf_counter` with explicit thresholds (for async endpoints that pytest-benchmark doesn't support natively) | `test_login_latency`, `test_get_current_user_latency`, `test_register_latency`, `test_token_refresh_latency`, `test_sessions_list_latency`, `test_user_profile_update_latency` |
|
||||||
|
|
||||||
|
### 3. Regression detection
|
||||||
|
|
||||||
|
When running `make benchmark-check`, the system:
|
||||||
|
|
||||||
|
1. Runs all benchmark tests
|
||||||
|
2. Compares results against the saved baseline (`.benchmarks/` directory)
|
||||||
|
3. **Fails the build** if any test's mean time exceeds **200%** of the baseline (i.e., 3× slower)
|
||||||
|
|
||||||
|
The `200%` threshold in `--benchmark-compare-fail=mean:200%` means "fail if the mean increased by more than 200% relative to the baseline." This is deliberately generous to avoid false positives from normal run-to-run variance while still catching real regressions.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Understanding Results
|
||||||
|
|
||||||
|
A typical benchmark output looks like this:
|
||||||
|
|
||||||
|
```
|
||||||
|
--------------------------------------------------------------------------------------- benchmark: 2 tests --------------------------------------------------------------------------------------
|
||||||
|
Name (time in ms) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations
|
||||||
|
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||||
|
test_health_endpoint_performance 0.9841 (1.0) 1.5513 (1.0) 1.1390 (1.0) 0.1098 (1.0) 1.1151 (1.0) 0.1672 (1.0) 39;2 877.9666 (1.0) 133 1
|
||||||
|
test_openapi_schema_performance 1.6523 (1.68) 2.0892 (1.35) 1.7843 (1.57) 0.1553 (1.41) 1.7200 (1.54) 0.1727 (1.03) 2;0 560.4471 (0.64) 10 1
|
||||||
|
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||||
|
```
|
||||||
|
|
||||||
|
### Column reference
|
||||||
|
|
||||||
|
| Column | Meaning |
|
||||||
|
|--------|---------|
|
||||||
|
| **Min** | Fastest single execution |
|
||||||
|
| **Max** | Slowest single execution |
|
||||||
|
| **Mean** | Average across all rounds — the primary metric for regression detection |
|
||||||
|
| **StdDev** | How much results vary between rounds (lower = more stable) |
|
||||||
|
| **Median** | Middle value, less sensitive to outliers than mean |
|
||||||
|
| **IQR** | Interquartile range — spread of the middle 50% of results |
|
||||||
|
| **Outliers** | Format `A;B` — A = within 1 StdDev, B = within 1.5 IQR from quartiles |
|
||||||
|
| **OPS** | Operations per second (`1 / Mean`) |
|
||||||
|
| **Rounds** | How many times the test was executed (auto-calibrated) |
|
||||||
|
| **Iterations** | Iterations per round (usually 1 for ms-scale tests) |
|
||||||
|
|
||||||
|
### The ratio numbers `(1.0)`, `(1.68)`, etc.
|
||||||
|
|
||||||
|
These show how each test compares **to the best result in that column**. The fastest test is always `(1.0)`, and others show their relative factor. For example, `(1.68)` means "1.68× slower than the fastest."
|
||||||
|
|
||||||
|
### Color coding
|
||||||
|
|
||||||
|
- **Green**: The fastest (best) value in each column
|
||||||
|
- **Red**: The slowest (worst) value in each column
|
||||||
|
|
||||||
|
This is a **relative ranking within the current run** — red does NOT mean the test failed or that performance is bad. It simply highlights which endpoint is the slower one in the group.
|
||||||
|
|
||||||
|
### What's "normal"?
|
||||||
|
|
||||||
|
For this project's current endpoints:
|
||||||
|
|
||||||
|
| Test | Expected range | Why |
|
||||||
|
|------|---------------|-----|
|
||||||
|
| `GET /health` | ~1–1.5ms | Minimal logic, mocked DB check |
|
||||||
|
| `GET /api/v1/openapi.json` | ~1.5–2.5ms | Serializes entire API schema |
|
||||||
|
| `get_password_hash` | ~200ms | CPU-bound bcrypt hashing |
|
||||||
|
| `verify_password` | ~200ms | CPU-bound bcrypt verification |
|
||||||
|
| `create_access_token` | ~17–20µs | JWT encoding with HMAC-SHA256 |
|
||||||
|
| `create_refresh_token` | ~17–20µs | JWT encoding with HMAC-SHA256 |
|
||||||
|
| `decode_token` | ~20–25µs | JWT decoding and claim validation |
|
||||||
|
| `POST /api/v1/auth/login` | < 500ms threshold | Includes bcrypt password verification |
|
||||||
|
| `POST /api/v1/auth/register` | < 500ms threshold | Includes bcrypt password hashing |
|
||||||
|
| `POST /api/v1/auth/refresh` | < 200ms threshold | Token rotation + DB session update |
|
||||||
|
| `GET /api/v1/users/me` | < 200ms threshold | DB lookup + token validation |
|
||||||
|
| `GET /api/v1/sessions/me` | < 200ms threshold | Session list query + token validation |
|
||||||
|
| `PATCH /api/v1/users/me` | < 200ms threshold | DB update + token validation |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Test Organization
|
||||||
|
|
||||||
|
```
|
||||||
|
backend/tests/
|
||||||
|
├── benchmarks/
|
||||||
|
│ └── test_endpoint_performance.py # All performance benchmark tests
|
||||||
|
│
|
||||||
|
backend/.benchmarks/ # Saved baselines (auto-generated)
|
||||||
|
└── Linux-CPython-3.12-64bit/
|
||||||
|
└── 0001_baseline.json # Platform-specific baseline file
|
||||||
|
```
|
||||||
|
|
||||||
|
### Test markers
|
||||||
|
|
||||||
|
All benchmark tests use the `@pytest.mark.benchmark` marker. The `--benchmark-only` flag ensures that only tests using the `benchmark` fixture are executed during benchmark runs, while manual latency tests (async) are skipped.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Writing Benchmark Tests
|
||||||
|
|
||||||
|
### Stateless endpoint (using pytest-benchmark fixture)
|
||||||
|
|
||||||
|
```python
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
def test_my_endpoint_performance(sync_client, benchmark):
|
||||||
|
"""Benchmark: GET /my-endpoint should respond within acceptable latency."""
|
||||||
|
result = benchmark(sync_client.get, "/my-endpoint")
|
||||||
|
assert result.status_code == 200
|
||||||
|
```
|
||||||
|
|
||||||
|
The `benchmark` fixture handles all timing, calibration, and statistics automatically. Just pass it the callable and arguments.
|
||||||
|
|
||||||
|
### Async / DB-dependent endpoint (manual timing)
|
||||||
|
|
||||||
|
For async endpoints that require database access, use manual timing with an explicit threshold:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import time
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
MAX_RESPONSE_MS = 300
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_my_async_endpoint_latency(client, setup_fixture):
|
||||||
|
"""Performance: endpoint must respond under threshold."""
|
||||||
|
iterations = 5
|
||||||
|
total_ms = 0.0
|
||||||
|
|
||||||
|
for _ in range(iterations):
|
||||||
|
start = time.perf_counter()
|
||||||
|
response = await client.get("/api/v1/my-endpoint")
|
||||||
|
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||||
|
total_ms += elapsed_ms
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
mean_ms = total_ms / iterations
|
||||||
|
assert mean_ms < MAX_RESPONSE_MS, (
|
||||||
|
f"Latency regression: {mean_ms:.1f}ms exceeds {MAX_RESPONSE_MS}ms threshold"
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Guidelines for new benchmarks
|
||||||
|
|
||||||
|
1. **Benchmark critical paths** — endpoints users hit frequently or where latency matters most
|
||||||
|
2. **Mock external dependencies** for stateless tests to isolate endpoint overhead
|
||||||
|
3. **Set generous thresholds** for manual tests — account for CI variability
|
||||||
|
4. **Keep benchmarks fast** — they run on every check, so avoid heavy setup
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Baseline Management
|
||||||
|
|
||||||
|
### Saving a baseline
|
||||||
|
|
||||||
|
```bash
|
||||||
|
make benchmark-save
|
||||||
|
```
|
||||||
|
|
||||||
|
This runs all benchmarks and saves results to `.benchmarks/<platform>/0001_baseline.json`. The baseline captures:
|
||||||
|
- Mean, min, max, median, stddev for each test
|
||||||
|
- Machine info (CPU, OS, Python version)
|
||||||
|
- Timestamp
|
||||||
|
|
||||||
|
### Comparing against baseline
|
||||||
|
|
||||||
|
```bash
|
||||||
|
make benchmark-check
|
||||||
|
```
|
||||||
|
|
||||||
|
If no baseline exists, this command automatically creates one and prints a warning. On subsequent runs, it compares current results against the saved baseline.
|
||||||
|
|
||||||
|
### When to update the baseline
|
||||||
|
|
||||||
|
- **After intentional performance changes** (e.g., you optimized an endpoint — save the new, faster baseline)
|
||||||
|
- **After infrastructure changes** (e.g., new CI runner, different hardware)
|
||||||
|
- **After adding new benchmark tests** (the new tests need a baseline entry)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Update the baseline after intentional changes
|
||||||
|
make benchmark-save
|
||||||
|
```
|
||||||
|
|
||||||
|
### Version control
|
||||||
|
|
||||||
|
The `.benchmarks/` directory can be committed to version control so that CI pipelines can compare against a known-good baseline. However, since benchmark results are machine-specific, you may prefer to generate baselines in CI rather than committing local results.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## CI/CD Integration
|
||||||
|
|
||||||
|
Add benchmark checking to your CI pipeline to catch regressions on every PR:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# Example GitHub Actions step
|
||||||
|
- name: Performance regression check
|
||||||
|
run: |
|
||||||
|
cd backend
|
||||||
|
make benchmark-save # Create baseline from main branch
|
||||||
|
# ... apply PR changes ...
|
||||||
|
make benchmark-check # Compare PR against baseline
|
||||||
|
```
|
||||||
|
|
||||||
|
A more robust approach:
|
||||||
|
1. Save the baseline on the `main` branch after each merge
|
||||||
|
2. On PR branches, run `make benchmark-check` against the `main` baseline
|
||||||
|
3. The pipeline fails if any endpoint regresses beyond the 200% threshold
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### "No benchmark baseline found" warning
|
||||||
|
|
||||||
|
```
|
||||||
|
⚠️ No benchmark baseline found. Run 'make benchmark-save' first to create one.
|
||||||
|
```
|
||||||
|
|
||||||
|
This means no baseline file exists yet. The command will auto-create one. Future runs of `make benchmark-check` will compare against it.
|
||||||
|
|
||||||
|
### Machine info mismatch warning
|
||||||
|
|
||||||
|
```
|
||||||
|
WARNING: benchmark machine_info is different
|
||||||
|
```
|
||||||
|
|
||||||
|
This is expected when comparing baselines generated on a different machine or OS. The comparison still works, but absolute numbers may differ. Re-save the baseline on the current machine if needed.
|
||||||
|
|
||||||
|
### High variance (large StdDev)
|
||||||
|
|
||||||
|
If StdDev is high relative to the Mean, results may be unreliable. Common causes:
|
||||||
|
- System under load during benchmark run
|
||||||
|
- Garbage collection interference
|
||||||
|
- Thermal throttling
|
||||||
|
|
||||||
|
Try running benchmarks on an idle system or increasing `min_rounds` in `pyproject.toml`.
|
||||||
|
|
||||||
|
### Only 7 of 13 tests run
|
||||||
|
|
||||||
|
The async tests (`test_login_latency`, `test_get_current_user_latency`, `test_register_latency`, `test_token_refresh_latency`, `test_sessions_list_latency`, `test_user_profile_update_latency`) are skipped during `--benchmark-only` runs because they don't use the `benchmark` fixture. They run as part of the normal test suite (`make test`) with manual threshold assertions.
|
||||||
@@ -8,6 +8,7 @@ This document outlines the coding standards and best practices for the FastAPI b
|
|||||||
- [Code Organization](#code-organization)
|
- [Code Organization](#code-organization)
|
||||||
- [Naming Conventions](#naming-conventions)
|
- [Naming Conventions](#naming-conventions)
|
||||||
- [Error Handling](#error-handling)
|
- [Error Handling](#error-handling)
|
||||||
|
- [Data Models and Migrations](#data-models-and-migrations)
|
||||||
- [Database Operations](#database-operations)
|
- [Database Operations](#database-operations)
|
||||||
- [API Endpoints](#api-endpoints)
|
- [API Endpoints](#api-endpoints)
|
||||||
- [Authentication & Security](#authentication--security)
|
- [Authentication & Security](#authentication--security)
|
||||||
@@ -74,15 +75,14 @@ def create_user(db: Session, user_in: UserCreate) -> User:
|
|||||||
### 4. Code Formatting
|
### 4. Code Formatting
|
||||||
|
|
||||||
Use automated formatters:
|
Use automated formatters:
|
||||||
- **Black**: Code formatting
|
- **Ruff**: Code formatting and linting (replaces Black, isort, flake8)
|
||||||
- **isort**: Import sorting
|
- **pyright**: Static type checking
|
||||||
- **flake8**: Linting
|
|
||||||
|
|
||||||
Run before committing:
|
Run before committing (or use `make validate`):
|
||||||
```bash
|
```bash
|
||||||
black app tests
|
uv run ruff format app tests
|
||||||
isort app tests
|
uv run ruff check app tests
|
||||||
flake8 app tests
|
uv run pyright app
|
||||||
```
|
```
|
||||||
|
|
||||||
## Code Organization
|
## Code Organization
|
||||||
@@ -93,19 +93,17 @@ Follow the 5-layer architecture strictly:
|
|||||||
|
|
||||||
```
|
```
|
||||||
API Layer (routes/)
|
API Layer (routes/)
|
||||||
↓ calls
|
↓ calls (via service injected from dependencies/services.py)
|
||||||
Dependencies (dependencies/)
|
|
||||||
↓ injects
|
|
||||||
Service Layer (services/)
|
Service Layer (services/)
|
||||||
↓ calls
|
↓ calls
|
||||||
CRUD Layer (crud/)
|
Repository Layer (repositories/)
|
||||||
↓ uses
|
↓ uses
|
||||||
Models & Schemas (models/, schemas/)
|
Models & Schemas (models/, schemas/)
|
||||||
```
|
```
|
||||||
|
|
||||||
**Rules:**
|
**Rules:**
|
||||||
- Routes should NOT directly call CRUD operations (use services when business logic is needed)
|
- Routes must NEVER import repositories directly — always use a service
|
||||||
- CRUD operations should NOT contain business logic
|
- Services call repositories; repositories contain only database operations
|
||||||
- Models should NOT import from higher layers
|
- Models should NOT import from higher layers
|
||||||
- Each layer should only depend on the layer directly below it
|
- Each layer should only depend on the layer directly below it
|
||||||
|
|
||||||
@@ -124,7 +122,7 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
# 3. Local application imports
|
# 3. Local application imports
|
||||||
from app.api.dependencies.auth import get_current_user
|
from app.api.dependencies.auth import get_current_user
|
||||||
from app.crud import user_crud
|
from app.api.dependencies.services import get_user_service
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.schemas.users import UserResponse, UserCreate
|
from app.schemas.users import UserResponse, UserCreate
|
||||||
```
|
```
|
||||||
@@ -216,7 +214,7 @@ if not user:
|
|||||||
|
|
||||||
### Error Handling Pattern
|
### Error Handling Pattern
|
||||||
|
|
||||||
Always follow this pattern in CRUD operations (Async version):
|
Always follow this pattern in repository operations (Async version):
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from sqlalchemy.exc import IntegrityError, OperationalError, DataError
|
from sqlalchemy.exc import IntegrityError, OperationalError, DataError
|
||||||
@@ -282,9 +280,154 @@ All error responses follow this structure:
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Data Models and Migrations
|
||||||
|
|
||||||
|
### Model Definition Best Practices
|
||||||
|
|
||||||
|
To ensure Alembic autogenerate works reliably without drift, follow these rules:
|
||||||
|
|
||||||
|
#### 1. Simple Indexes: Use Column-Level or `__table_args__`, Not Both
|
||||||
|
|
||||||
|
```python
|
||||||
|
# ❌ BAD - Creates DUPLICATE indexes with different names
|
||||||
|
class User(Base):
|
||||||
|
role = Column(String(50), index=True) # Creates ix_users_role
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("ix_user_role", "role"), # Creates ANOTHER index!
|
||||||
|
)
|
||||||
|
|
||||||
|
# ✅ GOOD - Choose ONE approach
|
||||||
|
class User(Base):
|
||||||
|
role = Column(String(50)) # No index=True
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("ix_user_role", "role"), # Single index with explicit name
|
||||||
|
)
|
||||||
|
|
||||||
|
# ✅ ALSO GOOD - For simple single-column indexes
|
||||||
|
class User(Base):
|
||||||
|
role = Column(String(50), index=True) # Auto-named ix_users_role
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 2. Composite Indexes: Always Use `__table_args__`
|
||||||
|
|
||||||
|
```python
|
||||||
|
class UserOrganization(Base):
|
||||||
|
__tablename__ = "user_organizations"
|
||||||
|
|
||||||
|
user_id = Column(UUID, nullable=False)
|
||||||
|
organization_id = Column(UUID, nullable=False)
|
||||||
|
is_active = Column(Boolean, default=True, nullable=False, index=True)
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("ix_user_org_user_active", "user_id", "is_active"),
|
||||||
|
Index("ix_user_org_org_active", "organization_id", "is_active"),
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 3. Functional/Partial Indexes: Use `ix_perf_` Prefix
|
||||||
|
|
||||||
|
Alembic **cannot** auto-detect:
|
||||||
|
- **Functional indexes**: `LOWER(column)`, `UPPER(column)`, expressions
|
||||||
|
- **Partial indexes**: Indexes with `WHERE` clauses
|
||||||
|
|
||||||
|
**Solution**: Use the `ix_perf_` naming prefix. Any index with this prefix is automatically excluded from autogenerate by `env.py`.
|
||||||
|
|
||||||
|
```python
|
||||||
|
# In migration file (NOT in model) - use ix_perf_ prefix:
|
||||||
|
op.create_index(
|
||||||
|
"ix_perf_users_email_lower", # <-- ix_perf_ prefix!
|
||||||
|
"users",
|
||||||
|
[sa.text("LOWER(email)")], # Functional
|
||||||
|
postgresql_where=sa.text("deleted_at IS NULL"), # Partial
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**No need to update `env.py`** - the prefix convention handles it automatically:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# env.py - already configured:
|
||||||
|
def include_object(object, name, type_, reflected, compare_to):
|
||||||
|
if type_ == "index" and name:
|
||||||
|
if name.startswith("ix_perf_"): # Auto-excluded!
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
```
|
||||||
|
|
||||||
|
**To add new performance indexes:**
|
||||||
|
1. Create a new migration file
|
||||||
|
2. Name your indexes with `ix_perf_` prefix
|
||||||
|
3. Done - Alembic will ignore them automatically
|
||||||
|
|
||||||
|
#### 4. Use Correct Types
|
||||||
|
|
||||||
|
```python
|
||||||
|
# ✅ GOOD - PostgreSQL-native types
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||||
|
|
||||||
|
class User(Base):
|
||||||
|
id = Column(UUID(as_uuid=True), primary_key=True)
|
||||||
|
preferences = Column(JSONB) # Not JSON!
|
||||||
|
|
||||||
|
# ❌ BAD - Generic types may cause migration drift
|
||||||
|
from sqlalchemy import JSON
|
||||||
|
preferences = Column(JSON) # May detect as different from JSONB
|
||||||
|
```
|
||||||
|
|
||||||
|
### Migration Workflow
|
||||||
|
|
||||||
|
#### Creating Migrations
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Generate autogenerate migration:
|
||||||
|
python migrate.py generate "Add new field"
|
||||||
|
|
||||||
|
# Or inside Docker:
|
||||||
|
docker exec -w /app backend uv run alembic revision --autogenerate -m "Add new field"
|
||||||
|
|
||||||
|
# Apply migration:
|
||||||
|
python migrate.py apply
|
||||||
|
# Or: docker exec -w /app backend uv run alembic upgrade head
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Testing for Drift
|
||||||
|
|
||||||
|
After any model changes, verify no unintended drift:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Generate test migration
|
||||||
|
docker exec -w /app backend uv run alembic revision --autogenerate -m "test_drift"
|
||||||
|
|
||||||
|
# Check the generated file - should be empty (just 'pass')
|
||||||
|
# If it has operations, investigate why
|
||||||
|
|
||||||
|
# Delete test file
|
||||||
|
rm backend/app/alembic/versions/*_test_drift.py
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Migration File Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
backend/app/alembic/versions/
|
||||||
|
├── cbddc8aa6eda_initial_models.py # Auto-generated, tracks all models
|
||||||
|
├── 0002_performance_indexes.py # Manual, functional/partial indexes
|
||||||
|
└── __init__.py
|
||||||
|
```
|
||||||
|
|
||||||
|
### Summary: What Goes Where
|
||||||
|
|
||||||
|
| Index Type | In Model? | Alembic Detects? | Where to Define |
|
||||||
|
|------------|-----------|------------------|-----------------|
|
||||||
|
| Simple column (`index=True`) | Yes | Yes | Column definition |
|
||||||
|
| Composite (`col1, col2`) | Yes | Yes | `__table_args__` |
|
||||||
|
| Unique composite | Yes | Yes | `__table_args__` with `unique=True` |
|
||||||
|
| Functional (`LOWER(col)`) | No | No | Migration with `ix_perf_` prefix |
|
||||||
|
| Partial (`WHERE ...`) | No | No | Migration with `ix_perf_` prefix |
|
||||||
|
|
||||||
## Database Operations
|
## Database Operations
|
||||||
|
|
||||||
### Async CRUD Pattern
|
### Async Repository Pattern
|
||||||
|
|
||||||
**IMPORTANT**: This application uses **async SQLAlchemy** with modern patterns for better performance and testability.
|
**IMPORTANT**: This application uses **async SQLAlchemy** with modern patterns for better performance and testability.
|
||||||
|
|
||||||
@@ -296,19 +439,19 @@ All error responses follow this structure:
|
|||||||
4. **Testability**: Easy to mock and test
|
4. **Testability**: Easy to mock and test
|
||||||
5. **Consistent Ordering**: Always order queries for pagination
|
5. **Consistent Ordering**: Always order queries for pagination
|
||||||
|
|
||||||
### Use the Async CRUD Base Class
|
### Use the Async Repository Base Class
|
||||||
|
|
||||||
Always inherit from `CRUDBase` for database operations:
|
Always inherit from `RepositoryBase` for database operations:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from app.crud.base import CRUDBase
|
from app.repositories.base import RepositoryBase
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.schemas.users import UserCreate, UserUpdate
|
from app.schemas.users import UserCreate, UserUpdate
|
||||||
|
|
||||||
class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
class UserRepository(RepositoryBase[User, UserCreate, UserUpdate]):
|
||||||
"""CRUD operations for User model."""
|
"""Repository for User model — database operations only."""
|
||||||
|
|
||||||
async def get_by_email(
|
async def get_by_email(
|
||||||
self,
|
self,
|
||||||
@@ -321,7 +464,7 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
|||||||
)
|
)
|
||||||
return result.scalar_one_or_none()
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
user_crud = CRUDUser(User)
|
user_repo = UserRepository(User)
|
||||||
```
|
```
|
||||||
|
|
||||||
**Key Points:**
|
**Key Points:**
|
||||||
@@ -330,6 +473,7 @@ user_crud = CRUDUser(User)
|
|||||||
- Use `await db.execute()` for queries
|
- Use `await db.execute()` for queries
|
||||||
- Use `.scalar_one_or_none()` instead of `.first()`
|
- Use `.scalar_one_or_none()` instead of `.first()`
|
||||||
- Use `T | None` instead of `Optional[T]`
|
- Use `T | None` instead of `Optional[T]`
|
||||||
|
- Repository instances are used internally by services — never import them in routes
|
||||||
|
|
||||||
### Modern SQLAlchemy Patterns
|
### Modern SQLAlchemy Patterns
|
||||||
|
|
||||||
@@ -417,13 +561,13 @@ async def create_user(
|
|||||||
The database session is automatically managed by FastAPI.
|
The database session is automatically managed by FastAPI.
|
||||||
Commit on success, rollback on error.
|
Commit on success, rollback on error.
|
||||||
"""
|
"""
|
||||||
return await user_crud.create(db, obj_in=user_in)
|
return await user_service.create_user(db, obj_in=user_in)
|
||||||
```
|
```
|
||||||
|
|
||||||
**Key Points:**
|
**Key Points:**
|
||||||
- Route functions must be `async def`
|
- Route functions must be `async def`
|
||||||
- Database parameter is `AsyncSession`
|
- Database parameter is `AsyncSession`
|
||||||
- Always `await` CRUD operations
|
- Always `await` repository operations
|
||||||
|
|
||||||
#### In Services (Multiple Operations)
|
#### In Services (Multiple Operations)
|
||||||
|
|
||||||
@@ -436,12 +580,11 @@ async def complex_operation(
|
|||||||
"""
|
"""
|
||||||
Perform multiple database operations atomically.
|
Perform multiple database operations atomically.
|
||||||
|
|
||||||
The session automatically commits on success or rolls back on error.
|
Services call repositories; commit/rollback is handled inside
|
||||||
|
each repository method.
|
||||||
"""
|
"""
|
||||||
user = await user_crud.create(db, obj_in=user_data)
|
user = await user_repo.create(db, obj_in=user_data)
|
||||||
session = await session_crud.create(db, obj_in=session_data)
|
session = await session_repo.create(db, obj_in=session_data)
|
||||||
|
|
||||||
# Commit is handled by the route's dependency
|
|
||||||
return user, session
|
return user, session
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -451,10 +594,10 @@ Prefer soft deletes over hard deletes for audit trails:
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
# Good - Soft delete (sets deleted_at)
|
# Good - Soft delete (sets deleted_at)
|
||||||
await user_crud.soft_delete(db, id=user_id)
|
await user_repo.soft_delete(db, id=user_id)
|
||||||
|
|
||||||
# Acceptable only when required - Hard delete
|
# Acceptable only when required - Hard delete
|
||||||
user_crud.remove(db, id=user_id)
|
await user_repo.remove(db, id=user_id)
|
||||||
```
|
```
|
||||||
|
|
||||||
### Query Patterns
|
### Query Patterns
|
||||||
@@ -594,9 +737,10 @@ Always implement pagination for list endpoints:
|
|||||||
from app.schemas.common import PaginationParams, PaginatedResponse
|
from app.schemas.common import PaginationParams, PaginatedResponse
|
||||||
|
|
||||||
@router.get("/users", response_model=PaginatedResponse[UserResponse])
|
@router.get("/users", response_model=PaginatedResponse[UserResponse])
|
||||||
def list_users(
|
async def list_users(
|
||||||
pagination: PaginationParams = Depends(),
|
pagination: PaginationParams = Depends(),
|
||||||
db: Session = Depends(get_db)
|
user_service: UserService = Depends(get_user_service),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
List all users with pagination.
|
List all users with pagination.
|
||||||
@@ -604,10 +748,8 @@ def list_users(
|
|||||||
Default page size: 20
|
Default page size: 20
|
||||||
Maximum page size: 100
|
Maximum page size: 100
|
||||||
"""
|
"""
|
||||||
users, total = user_crud.get_multi_with_total(
|
users, total = await user_service.get_users(
|
||||||
db,
|
db, skip=pagination.offset, limit=pagination.limit
|
||||||
skip=pagination.offset,
|
|
||||||
limit=pagination.limit
|
|
||||||
)
|
)
|
||||||
return PaginatedResponse(data=users, pagination=pagination.create_meta(total))
|
return PaginatedResponse(data=users, pagination=pagination.create_meta(total))
|
||||||
```
|
```
|
||||||
@@ -670,19 +812,17 @@ def admin_route(
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
# Check ownership
|
# Check ownership
|
||||||
def delete_resource(
|
async def delete_resource(
|
||||||
resource_id: UUID,
|
resource_id: UUID,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db)
|
resource_service: ResourceService = Depends(get_resource_service),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
resource = resource_crud.get(db, id=resource_id)
|
# Service handles ownership check and raises appropriate errors
|
||||||
if not resource:
|
await resource_service.delete_resource(
|
||||||
raise NotFoundError("Resource not found")
|
db, resource_id=resource_id, user_id=current_user.id,
|
||||||
|
is_superuser=current_user.is_superuser,
|
||||||
if resource.user_id != current_user.id and not current_user.is_superuser:
|
)
|
||||||
raise AuthorizationError("You can only delete your own resources")
|
|
||||||
|
|
||||||
resource_crud.remove(db, id=resource_id)
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Input Validation
|
### Input Validation
|
||||||
@@ -716,9 +856,9 @@ tests/
|
|||||||
├── api/ # Integration tests
|
├── api/ # Integration tests
|
||||||
│ ├── test_users.py
|
│ ├── test_users.py
|
||||||
│ └── test_auth.py
|
│ └── test_auth.py
|
||||||
├── crud/ # Unit tests for CRUD
|
├── repositories/ # Unit tests for repositories
|
||||||
├── models/ # Model tests
|
├── services/ # Unit tests for services
|
||||||
└── services/ # Service tests
|
└── models/ # Model tests
|
||||||
```
|
```
|
||||||
|
|
||||||
### Async Testing with pytest-asyncio
|
### Async Testing with pytest-asyncio
|
||||||
@@ -781,7 +921,7 @@ async def test_user(db_session: AsyncSession) -> User:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_user(db_session: AsyncSession, test_user: User):
|
async def test_get_user(db_session: AsyncSession, test_user: User):
|
||||||
"""Test retrieving a user by ID."""
|
"""Test retrieving a user by ID."""
|
||||||
user = await user_crud.get(db_session, id=test_user.id)
|
user = await user_repo.get(db_session, id=test_user.id)
|
||||||
assert user is not None
|
assert user is not None
|
||||||
assert user.email == test_user.email
|
assert user.email == test_user.email
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -334,14 +334,14 @@ def login(request: Request, credentials: OAuth2PasswordRequestForm):
|
|||||||
# ❌ WRONG - Returns password hash!
|
# ❌ WRONG - Returns password hash!
|
||||||
@router.get("/users/{user_id}")
|
@router.get("/users/{user_id}")
|
||||||
def get_user(user_id: UUID, db: Session = Depends(get_db)) -> User:
|
def get_user(user_id: UUID, db: Session = Depends(get_db)) -> User:
|
||||||
return user_crud.get(db, id=user_id) # Returns ORM model with ALL fields!
|
return user_repo.get(db, id=user_id) # Returns ORM model with ALL fields!
|
||||||
```
|
```
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# ✅ CORRECT - Use response schema
|
# ✅ CORRECT - Use response schema
|
||||||
@router.get("/users/{user_id}", response_model=UserResponse)
|
@router.get("/users/{user_id}", response_model=UserResponse)
|
||||||
def get_user(user_id: UUID, db: Session = Depends(get_db)):
|
def get_user(user_id: UUID, db: Session = Depends(get_db)):
|
||||||
user = user_crud.get(db, id=user_id)
|
user = user_repo.get(db, id=user_id)
|
||||||
if not user:
|
if not user:
|
||||||
raise HTTPException(status_code=404, detail="User not found")
|
raise HTTPException(status_code=404, detail="User not found")
|
||||||
return user # Pydantic filters to only UserResponse fields
|
return user # Pydantic filters to only UserResponse fields
|
||||||
@@ -506,8 +506,8 @@ def revoke_session(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db)
|
||||||
):
|
):
|
||||||
session = session_crud.get(db, id=session_id)
|
session = session_repo.get(db, id=session_id)
|
||||||
session_crud.deactivate(db, session_id=session_id)
|
session_repo.deactivate(db, session_id=session_id)
|
||||||
# BUG: User can revoke ANYONE'S session!
|
# BUG: User can revoke ANYONE'S session!
|
||||||
return {"message": "Session revoked"}
|
return {"message": "Session revoked"}
|
||||||
```
|
```
|
||||||
@@ -520,7 +520,7 @@ def revoke_session(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db)
|
||||||
):
|
):
|
||||||
session = session_crud.get(db, id=session_id)
|
session = session_repo.get(db, id=session_id)
|
||||||
|
|
||||||
if not session:
|
if not session:
|
||||||
raise NotFoundError("Session not found")
|
raise NotFoundError("Session not found")
|
||||||
@@ -529,7 +529,7 @@ def revoke_session(
|
|||||||
if session.user_id != current_user.id:
|
if session.user_id != current_user.id:
|
||||||
raise AuthorizationError("You can only revoke your own sessions")
|
raise AuthorizationError("You can only revoke your own sessions")
|
||||||
|
|
||||||
session_crud.deactivate(db, session_id=session_id)
|
session_repo.deactivate(db, session_id=session_id)
|
||||||
return {"message": "Session revoked"}
|
return {"message": "Session revoked"}
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -616,7 +616,43 @@ def create_user(
|
|||||||
return user
|
return user
|
||||||
```
|
```
|
||||||
|
|
||||||
**Rule**: Add type hints to ALL functions. Use `mypy` to enforce type checking.
|
**Rule**: Add type hints to ALL functions. Use `pyright` to enforce type checking (`make type-check`).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### ❌ PITFALL #19: Importing Repositories Directly in Routes
|
||||||
|
|
||||||
|
**Issue**: Routes should never call repositories directly. The layered architecture requires all business operations to go through the service layer.
|
||||||
|
|
||||||
|
```python
|
||||||
|
# ❌ WRONG - Route bypasses service layer
|
||||||
|
from app.repositories.session import session_repo
|
||||||
|
|
||||||
|
@router.get("/sessions/me")
|
||||||
|
async def list_sessions(
|
||||||
|
current_user: User = Depends(get_current_active_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
return await session_repo.get_user_sessions(db, user_id=current_user.id)
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
# ✅ CORRECT - Route calls service injected via dependency
|
||||||
|
from app.api.dependencies.services import get_session_service
|
||||||
|
from app.services.session_service import SessionService
|
||||||
|
|
||||||
|
@router.get("/sessions/me")
|
||||||
|
async def list_sessions(
|
||||||
|
current_user: User = Depends(get_current_active_user),
|
||||||
|
session_service: SessionService = Depends(get_session_service),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
return await session_service.get_user_sessions(db, user_id=current_user.id)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Rule**: Routes import from `app.api.dependencies.services`, never from `app.repositories.*`. Services are the only callers of repositories.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -649,6 +685,11 @@ Use this checklist to catch issues before code review:
|
|||||||
- [ ] Resource ownership verification
|
- [ ] Resource ownership verification
|
||||||
- [ ] CORS configured (no wildcards in production)
|
- [ ] CORS configured (no wildcards in production)
|
||||||
|
|
||||||
|
### Architecture
|
||||||
|
- [ ] Routes never import repositories directly (only services)
|
||||||
|
- [ ] Services call repositories; repositories call database only
|
||||||
|
- [ ] New service registered in `app/api/dependencies/services.py`
|
||||||
|
|
||||||
### Python
|
### Python
|
||||||
- [ ] Use `==` not `is` for value comparison
|
- [ ] Use `==` not `is` for value comparison
|
||||||
- [ ] No mutable default arguments
|
- [ ] No mutable default arguments
|
||||||
@@ -661,21 +702,18 @@ Use this checklist to catch issues before code review:
|
|||||||
|
|
||||||
### Pre-commit Checks
|
### Pre-commit Checks
|
||||||
|
|
||||||
Add these to your development workflow:
|
Add these to your development workflow (or use `make validate`):
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Format code
|
# Format + lint (Ruff replaces Black, isort, flake8)
|
||||||
black app tests
|
uv run ruff format app tests
|
||||||
isort app tests
|
uv run ruff check app tests
|
||||||
|
|
||||||
# Type checking
|
# Type checking
|
||||||
mypy app --strict
|
uv run pyright app
|
||||||
|
|
||||||
# Linting
|
|
||||||
flake8 app tests
|
|
||||||
|
|
||||||
# Run tests
|
# Run tests
|
||||||
pytest --cov=app --cov-report=term-missing
|
IS_TEST=True uv run pytest --cov=app --cov-report=term-missing
|
||||||
|
|
||||||
# Check coverage (should be 80%+)
|
# Check coverage (should be 80%+)
|
||||||
coverage report --fail-under=80
|
coverage report --fail-under=80
|
||||||
@@ -693,6 +731,6 @@ Add new entries when:
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
**Last Updated**: 2025-10-31
|
**Last Updated**: 2026-02-28
|
||||||
**Issues Cataloged**: 18 common pitfalls
|
**Issues Cataloged**: 19 common pitfalls
|
||||||
**Remember**: This document exists because these issues HAVE occurred. Don't skip it.
|
**Remember**: This document exists because these issues HAVE occurred. Don't skip it.
|
||||||
|
|||||||
348
backend/docs/E2E_TESTING.md
Normal file
348
backend/docs/E2E_TESTING.md
Normal file
@@ -0,0 +1,348 @@
|
|||||||
|
# Backend E2E Testing Guide
|
||||||
|
|
||||||
|
End-to-end testing infrastructure using **Testcontainers** (real PostgreSQL) and **Schemathesis** (OpenAPI contract testing).
|
||||||
|
|
||||||
|
## Table of Contents
|
||||||
|
|
||||||
|
- [Quick Start](#quick-start)
|
||||||
|
- [Requirements](#requirements)
|
||||||
|
- [How It Works](#how-it-works)
|
||||||
|
- [Test Organization](#test-organization)
|
||||||
|
- [Writing E2E Tests](#writing-e2e-tests)
|
||||||
|
- [Running Tests](#running-tests)
|
||||||
|
- [Troubleshooting](#troubleshooting)
|
||||||
|
- [CI/CD Integration](#cicd-integration)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 1. Install E2E dependencies
|
||||||
|
make install-e2e
|
||||||
|
|
||||||
|
# 2. Ensure Docker is running
|
||||||
|
make check-docker
|
||||||
|
|
||||||
|
# 3. Run E2E tests
|
||||||
|
make test-e2e
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
|
||||||
|
### Docker
|
||||||
|
|
||||||
|
E2E tests use Testcontainers to spin up real PostgreSQL containers. Docker must be running:
|
||||||
|
|
||||||
|
- **macOS/Windows**: Docker Desktop
|
||||||
|
- **Linux**: Docker Engine (`sudo systemctl start docker`)
|
||||||
|
|
||||||
|
### Dependencies
|
||||||
|
|
||||||
|
E2E tests require additional packages beyond the standard dev dependencies:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Install E2E dependencies
|
||||||
|
make install-e2e
|
||||||
|
|
||||||
|
# Or manually:
|
||||||
|
uv sync --extra dev --extra e2e
|
||||||
|
```
|
||||||
|
|
||||||
|
This installs:
|
||||||
|
- `testcontainers[postgres]>=4.0.0` - Docker container management
|
||||||
|
- `schemathesis>=3.30.0` - OpenAPI contract testing
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## How It Works
|
||||||
|
|
||||||
|
### Testcontainers
|
||||||
|
|
||||||
|
Testcontainers automatically manages Docker containers for tests:
|
||||||
|
|
||||||
|
1. **Session-scoped container**: A single PostgreSQL 17 container starts once per test session
|
||||||
|
2. **Function-scoped isolation**: Each test gets fresh tables (drop + recreate)
|
||||||
|
3. **Automatic cleanup**: Container is destroyed when tests complete
|
||||||
|
|
||||||
|
This approach catches bugs that SQLite-based tests miss:
|
||||||
|
- PostgreSQL-specific SQL behavior
|
||||||
|
- Real constraint violations
|
||||||
|
- Actual transaction semantics
|
||||||
|
- JSONB column behavior
|
||||||
|
|
||||||
|
### Schemathesis
|
||||||
|
|
||||||
|
Schemathesis generates test cases from your OpenAPI schema:
|
||||||
|
|
||||||
|
1. **Schema loading**: Reads `/api/v1/openapi.json` from your FastAPI app
|
||||||
|
2. **Test generation**: Creates test cases for each endpoint
|
||||||
|
3. **Response validation**: Verifies responses match documented schema
|
||||||
|
|
||||||
|
This catches:
|
||||||
|
- Undocumented response codes
|
||||||
|
- Schema mismatches (wrong types, missing fields)
|
||||||
|
- Edge cases in input validation
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Test Organization
|
||||||
|
|
||||||
|
```
|
||||||
|
backend/tests/
|
||||||
|
├── e2e/ # E2E tests (PostgreSQL, Docker required)
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── conftest.py # Testcontainers fixtures
|
||||||
|
│ ├── test_api_contracts.py # Schemathesis schema tests
|
||||||
|
│ └── test_database_workflows.py # PostgreSQL workflow tests
|
||||||
|
│
|
||||||
|
├── api/ # Integration tests (SQLite, fast)
|
||||||
|
├── repositories/ # Repository unit tests
|
||||||
|
└── conftest.py # Standard fixtures
|
||||||
|
```
|
||||||
|
|
||||||
|
### Test Markers
|
||||||
|
|
||||||
|
Tests use pytest markers for filtering:
|
||||||
|
|
||||||
|
| Marker | Description |
|
||||||
|
|--------|-------------|
|
||||||
|
| `@pytest.mark.e2e` | End-to-end test requiring Docker |
|
||||||
|
| `@pytest.mark.postgres` | PostgreSQL-specific test |
|
||||||
|
| `@pytest.mark.schemathesis` | Schemathesis schema test |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Writing E2E Tests
|
||||||
|
|
||||||
|
### Basic E2E Test
|
||||||
|
|
||||||
|
```python
|
||||||
|
import pytest
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
@pytest.mark.e2e
|
||||||
|
@pytest.mark.postgres
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_user_workflow(e2e_client):
|
||||||
|
"""Test user registration with real PostgreSQL."""
|
||||||
|
email = f"test-{uuid4().hex[:8]}@example.com"
|
||||||
|
|
||||||
|
response = await e2e_client.post(
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
json={
|
||||||
|
"email": email,
|
||||||
|
"password": "SecurePassword123!",
|
||||||
|
"first_name": "Test",
|
||||||
|
"last_name": "User",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code in [200, 201]
|
||||||
|
assert response.json()["email"] == email
|
||||||
|
```
|
||||||
|
|
||||||
|
### Available Fixtures
|
||||||
|
|
||||||
|
| Fixture | Scope | Description |
|
||||||
|
|---------|-------|-------------|
|
||||||
|
| `postgres_container` | session | Raw Testcontainers PostgreSQL container |
|
||||||
|
| `async_postgres_url` | session | Asyncpg-compatible connection URL |
|
||||||
|
| `e2e_db_session` | function | SQLAlchemy AsyncSession with fresh tables |
|
||||||
|
| `e2e_client` | function | httpx AsyncClient connected to real DB |
|
||||||
|
|
||||||
|
### Schemathesis Test
|
||||||
|
|
||||||
|
```python
|
||||||
|
import pytest
|
||||||
|
import schemathesis
|
||||||
|
from hypothesis import settings, Phase
|
||||||
|
|
||||||
|
from app.main import app
|
||||||
|
|
||||||
|
schema = schemathesis.from_asgi("/api/v1/openapi.json", app=app)
|
||||||
|
|
||||||
|
@pytest.mark.e2e
|
||||||
|
@pytest.mark.schemathesis
|
||||||
|
@schema.parametrize(endpoint="/api/v1/auth/register")
|
||||||
|
@settings(max_examples=20)
|
||||||
|
def test_registration_schema(case):
|
||||||
|
"""Test registration endpoint conforms to schema."""
|
||||||
|
response = case.call_asgi()
|
||||||
|
case.validate_response(response)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Running Tests
|
||||||
|
|
||||||
|
### Commands
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Run all E2E tests
|
||||||
|
make test-e2e
|
||||||
|
|
||||||
|
# Run only Schemathesis schema tests
|
||||||
|
make test-e2e-schema
|
||||||
|
|
||||||
|
# Run all tests (unit + integration + E2E)
|
||||||
|
make test-all
|
||||||
|
|
||||||
|
# Check Docker availability
|
||||||
|
make check-docker
|
||||||
|
```
|
||||||
|
|
||||||
|
### Direct pytest
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# All E2E tests
|
||||||
|
IS_TEST=True PYTHONPATH=. uv run pytest tests/e2e/ -v
|
||||||
|
|
||||||
|
# Only PostgreSQL tests
|
||||||
|
IS_TEST=True PYTHONPATH=. uv run pytest tests/e2e/ -v -m postgres
|
||||||
|
|
||||||
|
# Only Schemathesis tests
|
||||||
|
IS_TEST=True PYTHONPATH=. uv run pytest tests/e2e/ -v -m schemathesis
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### Docker Not Running
|
||||||
|
|
||||||
|
**Error:**
|
||||||
|
```
|
||||||
|
Docker is not running!
|
||||||
|
E2E tests require Docker to be running.
|
||||||
|
```
|
||||||
|
|
||||||
|
**Solution:**
|
||||||
|
```bash
|
||||||
|
# macOS/Windows
|
||||||
|
# Open Docker Desktop
|
||||||
|
|
||||||
|
# Linux
|
||||||
|
sudo systemctl start docker
|
||||||
|
```
|
||||||
|
|
||||||
|
### Testcontainers Not Installed
|
||||||
|
|
||||||
|
**Error:**
|
||||||
|
```
|
||||||
|
SKIPPED: testcontainers not installed - run: make install-e2e
|
||||||
|
```
|
||||||
|
|
||||||
|
**Solution:**
|
||||||
|
```bash
|
||||||
|
make install-e2e
|
||||||
|
# Or: uv sync --extra dev --extra e2e
|
||||||
|
```
|
||||||
|
|
||||||
|
### Container Startup Timeout
|
||||||
|
|
||||||
|
**Error:**
|
||||||
|
```
|
||||||
|
testcontainers.core.waiting_utils.UnexpectedResponse
|
||||||
|
```
|
||||||
|
|
||||||
|
**Solutions:**
|
||||||
|
1. Increase Docker resources (memory, CPU)
|
||||||
|
2. Pull the image manually: `docker pull postgres:17-alpine`
|
||||||
|
3. Check Docker daemon logs: `docker logs`
|
||||||
|
|
||||||
|
### Port Conflicts
|
||||||
|
|
||||||
|
**Error:**
|
||||||
|
```
|
||||||
|
Error starting container: port is already allocated
|
||||||
|
```
|
||||||
|
|
||||||
|
**Solution:**
|
||||||
|
Testcontainers uses random ports, so conflicts are rare. If occurring:
|
||||||
|
1. Stop other PostgreSQL containers: `docker stop $(docker ps -q)`
|
||||||
|
2. Check for orphaned containers: `docker container prune`
|
||||||
|
|
||||||
|
### Ryuk/Reaper Port 8080 Issues
|
||||||
|
|
||||||
|
**Error:**
|
||||||
|
```
|
||||||
|
ConnectionError: Port mapping for container ... and port 8080 is not available
|
||||||
|
```
|
||||||
|
|
||||||
|
**Solution:**
|
||||||
|
This is related to the Testcontainers Reaper (Ryuk) which handles automatic cleanup.
|
||||||
|
The `conftest.py` automatically disables Ryuk to avoid this issue. If you still encounter
|
||||||
|
this error, ensure you're using the latest conftest.py or set the environment variable:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export TESTCONTAINERS_RYUK_DISABLED=true
|
||||||
|
```
|
||||||
|
|
||||||
|
### Parallel Test Execution Issues
|
||||||
|
|
||||||
|
**Error:**
|
||||||
|
```
|
||||||
|
ScopeMismatch: ... cannot use a higher-scoped fixture 'postgres_container'
|
||||||
|
```
|
||||||
|
|
||||||
|
**Solution:**
|
||||||
|
E2E tests must run sequentially (not in parallel) because they share a session-scoped
|
||||||
|
PostgreSQL container. The Makefile commands use `-n 0` to disable parallel execution.
|
||||||
|
If running pytest directly, add `-n 0`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
IS_TEST=True PYTHONPATH=. uv run pytest tests/e2e/ -v -n 0
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## CI/CD Integration
|
||||||
|
|
||||||
|
### GitHub Actions
|
||||||
|
|
||||||
|
A workflow template is provided at `.github/workflows/backend-e2e-tests.yml.template`.
|
||||||
|
|
||||||
|
To enable:
|
||||||
|
1. Rename to `backend-e2e-tests.yml`
|
||||||
|
2. Push to repository
|
||||||
|
|
||||||
|
The workflow:
|
||||||
|
- Runs on pushes to `main`/`develop` affecting `backend/`
|
||||||
|
- Uses `continue-on-error: true` (E2E failures don't block merge)
|
||||||
|
- Caches uv dependencies for speed
|
||||||
|
|
||||||
|
### Local CI Simulation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Run what CI runs
|
||||||
|
make test-all
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Best Practices
|
||||||
|
|
||||||
|
### DO
|
||||||
|
|
||||||
|
- Use unique emails per test: `f"test-{uuid4().hex[:8]}@example.com"`
|
||||||
|
- Mark tests with appropriate markers: `@pytest.mark.e2e`
|
||||||
|
- Keep E2E tests focused on critical workflows
|
||||||
|
- Use `e2e_client` fixture for most tests
|
||||||
|
|
||||||
|
### DON'T
|
||||||
|
|
||||||
|
- Share state between tests (each test gets fresh tables)
|
||||||
|
- Test every endpoint in E2E (use unit tests for edge cases)
|
||||||
|
- Skip the `IS_TEST=True` environment variable
|
||||||
|
- Run E2E tests without Docker
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Further Reading
|
||||||
|
|
||||||
|
- [Testcontainers Documentation](https://testcontainers.com/guides/getting-started-with-testcontainers-for-python/)
|
||||||
|
- [Schemathesis Documentation](https://schemathesis.readthedocs.io/)
|
||||||
|
- [pytest-asyncio Documentation](https://pytest-asyncio.readthedocs.io/)
|
||||||
File diff suppressed because it is too large
Load Diff
16
backend/entrypoint.sh
Normal file → Executable file
16
backend/entrypoint.sh
Normal file → Executable file
@@ -1,12 +1,22 @@
|
|||||||
#!/bin/bash
|
#!/bin/sh
|
||||||
set -e
|
set -e
|
||||||
echo "Starting Backend"
|
echo "Starting Backend"
|
||||||
|
|
||||||
|
# Ensure the project's virtualenv binaries are on PATH so commands like
|
||||||
|
# 'uvicorn' work even when not prefixed by 'uv run'. This matches how uv
|
||||||
|
# installs the env into /app/.venv in our containers.
|
||||||
|
if [ -d "/app/.venv/bin" ]; then
|
||||||
|
export PATH="/app/.venv/bin:$PATH"
|
||||||
|
fi
|
||||||
|
|
||||||
# Apply database migrations
|
# Apply database migrations
|
||||||
uv run alembic upgrade head
|
# Avoid installing the project in editable mode (which tries to write egg-info)
|
||||||
|
# when running inside a bind-mounted volume with restricted permissions.
|
||||||
|
# See: https://github.com/astral-sh/uv (use --no-project to skip project build)
|
||||||
|
uv run --no-project alembic upgrade head
|
||||||
|
|
||||||
# Initialize database (creates first superuser if needed)
|
# Initialize database (creates first superuser if needed)
|
||||||
uv run python app/init_db.py
|
uv run --no-project python app/init_db.py
|
||||||
|
|
||||||
# Execute the command passed to docker run
|
# Execute the command passed to docker run
|
||||||
exec "$@"
|
exec "$@"
|
||||||
@@ -2,8 +2,32 @@
|
|||||||
"""
|
"""
|
||||||
Database migration helper script.
|
Database migration helper script.
|
||||||
Provides convenient commands for generating and applying Alembic migrations.
|
Provides convenient commands for generating and applying Alembic migrations.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
# Generate migration (auto-increments revision ID: 0001, 0002, etc.)
|
||||||
|
python migrate.py --local generate "Add new field"
|
||||||
|
python migrate.py --local auto "Add new field"
|
||||||
|
|
||||||
|
# Apply migrations
|
||||||
|
python migrate.py --local apply
|
||||||
|
|
||||||
|
# Show next revision ID
|
||||||
|
python migrate.py next
|
||||||
|
|
||||||
|
# Reset after deleting migrations (clears alembic_version table)
|
||||||
|
python migrate.py --local reset
|
||||||
|
|
||||||
|
# Override auto-increment with custom revision ID
|
||||||
|
python migrate.py --local generate "initial_models" --rev-id custom_id
|
||||||
|
|
||||||
|
# Generate empty migration template without database (no autogenerate)
|
||||||
|
python migrate.py generate "Add performance indexes" --offline
|
||||||
|
|
||||||
|
# Inside Docker (without --local flag):
|
||||||
|
python migrate.py auto "Add new field"
|
||||||
"""
|
"""
|
||||||
import argparse
|
import argparse
|
||||||
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -13,15 +37,21 @@ project_root = Path(__file__).resolve().parent
|
|||||||
if str(project_root) not in sys.path:
|
if str(project_root) not in sys.path:
|
||||||
sys.path.append(str(project_root))
|
sys.path.append(str(project_root))
|
||||||
|
|
||||||
try:
|
|
||||||
# Import settings to check if configuration is working
|
|
||||||
from app.core.config import settings
|
|
||||||
|
|
||||||
print(f"Using database URL: {settings.database_url}")
|
def setup_database_url(use_local: bool) -> str:
|
||||||
except ImportError as e:
|
"""Setup database URL, optionally using localhost for local development."""
|
||||||
print(f"Error importing settings: {e}")
|
if use_local:
|
||||||
print("Make sure your Python path includes the project root.")
|
# Override DATABASE_URL to use localhost instead of Docker hostname
|
||||||
sys.exit(1)
|
local_url = os.environ.get(
|
||||||
|
"LOCAL_DATABASE_URL",
|
||||||
|
"postgresql://postgres:postgres@localhost:5432/app"
|
||||||
|
)
|
||||||
|
os.environ["DATABASE_URL"] = local_url
|
||||||
|
return local_url
|
||||||
|
|
||||||
|
# Use the configured DATABASE_URL from environment/.env
|
||||||
|
from app.core.config import settings
|
||||||
|
return settings.database_url
|
||||||
|
|
||||||
|
|
||||||
def check_models():
|
def check_models():
|
||||||
@@ -40,11 +70,30 @@ def check_models():
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def generate_migration(message):
|
def generate_migration(message, rev_id=None, auto_rev_id=True, offline=False):
|
||||||
"""Generate an Alembic migration with the given message"""
|
"""Generate an Alembic migration with the given message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Migration message
|
||||||
|
rev_id: Custom revision ID (overrides auto_rev_id)
|
||||||
|
auto_rev_id: If True and rev_id is None, auto-generate sequential ID
|
||||||
|
offline: If True, generate empty migration without database (no autogenerate)
|
||||||
|
"""
|
||||||
|
# Auto-generate sequential revision ID if not provided
|
||||||
|
if rev_id is None and auto_rev_id:
|
||||||
|
rev_id = get_next_rev_id()
|
||||||
|
|
||||||
print(f"Generating migration: {message}")
|
print(f"Generating migration: {message}")
|
||||||
|
if rev_id:
|
||||||
|
print(f"Using revision ID: {rev_id}")
|
||||||
|
|
||||||
|
if offline:
|
||||||
|
# Generate migration file directly without database connection
|
||||||
|
return generate_offline_migration(message, rev_id)
|
||||||
|
|
||||||
cmd = ["alembic", "revision", "--autogenerate", "-m", message]
|
cmd = ["alembic", "revision", "--autogenerate", "-m", message]
|
||||||
|
if rev_id:
|
||||||
|
cmd.extend(["--rev-id", rev_id])
|
||||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||||
|
|
||||||
print(result.stdout)
|
print(result.stdout)
|
||||||
@@ -64,8 +113,9 @@ def generate_migration(message):
|
|||||||
if len(part) >= 12 and all(c in "0123456789abcdef" for c in part[:12]):
|
if len(part) >= 12 and all(c in "0123456789abcdef" for c in part[:12]):
|
||||||
revision = part[:12]
|
revision = part[:12]
|
||||||
break
|
break
|
||||||
except Exception:
|
except Exception as e:
|
||||||
pass
|
# If parsing fails, we can still proceed without a detected revision
|
||||||
|
print(f"Warning: could not parse revision from line '{line}': {e}")
|
||||||
|
|
||||||
if revision:
|
if revision:
|
||||||
print(f"Generated revision: {revision}")
|
print(f"Generated revision: {revision}")
|
||||||
@@ -131,8 +181,14 @@ def check_database_connection():
|
|||||||
from sqlalchemy.exc import SQLAlchemyError
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
|
||||||
try:
|
try:
|
||||||
engine = create_engine(settings.database_url)
|
# Use DATABASE_URL from environment (set by setup_database_url)
|
||||||
with engine.connect() as conn:
|
db_url = os.environ.get("DATABASE_URL")
|
||||||
|
if not db_url:
|
||||||
|
from app.core.config import settings
|
||||||
|
db_url = settings.database_url
|
||||||
|
|
||||||
|
engine = create_engine(db_url)
|
||||||
|
with engine.connect():
|
||||||
print("✓ Database connection successful!")
|
print("✓ Database connection successful!")
|
||||||
return True
|
return True
|
||||||
except SQLAlchemyError as e:
|
except SQLAlchemyError as e:
|
||||||
@@ -140,16 +196,172 @@ def check_database_connection():
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_next_rev_id():
|
||||||
|
"""Get the next sequential revision ID based on existing migrations."""
|
||||||
|
import re
|
||||||
|
|
||||||
|
versions_dir = project_root / "app" / "alembic" / "versions"
|
||||||
|
if not versions_dir.exists():
|
||||||
|
return "0001"
|
||||||
|
|
||||||
|
# Find all migration files with numeric prefixes
|
||||||
|
max_num = 0
|
||||||
|
pattern = re.compile(r"^(\d{4})_.*\.py$")
|
||||||
|
|
||||||
|
for f in versions_dir.iterdir():
|
||||||
|
if f.is_file() and f.suffix == ".py":
|
||||||
|
match = pattern.match(f.name)
|
||||||
|
if match:
|
||||||
|
num = int(match.group(1))
|
||||||
|
max_num = max(max_num, num)
|
||||||
|
|
||||||
|
next_num = max_num + 1
|
||||||
|
return f"{next_num:04d}"
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_rev_id():
|
||||||
|
"""Get the current (latest) revision ID from existing migrations."""
|
||||||
|
import re
|
||||||
|
|
||||||
|
versions_dir = project_root / "app" / "alembic" / "versions"
|
||||||
|
if not versions_dir.exists():
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Find all migration files with numeric prefixes and get the highest
|
||||||
|
max_num = 0
|
||||||
|
max_rev_id = None
|
||||||
|
pattern = re.compile(r"^(\d{4})_.*\.py$")
|
||||||
|
|
||||||
|
for f in versions_dir.iterdir():
|
||||||
|
if f.is_file() and f.suffix == ".py":
|
||||||
|
match = pattern.match(f.name)
|
||||||
|
if match:
|
||||||
|
num = int(match.group(1))
|
||||||
|
if num > max_num:
|
||||||
|
max_num = num
|
||||||
|
max_rev_id = match.group(1)
|
||||||
|
|
||||||
|
return max_rev_id
|
||||||
|
|
||||||
|
|
||||||
|
def generate_offline_migration(message, rev_id):
|
||||||
|
"""Generate a migration file without database connection.
|
||||||
|
|
||||||
|
Creates an empty migration template that can be filled in manually.
|
||||||
|
Useful for performance indexes or when database is not available.
|
||||||
|
"""
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
versions_dir = project_root / "app" / "alembic" / "versions"
|
||||||
|
versions_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Slugify the message for filename
|
||||||
|
slug = message.lower().replace(" ", "_").replace("-", "_")
|
||||||
|
slug = "".join(c for c in slug if c.isalnum() or c == "_")
|
||||||
|
|
||||||
|
filename = f"{rev_id}_{slug}.py"
|
||||||
|
filepath = versions_dir / filename
|
||||||
|
|
||||||
|
# Get the previous revision ID
|
||||||
|
down_revision = get_current_rev_id()
|
||||||
|
down_rev_str = f'"{down_revision}"' if down_revision else "None"
|
||||||
|
|
||||||
|
# Generate the migration file content
|
||||||
|
content = f'''"""{message}
|
||||||
|
|
||||||
|
Revision ID: {rev_id}
|
||||||
|
Revises: {down_revision or ''}
|
||||||
|
Create Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')}
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "{rev_id}"
|
||||||
|
down_revision: str | None = {down_rev_str}
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# TODO: Add your upgrade operations here
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# TODO: Add your downgrade operations here
|
||||||
|
pass
|
||||||
|
'''
|
||||||
|
|
||||||
|
filepath.write_text(content)
|
||||||
|
print(f"Generated offline migration: {filepath}")
|
||||||
|
return rev_id
|
||||||
|
|
||||||
|
|
||||||
|
def show_next_rev_id():
|
||||||
|
"""Show the next sequential revision ID."""
|
||||||
|
next_id = get_next_rev_id()
|
||||||
|
print(f"Next revision ID: {next_id}")
|
||||||
|
print(f"\nUsage:")
|
||||||
|
print(f" python migrate.py --local generate 'your_message' --rev-id {next_id}")
|
||||||
|
print(f" python migrate.py --local auto 'your_message' --rev-id {next_id}")
|
||||||
|
return next_id
|
||||||
|
|
||||||
|
|
||||||
|
def reset_alembic_version():
|
||||||
|
"""Reset the alembic_version table (for fresh start after deleting migrations)."""
|
||||||
|
from sqlalchemy import create_engine, text
|
||||||
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
|
||||||
|
db_url = os.environ.get("DATABASE_URL")
|
||||||
|
if not db_url:
|
||||||
|
from app.core.config import settings
|
||||||
|
db_url = settings.database_url
|
||||||
|
|
||||||
|
try:
|
||||||
|
engine = create_engine(db_url)
|
||||||
|
with engine.connect() as conn:
|
||||||
|
conn.execute(text("DROP TABLE IF EXISTS alembic_version"))
|
||||||
|
conn.commit()
|
||||||
|
print("✓ Alembic version table reset successfully")
|
||||||
|
print(" You can now run migrations from scratch")
|
||||||
|
return True
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
print(f"✗ Error resetting alembic version: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""Main function"""
|
"""Main function"""
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description='Database migration helper for FastNext template'
|
description='Database migration helper for Generative Models Arena'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Global options
|
||||||
|
parser.add_argument(
|
||||||
|
'--local', '-l',
|
||||||
|
action='store_true',
|
||||||
|
help='Use localhost instead of Docker hostname (for local development)'
|
||||||
|
)
|
||||||
|
|
||||||
subparsers = parser.add_subparsers(dest='command', help='Command to run')
|
subparsers = parser.add_subparsers(dest='command', help='Command to run')
|
||||||
|
|
||||||
# Generate command
|
# Generate command
|
||||||
generate_parser = subparsers.add_parser('generate', help='Generate a migration')
|
generate_parser = subparsers.add_parser('generate', help='Generate a migration')
|
||||||
generate_parser.add_argument('message', help='Migration message')
|
generate_parser.add_argument('message', help='Migration message')
|
||||||
|
generate_parser.add_argument(
|
||||||
|
'--rev-id',
|
||||||
|
help='Custom revision ID (e.g., 0001, 0002 for sequential naming)'
|
||||||
|
)
|
||||||
|
generate_parser.add_argument(
|
||||||
|
'--offline',
|
||||||
|
action='store_true',
|
||||||
|
help='Generate empty migration template without database connection'
|
||||||
|
)
|
||||||
|
|
||||||
# Apply command
|
# Apply command
|
||||||
apply_parser = subparsers.add_parser('apply', help='Apply migrations')
|
apply_parser = subparsers.add_parser('apply', help='Apply migrations')
|
||||||
@@ -164,15 +376,56 @@ def main():
|
|||||||
# Check command
|
# Check command
|
||||||
subparsers.add_parser('check', help='Check database connection and models')
|
subparsers.add_parser('check', help='Check database connection and models')
|
||||||
|
|
||||||
|
# Next command (show next revision ID)
|
||||||
|
subparsers.add_parser('next', help='Show the next sequential revision ID')
|
||||||
|
|
||||||
|
# Reset command (clear alembic_version table)
|
||||||
|
subparsers.add_parser(
|
||||||
|
'reset',
|
||||||
|
help='Reset alembic_version table (use after deleting all migrations)'
|
||||||
|
)
|
||||||
|
|
||||||
# Auto command (generate and apply)
|
# Auto command (generate and apply)
|
||||||
auto_parser = subparsers.add_parser('auto', help='Generate and apply migration')
|
auto_parser = subparsers.add_parser('auto', help='Generate and apply migration')
|
||||||
auto_parser.add_argument('message', help='Migration message')
|
auto_parser.add_argument('message', help='Migration message')
|
||||||
|
auto_parser.add_argument(
|
||||||
|
'--rev-id',
|
||||||
|
help='Custom revision ID (e.g., 0001, 0002 for sequential naming)'
|
||||||
|
)
|
||||||
|
auto_parser.add_argument(
|
||||||
|
'--offline',
|
||||||
|
action='store_true',
|
||||||
|
help='Generate empty migration template without database connection'
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Commands that don't need database connection
|
||||||
|
if args.command == 'next':
|
||||||
|
show_next_rev_id()
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check if offline mode is requested
|
||||||
|
offline = getattr(args, 'offline', False)
|
||||||
|
|
||||||
|
# Offline generate doesn't need database or model check
|
||||||
|
if args.command == 'generate' and offline:
|
||||||
|
generate_migration(args.message, rev_id=args.rev_id, offline=True)
|
||||||
|
return
|
||||||
|
|
||||||
|
if args.command == 'auto' and offline:
|
||||||
|
generate_migration(args.message, rev_id=args.rev_id, offline=True)
|
||||||
|
print("\nOffline migration generated. Apply it later with:")
|
||||||
|
print(f" python migrate.py --local apply")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Setup database URL (must be done before importing settings elsewhere)
|
||||||
|
db_url = setup_database_url(args.local)
|
||||||
|
print(f"Using database URL: {db_url}")
|
||||||
|
|
||||||
if args.command == 'generate':
|
if args.command == 'generate':
|
||||||
check_models()
|
check_models()
|
||||||
generate_migration(args.message)
|
generate_migration(args.message, rev_id=args.rev_id)
|
||||||
|
|
||||||
elif args.command == 'apply':
|
elif args.command == 'apply':
|
||||||
apply_migration(args.revision)
|
apply_migration(args.revision)
|
||||||
@@ -187,11 +440,14 @@ def main():
|
|||||||
check_database_connection()
|
check_database_connection()
|
||||||
check_models()
|
check_models()
|
||||||
|
|
||||||
|
elif args.command == 'reset':
|
||||||
|
reset_alembic_version()
|
||||||
|
|
||||||
elif args.command == 'auto':
|
elif args.command == 'auto':
|
||||||
check_models()
|
check_models()
|
||||||
revision = generate_migration(args.message)
|
revision = generate_migration(args.message, rev_id=args.rev_id)
|
||||||
if revision:
|
if revision:
|
||||||
proceed = input("\nPress Enter to apply migration or Ctrl+C to abort... ")
|
input("\nPress Enter to apply migration or Ctrl+C to abort... ")
|
||||||
apply_migration()
|
apply_migration()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -20,40 +20,36 @@ dependencies = [
|
|||||||
"uvicorn>=0.34.0",
|
"uvicorn>=0.34.0",
|
||||||
"pydantic>=2.10.6",
|
"pydantic>=2.10.6",
|
||||||
"pydantic-settings>=2.2.1",
|
"pydantic-settings>=2.2.1",
|
||||||
"python-multipart>=0.0.19",
|
"python-multipart>=0.0.22",
|
||||||
"fastapi-utils==0.8.0",
|
"fastapi-utils==0.8.0",
|
||||||
|
|
||||||
# Database
|
# Database
|
||||||
"sqlalchemy>=2.0.29",
|
"sqlalchemy>=2.0.29",
|
||||||
"alembic>=1.14.1",
|
"alembic>=1.14.1",
|
||||||
"psycopg2-binary>=2.9.9",
|
"psycopg2-binary>=2.9.9",
|
||||||
"asyncpg>=0.29.0",
|
"asyncpg>=0.29.0",
|
||||||
"aiosqlite==0.21.0",
|
"aiosqlite==0.21.0",
|
||||||
|
|
||||||
# Environment configuration
|
# Environment configuration
|
||||||
"python-dotenv>=1.0.1",
|
"python-dotenv>=1.0.1",
|
||||||
|
|
||||||
# API utilities
|
# API utilities
|
||||||
"email-validator>=2.1.0.post1",
|
"email-validator>=2.1.0.post1",
|
||||||
"ujson>=5.9.0",
|
"ujson>=5.9.0",
|
||||||
|
|
||||||
# CORS and security
|
# CORS and security
|
||||||
"starlette>=0.40.0",
|
"starlette>=0.40.0",
|
||||||
"starlette-csrf>=1.4.5",
|
"starlette-csrf>=1.4.5",
|
||||||
"slowapi>=0.1.9",
|
"slowapi>=0.1.9",
|
||||||
|
|
||||||
# Utilities
|
# Utilities
|
||||||
"httpx>=0.27.0",
|
"httpx>=0.27.0",
|
||||||
"tenacity>=8.2.3",
|
"tenacity>=8.2.3",
|
||||||
"pytz>=2024.1",
|
"pytz>=2024.1",
|
||||||
"pillow>=10.3.0",
|
"pillow>=12.1.1",
|
||||||
"apscheduler==3.11.0",
|
"apscheduler==3.11.0",
|
||||||
|
# Security and authentication
|
||||||
# Security and authentication (pinned for reproducibility)
|
"PyJWT>=2.9.0",
|
||||||
"python-jose==3.4.0",
|
|
||||||
"passlib==1.7.4",
|
|
||||||
"bcrypt==4.2.1",
|
"bcrypt==4.2.1",
|
||||||
"cryptography==44.0.1",
|
"cryptography>=46.0.5",
|
||||||
|
# OAuth authentication
|
||||||
|
"authlib>=1.6.6",
|
||||||
|
"urllib3>=2.6.3",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Development dependencies
|
# Development dependencies
|
||||||
@@ -69,7 +65,24 @@ dev = [
|
|||||||
|
|
||||||
# Development tools
|
# Development tools
|
||||||
"ruff>=0.8.0", # All-in-one: linting, formatting, import sorting
|
"ruff>=0.8.0", # All-in-one: linting, formatting, import sorting
|
||||||
"mypy>=1.8.0", # Type checking
|
"pyright>=1.1.390", # Type checking
|
||||||
|
|
||||||
|
# Security auditing
|
||||||
|
"pip-audit>=2.7.0", # Dependency vulnerability scanning (PyPA/OSV)
|
||||||
|
"pip-licenses>=4.0.0", # License compliance checking
|
||||||
|
"detect-secrets>=1.5.0", # Hardcoded secrets detection
|
||||||
|
|
||||||
|
# Performance benchmarking
|
||||||
|
"pytest-benchmark>=4.0.0", # Performance regression detection
|
||||||
|
|
||||||
|
# Pre-commit hooks
|
||||||
|
"pre-commit>=4.0.0", # Git pre-commit hook framework
|
||||||
|
]
|
||||||
|
|
||||||
|
# E2E testing with real PostgreSQL (requires Docker)
|
||||||
|
e2e = [
|
||||||
|
"testcontainers[postgres]>=4.0.0",
|
||||||
|
"schemathesis>=3.30.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
@@ -122,6 +135,8 @@ select = [
|
|||||||
"RUF", # Ruff-specific
|
"RUF", # Ruff-specific
|
||||||
"ASYNC", # flake8-async
|
"ASYNC", # flake8-async
|
||||||
"S", # flake8-bandit (security)
|
"S", # flake8-bandit (security)
|
||||||
|
"G", # flake8-logging-format (logging best practices)
|
||||||
|
"T20", # flake8-print (no print statements in production code)
|
||||||
]
|
]
|
||||||
|
|
||||||
# Ignore specific rules
|
# Ignore specific rules
|
||||||
@@ -145,11 +160,13 @@ unfixable = []
|
|||||||
[tool.ruff.lint.per-file-ignores]
|
[tool.ruff.lint.per-file-ignores]
|
||||||
"app/alembic/env.py" = ["E402", "F403", "F405"] # Alembic requires specific import order
|
"app/alembic/env.py" = ["E402", "F403", "F405"] # Alembic requires specific import order
|
||||||
"app/alembic/versions/*.py" = ["E402"] # Migration files have specific structure
|
"app/alembic/versions/*.py" = ["E402"] # Migration files have specific structure
|
||||||
"tests/**/*.py" = ["S101", "N806", "B017", "N817", "S110", "ASYNC251", "RUF043"] # pytest: asserts, CamelCase fixtures, blind exceptions, try-pass patterns, and async test helpers are intentional
|
"tests/**/*.py" = ["S101", "N806", "B017", "N817", "ASYNC251", "RUF043", "T20"] # pytest: asserts, CamelCase fixtures, blind exceptions, async test helpers, and print for debugging are intentional
|
||||||
"app/models/__init__.py" = ["F401"] # __init__ files re-export modules
|
"app/models/__init__.py" = ["F401"] # __init__ files re-export modules
|
||||||
"app/models/base.py" = ["F401"] # Re-exports Base for use by other models
|
"app/models/base.py" = ["F401"] # Re-exports Base for use by other models
|
||||||
"app/utils/test_utils.py" = ["N806"] # SQLAlchemy session factories use CamelCase convention
|
"app/utils/test_utils.py" = ["N806"] # SQLAlchemy session factories use CamelCase convention
|
||||||
"app/main.py" = ["N806"] # Constants use UPPER_CASE convention
|
"app/main.py" = ["N806"] # Constants use UPPER_CASE convention
|
||||||
|
"app/init_db.py" = ["T20"] # CLI script uses print for user-facing output
|
||||||
|
"migrate.py" = ["T20"] # CLI script uses print for user-facing output
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
# Ruff Import Sorting (isort replacement)
|
# Ruff Import Sorting (isort replacement)
|
||||||
@@ -176,116 +193,6 @@ indent-style = "space"
|
|||||||
skip-magic-trailing-comma = false
|
skip-magic-trailing-comma = false
|
||||||
line-ending = "lf"
|
line-ending = "lf"
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# mypy Configuration - Type Checking
|
|
||||||
# ============================================================================
|
|
||||||
[tool.mypy]
|
|
||||||
python_version = "3.12"
|
|
||||||
warn_return_any = false # SQLAlchemy queries return Any - overly strict
|
|
||||||
warn_unused_configs = true
|
|
||||||
disallow_untyped_defs = false # Gradual typing - enable later
|
|
||||||
disallow_incomplete_defs = false
|
|
||||||
check_untyped_defs = true
|
|
||||||
no_implicit_optional = true
|
|
||||||
warn_redundant_casts = true
|
|
||||||
warn_unused_ignores = true
|
|
||||||
warn_no_return = true
|
|
||||||
strict_equality = true
|
|
||||||
ignore_missing_imports = false
|
|
||||||
explicit_package_bases = true
|
|
||||||
namespace_packages = true
|
|
||||||
|
|
||||||
# Pydantic plugin for better validation
|
|
||||||
plugins = ["pydantic.mypy"]
|
|
||||||
|
|
||||||
# Per-module options
|
|
||||||
[[tool.mypy.overrides]]
|
|
||||||
module = "alembic.*"
|
|
||||||
ignore_errors = true
|
|
||||||
|
|
||||||
[[tool.mypy.overrides]]
|
|
||||||
module = "app.alembic.*"
|
|
||||||
ignore_errors = true
|
|
||||||
|
|
||||||
[[tool.mypy.overrides]]
|
|
||||||
module = "sqlalchemy.*"
|
|
||||||
ignore_missing_imports = true
|
|
||||||
|
|
||||||
[[tool.mypy.overrides]]
|
|
||||||
module = "fastapi_utils.*"
|
|
||||||
ignore_missing_imports = true
|
|
||||||
|
|
||||||
[[tool.mypy.overrides]]
|
|
||||||
module = "slowapi.*"
|
|
||||||
ignore_missing_imports = true
|
|
||||||
|
|
||||||
[[tool.mypy.overrides]]
|
|
||||||
module = "jose.*"
|
|
||||||
ignore_missing_imports = true
|
|
||||||
|
|
||||||
[[tool.mypy.overrides]]
|
|
||||||
module = "passlib.*"
|
|
||||||
ignore_missing_imports = true
|
|
||||||
|
|
||||||
[[tool.mypy.overrides]]
|
|
||||||
module = "pydantic_settings.*"
|
|
||||||
ignore_missing_imports = true
|
|
||||||
|
|
||||||
[[tool.mypy.overrides]]
|
|
||||||
module = "fastapi.*"
|
|
||||||
ignore_missing_imports = true
|
|
||||||
|
|
||||||
[[tool.mypy.overrides]]
|
|
||||||
module = "apscheduler.*"
|
|
||||||
ignore_missing_imports = true
|
|
||||||
|
|
||||||
[[tool.mypy.overrides]]
|
|
||||||
module = "starlette.*"
|
|
||||||
ignore_missing_imports = true
|
|
||||||
|
|
||||||
# SQLAlchemy ORM models - Column descriptors cause type confusion
|
|
||||||
[[tool.mypy.overrides]]
|
|
||||||
module = "app.models.*"
|
|
||||||
disable_error_code = ["assignment", "arg-type", "return-value"]
|
|
||||||
|
|
||||||
# CRUD operations - Generic ModelType and SQLAlchemy Result issues
|
|
||||||
[[tool.mypy.overrides]]
|
|
||||||
module = "app.crud.*"
|
|
||||||
disable_error_code = ["attr-defined", "assignment", "arg-type", "return-value"]
|
|
||||||
|
|
||||||
# API routes - SQLAlchemy Column to Pydantic schema conversions
|
|
||||||
[[tool.mypy.overrides]]
|
|
||||||
module = "app.api.routes.*"
|
|
||||||
disable_error_code = ["arg-type", "call-arg", "call-overload", "assignment"]
|
|
||||||
|
|
||||||
# API dependencies - Similar SQLAlchemy Column issues
|
|
||||||
[[tool.mypy.overrides]]
|
|
||||||
module = "app.api.dependencies.*"
|
|
||||||
disable_error_code = ["arg-type"]
|
|
||||||
|
|
||||||
# FastAPI exception handlers have correct signatures despite mypy warnings
|
|
||||||
[[tool.mypy.overrides]]
|
|
||||||
module = "app.main"
|
|
||||||
disable_error_code = ["arg-type"]
|
|
||||||
|
|
||||||
# Auth service - SQLAlchemy Column issues
|
|
||||||
[[tool.mypy.overrides]]
|
|
||||||
module = "app.services.auth_service"
|
|
||||||
disable_error_code = ["assignment", "arg-type"]
|
|
||||||
|
|
||||||
# Test utils - Testing patterns
|
|
||||||
[[tool.mypy.overrides]]
|
|
||||||
module = "app.utils.auth_test_utils"
|
|
||||||
disable_error_code = ["assignment", "arg-type"]
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Pydantic mypy plugin configuration
|
|
||||||
# ============================================================================
|
|
||||||
[tool.pydantic-mypy]
|
|
||||||
init_forbid_extra = true
|
|
||||||
init_typed = true
|
|
||||||
warn_required_dynamic_aliases = true
|
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
# Pytest Configuration
|
# Pytest Configuration
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
@@ -302,10 +209,15 @@ addopts = [
|
|||||||
"--cov=app",
|
"--cov=app",
|
||||||
"--cov-report=term-missing",
|
"--cov-report=term-missing",
|
||||||
"--cov-report=html",
|
"--cov-report=html",
|
||||||
|
"--ignore=tests/benchmarks", # benchmarks are incompatible with xdist; run via 'make benchmark'
|
||||||
|
"-p", "no:benchmark", # disable pytest-benchmark plugin during normal runs (conflicts with xdist)
|
||||||
]
|
]
|
||||||
markers = [
|
markers = [
|
||||||
"sqlite: marks tests that should run on SQLite (mocked).",
|
"sqlite: marks tests that should run on SQLite (mocked).",
|
||||||
"postgres: marks tests that require a real PostgreSQL database.",
|
"postgres: marks tests that require a real PostgreSQL database.",
|
||||||
|
"e2e: marks end-to-end tests requiring Docker containers.",
|
||||||
|
"schemathesis: marks Schemathesis-generated API tests.",
|
||||||
|
"benchmark: marks performance benchmark tests.",
|
||||||
]
|
]
|
||||||
asyncio_default_fixture_loop_scope = "function"
|
asyncio_default_fixture_loop_scope = "function"
|
||||||
|
|
||||||
@@ -319,6 +231,7 @@ omit = [
|
|||||||
"*/__pycache__/*",
|
"*/__pycache__/*",
|
||||||
"*/alembic/versions/*",
|
"*/alembic/versions/*",
|
||||||
"*/.venv/*",
|
"*/.venv/*",
|
||||||
|
"app/init_db.py", # CLI script for database initialization
|
||||||
]
|
]
|
||||||
branch = true
|
branch = true
|
||||||
|
|
||||||
|
|||||||
23
backend/pyrightconfig.json
Normal file
23
backend/pyrightconfig.json
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
{
|
||||||
|
"include": ["app"],
|
||||||
|
"exclude": ["app/alembic"],
|
||||||
|
"pythonVersion": "3.12",
|
||||||
|
"venvPath": ".",
|
||||||
|
"venv": ".venv",
|
||||||
|
"typeCheckingMode": "standard",
|
||||||
|
"reportMissingImports": true,
|
||||||
|
"reportMissingTypeStubs": false,
|
||||||
|
"reportUnknownMemberType": false,
|
||||||
|
"reportUnknownVariableType": false,
|
||||||
|
"reportUnknownArgumentType": false,
|
||||||
|
"reportUnknownParameterType": false,
|
||||||
|
"reportUnknownLambdaType": false,
|
||||||
|
"reportReturnType": true,
|
||||||
|
"reportUnusedImport": false,
|
||||||
|
"reportGeneralTypeIssues": false,
|
||||||
|
"reportAttributeAccessIssue": false,
|
||||||
|
"reportArgumentType": false,
|
||||||
|
"strictListInference": false,
|
||||||
|
"strictDictionaryInference": false,
|
||||||
|
"strictSetInference": false
|
||||||
|
}
|
||||||
@@ -67,9 +67,7 @@ class TestParseAcceptLanguage:
|
|||||||
|
|
||||||
def test_parse_complex_header(self):
|
def test_parse_complex_header(self):
|
||||||
"""Test complex Accept-Language header with multiple locales"""
|
"""Test complex Accept-Language header with multiple locales"""
|
||||||
result = parse_accept_language(
|
result = parse_accept_language("it-IT,it;q=0.9,en-US;q=0.8,en;q=0.7,fr;q=0.6")
|
||||||
"it-IT,it;q=0.9,en-US;q=0.8,en;q=0.7,fr;q=0.6"
|
|
||||||
)
|
|
||||||
assert result == "it-it"
|
assert result == "it-it"
|
||||||
|
|
||||||
def test_parse_whitespace_handling(self):
|
def test_parse_whitespace_handling(self):
|
||||||
@@ -199,9 +197,7 @@ class TestGetLocale:
|
|||||||
assert result == "en"
|
assert result == "en"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_locale_from_accept_language_header(
|
async def test_locale_from_accept_language_header(self, async_user_without_locale):
|
||||||
self, async_user_without_locale
|
|
||||||
):
|
|
||||||
"""Test locale detection from Accept-Language header when user has no preference"""
|
"""Test locale detection from Accept-Language header when user has no preference"""
|
||||||
# Mock request with Italian Accept-Language (it-IT has highest priority)
|
# Mock request with Italian Accept-Language (it-IT has highest priority)
|
||||||
mock_request = MagicMock()
|
mock_request = MagicMock()
|
||||||
|
|||||||
@@ -147,7 +147,7 @@ class TestAdminCreateUser:
|
|||||||
headers={"Authorization": f"Bearer {superuser_token}"},
|
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
assert response.status_code == status.HTTP_409_CONFLICT
|
||||||
|
|
||||||
|
|
||||||
class TestAdminGetUser:
|
class TestAdminGetUser:
|
||||||
@@ -565,7 +565,7 @@ class TestAdminCreateOrganization:
|
|||||||
headers={"Authorization": f"Bearer {superuser_token}"},
|
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
assert response.status_code == status.HTTP_409_CONFLICT
|
||||||
|
|
||||||
|
|
||||||
class TestAdminGetOrganization:
|
class TestAdminGetOrganization:
|
||||||
@@ -923,6 +923,27 @@ class TestAdminRemoveOrganizationMember:
|
|||||||
|
|
||||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_admin_remove_organization_member_user_not_found(
|
||||||
|
self, client, async_test_superuser, async_test_db, superuser_token
|
||||||
|
):
|
||||||
|
"""Test removing non-existent user from organization."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
# Create organization
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
org = Organization(name="User Not Found Org", slug="user-not-found-org")
|
||||||
|
session.add(org)
|
||||||
|
await session.commit()
|
||||||
|
org_id = org.id
|
||||||
|
|
||||||
|
response = await client.delete(
|
||||||
|
f"/api/v1/admin/organizations/{org_id}/members/{uuid4()}",
|
||||||
|
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
|
||||||
|
|
||||||
# ===== SESSION MANAGEMENT TESTS =====
|
# ===== SESSION MANAGEMENT TESTS =====
|
||||||
|
|
||||||
@@ -1097,3 +1118,102 @@ class TestAdminListSessions:
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
|
||||||
|
|
||||||
|
# ===== ADMIN STATS TESTS =====
|
||||||
|
|
||||||
|
|
||||||
|
class TestAdminStats:
|
||||||
|
"""Tests for GET /admin/stats endpoint."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_admin_get_stats_with_data(
|
||||||
|
self,
|
||||||
|
client,
|
||||||
|
async_test_superuser,
|
||||||
|
async_test_user,
|
||||||
|
async_test_db,
|
||||||
|
superuser_token,
|
||||||
|
):
|
||||||
|
"""Test getting admin stats with real data in database."""
|
||||||
|
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||||
|
|
||||||
|
# Create multiple users and organizations with members
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
from app.core.auth import get_password_hash
|
||||||
|
from app.models.user import User
|
||||||
|
|
||||||
|
# Create several users
|
||||||
|
for i in range(5):
|
||||||
|
user = User(
|
||||||
|
email=f"statsuser{i}@example.com",
|
||||||
|
password_hash=get_password_hash("TestPassword123!"),
|
||||||
|
first_name=f"Stats{i}",
|
||||||
|
last_name="User",
|
||||||
|
is_active=i % 2 == 0, # Mix of active/inactive
|
||||||
|
)
|
||||||
|
session.add(user)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
# Create organizations with members
|
||||||
|
async with AsyncTestingSessionLocal() as session:
|
||||||
|
orgs = []
|
||||||
|
for i in range(3):
|
||||||
|
org = Organization(name=f"Stats Org {i}", slug=f"stats-org-{i}")
|
||||||
|
session.add(org)
|
||||||
|
orgs.append(org)
|
||||||
|
await session.flush()
|
||||||
|
|
||||||
|
# Add some members to organizations
|
||||||
|
user_org = UserOrganization(
|
||||||
|
user_id=async_test_user.id,
|
||||||
|
organization_id=orgs[0].id,
|
||||||
|
role=OrganizationRole.MEMBER,
|
||||||
|
is_active=True,
|
||||||
|
)
|
||||||
|
session.add(user_org)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
response = await client.get(
|
||||||
|
"/api/v1/admin/stats",
|
||||||
|
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
# Verify response structure
|
||||||
|
assert "user_growth" in data
|
||||||
|
assert "organization_distribution" in data
|
||||||
|
assert "registration_activity" in data
|
||||||
|
assert "user_status" in data
|
||||||
|
|
||||||
|
# Verify user_growth has 30 days of data
|
||||||
|
assert len(data["user_growth"]) == 30
|
||||||
|
for item in data["user_growth"]:
|
||||||
|
assert "date" in item
|
||||||
|
assert "total_users" in item
|
||||||
|
assert "active_users" in item
|
||||||
|
|
||||||
|
# Verify registration_activity has 14 days of data
|
||||||
|
assert len(data["registration_activity"]) == 14
|
||||||
|
for item in data["registration_activity"]:
|
||||||
|
assert "date" in item
|
||||||
|
assert "registrations" in item
|
||||||
|
|
||||||
|
# Verify user_status has active/inactive counts
|
||||||
|
assert len(data["user_status"]) == 2
|
||||||
|
status_names = {item["name"] for item in data["user_status"]}
|
||||||
|
assert status_names == {"Active", "Inactive"}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_admin_get_stats_unauthorized(
|
||||||
|
self, client, async_test_user, user_token
|
||||||
|
):
|
||||||
|
"""Test that non-admin users cannot access stats endpoint."""
|
||||||
|
response = await client.get(
|
||||||
|
"/api/v1/admin/stats",
|
||||||
|
headers={"Authorization": f"Bearer {user_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user