Compare commits
158 Commits
4420756741
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a94e29d99c | ||
|
|
81e48c73ca | ||
|
|
a3f78dc801 | ||
|
|
07309013d7 | ||
|
|
846fc31190 | ||
|
|
ff7a67cb58 | ||
|
|
0760a8284d | ||
|
|
ce4d0c7b0d | ||
|
|
4ceb8ad98c | ||
|
|
f8aafb250d | ||
|
|
4385d20ca6 | ||
|
|
1a36907f10 | ||
|
|
0553a1fc53 | ||
|
|
57e969ed67 | ||
|
|
68275b1dd3 | ||
|
|
80d2dc0cb2 | ||
|
|
a8aa416ecb | ||
|
|
4c6bf55bcc | ||
|
|
98b455fdc3 | ||
|
|
0646c96b19 | ||
|
|
62afb328fe | ||
|
|
b9a746bc16 | ||
|
|
de8e18e97d | ||
|
|
a3e557d022 | ||
|
|
4e357db25d | ||
|
|
568aad3673 | ||
|
|
ddcf926158 | ||
|
|
865eeece58 | ||
|
|
05fb3612f9 | ||
|
|
1b2e7dde35 | ||
|
|
29074f26a6 | ||
|
|
77ed190310 | ||
|
|
2bbe925cef | ||
|
|
4a06b96b2e | ||
|
|
088c1725b0 | ||
|
|
7ba1767cea | ||
|
|
c63b6a4f76 | ||
|
|
803b720530 | ||
|
|
7ff00426f2 | ||
|
|
b3f0dd4005 | ||
|
|
707315facd | ||
|
|
38114b79f9 | ||
|
|
1cb3658369 | ||
|
|
dc875c5c95 | ||
|
|
0ea428b718 | ||
|
|
400d6f6f75 | ||
|
|
7716468d72 | ||
|
|
48f052200f | ||
|
|
fbb030da69 | ||
|
|
d49f819469 | ||
|
|
507f2e9c00 | ||
|
|
c0b253a010 | ||
|
|
fcbcff99e9 | ||
|
|
b49678b7df | ||
|
|
aeed9dfdbc | ||
|
|
13f617828b | ||
|
|
84e0a7fe81 | ||
|
|
063a35e698 | ||
|
|
a2246fb6e1 | ||
|
|
16ee4e0cb3 | ||
|
|
e6792c2d6c | ||
|
|
1d20b149dc | ||
|
|
570848cc2d | ||
|
|
6b970765ba | ||
|
|
e79215b4de | ||
|
|
3bf28aa121 | ||
|
|
cda9810a7e | ||
|
|
d47bd34a92 | ||
|
|
5b0ae54365 | ||
|
|
372af25aaa | ||
|
|
d0b717a128 | ||
|
|
9d40aece30 | ||
|
|
487c8a3863 | ||
|
|
8659e884e9 | ||
|
|
a05def5906 | ||
|
|
9f655913b1 | ||
|
|
13abd159fa | ||
|
|
acfe59c8b3 | ||
|
|
2e4700ae9b | ||
|
|
8c83e2a699 | ||
|
|
9b6356b0db | ||
|
|
a410586cfb | ||
|
|
0e34cab921 | ||
|
|
3cf3858fca | ||
|
|
db0c555041 | ||
|
|
51ad80071a | ||
|
|
d730ab7526 | ||
|
|
b218be9318 | ||
|
|
e6813c87c3 | ||
|
|
210204eb7a | ||
|
|
6ad4cda3f4 | ||
|
|
54ceaa6f5d | ||
|
|
34e7f69465 | ||
|
|
8fdbc2b359 | ||
|
|
28b1cc6e48 | ||
|
|
5a21847382 | ||
|
|
444d495f83 | ||
|
|
a943f79ce7 | ||
|
|
f54905abd0 | ||
|
|
0105e765b3 | ||
|
|
bb06b450fd | ||
|
|
c1d6a04276 | ||
|
|
d7b333385d | ||
|
|
f02320e57c | ||
|
|
3ec589293c | ||
|
|
7b1bea2966 | ||
|
|
da7b6b5bfa | ||
|
|
7aa63d79df | ||
|
|
333c9c40af | ||
|
|
0b192ce030 | ||
|
|
da021d0640 | ||
|
|
d1b47006f4 | ||
|
|
a73d3c7d3e | ||
|
|
55ae92c460 | ||
|
|
fe6a98c379 | ||
|
|
b7c1191335 | ||
|
|
68e04a911a | ||
|
|
3001484948 | ||
|
|
c9f4772196 | ||
|
|
14e5839476 | ||
|
|
228d12b379 | ||
|
|
46ff95d8b9 | ||
|
|
235c309e4e | ||
|
|
5c47be2ee5 | ||
|
|
e9f787040a | ||
|
|
2532d1ac3c | ||
|
|
1f45ca2b50 | ||
|
|
8a343580ce | ||
|
|
424ca166b8 | ||
|
|
c589b565f0 | ||
|
|
a5c671c133 | ||
|
|
d8bde80d4f | ||
|
|
35efa24ce5 | ||
|
|
96df7edf88 | ||
|
|
464a6140c4 | ||
|
|
b2f3ec8f25 | ||
|
|
c8f90e9e8c | ||
|
|
2169618bc8 | ||
|
|
a84fd11cc7 | ||
|
|
6824fd7c33 | ||
|
|
d5eb855ae1 | ||
|
|
a6a10855fa | ||
|
|
bf95aab7ec | ||
|
|
214d0b1765 | ||
|
|
b630559e0b | ||
|
|
fe289228e1 | ||
|
|
63c171f83e | ||
|
|
e02329b734 | ||
|
|
e1d5914e7f | ||
|
|
d6a06e45ec | ||
|
|
e74830bec5 | ||
|
|
51ef4632e6 | ||
|
|
b749f62abd | ||
|
|
3b28b5cf97 | ||
|
|
652fb6b180 | ||
|
|
6b556431d3 | ||
|
|
f8b77200f0 | ||
|
|
f99de75dc6 |
55
.env.demo
Normal file
55
.env.demo
Normal file
@@ -0,0 +1,55 @@
|
||||
# Common settings
|
||||
PROJECT_NAME=App
|
||||
VERSION=1.0.0
|
||||
|
||||
# Database settings
|
||||
POSTGRES_USER=postgres
|
||||
POSTGRES_PASSWORD=postgres
|
||||
POSTGRES_DB=app
|
||||
POSTGRES_HOST=db
|
||||
POSTGRES_PORT=5432
|
||||
DATABASE_URL=postgresql://postgres:postgres@db:5432/app
|
||||
|
||||
# Backend settings
|
||||
BACKEND_PORT=8000
|
||||
# CRITICAL: Generate a secure SECRET_KEY for production!
|
||||
# Generate with: python -c 'import secrets; print(secrets.token_urlsafe(32))'
|
||||
# Must be at least 32 characters
|
||||
SECRET_KEY=demo_secret_key_for_testing_only_do_not_use_in_prod
|
||||
ENVIRONMENT=development
|
||||
DEMO_MODE=true
|
||||
DEBUG=true
|
||||
BACKEND_CORS_ORIGINS=["http://localhost:3000"]
|
||||
FIRST_SUPERUSER_EMAIL=admin@example.com
|
||||
# IMPORTANT: Use a strong password (min 12 chars, mixed case, digits)
|
||||
# Default weak passwords like 'Admin123' are rejected
|
||||
FIRST_SUPERUSER_PASSWORD=Admin123!
|
||||
|
||||
# OAuth Configuration (Social Login)
|
||||
# Set OAUTH_ENABLED=true and configure at least one provider
|
||||
OAUTH_ENABLED=false
|
||||
OAUTH_AUTO_LINK_BY_EMAIL=true
|
||||
|
||||
# Google OAuth (from Google Cloud Console > APIs & Services > Credentials)
|
||||
# https://console.cloud.google.com/apis/credentials
|
||||
# OAUTH_GOOGLE_CLIENT_ID=your-google-client-id.apps.googleusercontent.com
|
||||
# OAUTH_GOOGLE_CLIENT_SECRET=your-google-client-secret
|
||||
|
||||
# GitHub OAuth (from GitHub > Settings > Developer settings > OAuth Apps)
|
||||
# https://github.com/settings/developers
|
||||
# OAUTH_GITHUB_CLIENT_ID=your-github-client-id
|
||||
# OAUTH_GITHUB_CLIENT_SECRET=your-github-client-secret
|
||||
|
||||
# OAuth Provider Mode (Authorization Server for MCP/third-party clients)
|
||||
# Set OAUTH_PROVIDER_ENABLED=true to act as an OAuth 2.0 Authorization Server
|
||||
OAUTH_PROVIDER_ENABLED=true
|
||||
# IMPORTANT: Must be HTTPS in production!
|
||||
OAUTH_ISSUER=http://localhost:8000
|
||||
|
||||
# Frontend settings
|
||||
FRONTEND_PORT=3000
|
||||
FRONTEND_URL=http://localhost:3000
|
||||
NEXT_PUBLIC_API_URL=http://localhost:8000
|
||||
NEXT_PUBLIC_API_BASE_URL=http://localhost:8000
|
||||
NEXT_PUBLIC_APP_URL=http://localhost:3000
|
||||
NODE_ENV=development
|
||||
@@ -17,6 +17,7 @@ BACKEND_PORT=8000
|
||||
# Must be at least 32 characters
|
||||
SECRET_KEY=your_secret_key_here_REPLACE_WITH_GENERATED_KEY_32_CHARS_MIN
|
||||
ENVIRONMENT=development
|
||||
DEMO_MODE=false
|
||||
DEBUG=true
|
||||
BACKEND_CORS_ORIGINS=["http://localhost:3000"]
|
||||
FIRST_SUPERUSER_EMAIL=admin@example.com
|
||||
@@ -24,7 +25,31 @@ FIRST_SUPERUSER_EMAIL=admin@example.com
|
||||
# Default weak passwords like 'Admin123' are rejected
|
||||
FIRST_SUPERUSER_PASSWORD=YourStrongPassword123!
|
||||
|
||||
# OAuth Configuration (Social Login)
|
||||
# Set OAUTH_ENABLED=true and configure at least one provider
|
||||
OAUTH_ENABLED=false
|
||||
OAUTH_AUTO_LINK_BY_EMAIL=true
|
||||
|
||||
# Google OAuth (from Google Cloud Console > APIs & Services > Credentials)
|
||||
# https://console.cloud.google.com/apis/credentials
|
||||
# OAUTH_GOOGLE_CLIENT_ID=your-google-client-id.apps.googleusercontent.com
|
||||
# OAUTH_GOOGLE_CLIENT_SECRET=your-google-client-secret
|
||||
|
||||
# GitHub OAuth (from GitHub > Settings > Developer settings > OAuth Apps)
|
||||
# https://github.com/settings/developers
|
||||
# OAUTH_GITHUB_CLIENT_ID=your-github-client-id
|
||||
# OAUTH_GITHUB_CLIENT_SECRET=your-github-client-secret
|
||||
|
||||
# OAuth Provider Mode (Authorization Server for MCP/third-party clients)
|
||||
# Set OAUTH_PROVIDER_ENABLED=true to act as an OAuth 2.0 Authorization Server
|
||||
OAUTH_PROVIDER_ENABLED=false
|
||||
# IMPORTANT: Must be HTTPS in production!
|
||||
OAUTH_ISSUER=http://localhost:8000
|
||||
|
||||
# Frontend settings
|
||||
FRONTEND_PORT=3000
|
||||
FRONTEND_URL=http://localhost:3000
|
||||
NEXT_PUBLIC_API_URL=http://localhost:8000
|
||||
NEXT_PUBLIC_API_BASE_URL=http://localhost:8000
|
||||
NEXT_PUBLIC_APP_URL=http://localhost:3000
|
||||
NODE_ENV=development
|
||||
|
||||
2
.github/workflows/README.md
vendored
2
.github/workflows/README.md
vendored
@@ -41,7 +41,7 @@ To enable CI/CD workflows:
|
||||
- Runs on: Push to main/develop, PRs affecting frontend code
|
||||
- Tests: Frontend unit tests (Jest)
|
||||
- Coverage: Uploads to Codecov
|
||||
- Fast: Uses npm cache
|
||||
- Fast: Uses bun cache
|
||||
|
||||
### `e2e-tests.yml`
|
||||
- Runs on: All pushes and PRs
|
||||
|
||||
77
.github/workflows/backend-e2e-tests.yml.template
vendored
Normal file
77
.github/workflows/backend-e2e-tests.yml.template
vendored
Normal file
@@ -0,0 +1,77 @@
|
||||
# Backend E2E Tests CI Pipeline
|
||||
#
|
||||
# Runs end-to-end tests with real PostgreSQL via Testcontainers
|
||||
# and validates API contracts with Schemathesis.
|
||||
#
|
||||
# To enable: Rename this file to backend-e2e-tests.yml
|
||||
|
||||
name: Backend E2E Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main, develop]
|
||||
paths:
|
||||
- 'backend/**'
|
||||
- '.github/workflows/backend-e2e-tests.yml'
|
||||
pull_request:
|
||||
branches: [main, develop]
|
||||
paths:
|
||||
- 'backend/**'
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
e2e-tests:
|
||||
runs-on: ubuntu-latest
|
||||
# E2E test failures don't block merge - they're advisory
|
||||
continue-on-error: true
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.12'
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
with:
|
||||
version: "latest"
|
||||
|
||||
- name: Cache uv dependencies
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/uv
|
||||
key: uv-${{ runner.os }}-${{ hashFiles('backend/uv.lock') }}
|
||||
restore-keys: |
|
||||
uv-${{ runner.os }}-
|
||||
|
||||
- name: Install dependencies (with E2E)
|
||||
working-directory: ./backend
|
||||
run: uv sync --extra dev --extra e2e
|
||||
|
||||
- name: Check Docker availability
|
||||
id: docker-check
|
||||
run: |
|
||||
if docker info > /dev/null 2>&1; then
|
||||
echo "available=true" >> $GITHUB_OUTPUT
|
||||
echo "Docker is available"
|
||||
else
|
||||
echo "available=false" >> $GITHUB_OUTPUT
|
||||
echo "::warning::Docker not available - E2E tests will be skipped"
|
||||
fi
|
||||
|
||||
- name: Run E2E tests
|
||||
if: steps.docker-check.outputs.available == 'true'
|
||||
working-directory: ./backend
|
||||
env:
|
||||
IS_TEST: "True"
|
||||
SECRET_KEY: "e2e-test-secret-key-minimum-32-characters-long"
|
||||
PYTHONPATH: "."
|
||||
run: |
|
||||
uv run pytest tests/e2e/ -v --tb=short
|
||||
|
||||
- name: E2E tests skipped
|
||||
if: steps.docker-check.outputs.available != 'true'
|
||||
run: echo "E2E tests were skipped due to Docker unavailability"
|
||||
7
.gitignore
vendored
7
.gitignore
vendored
@@ -27,6 +27,10 @@ coverage
|
||||
# nyc test coverage
|
||||
.nyc_output
|
||||
|
||||
# Playwright authentication state (contains test auth tokens)
|
||||
frontend/e2e/.auth/
|
||||
**/playwright/.auth/
|
||||
|
||||
# Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files)
|
||||
.grunt
|
||||
|
||||
@@ -183,7 +187,7 @@ coverage.xml
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
backend/.benchmarks
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
@@ -264,6 +268,7 @@ celerybeat.pid
|
||||
.env
|
||||
.env.*
|
||||
!.env.template
|
||||
!.env.demo
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
|
||||
315
AGENTS.md
Normal file
315
AGENTS.md
Normal file
@@ -0,0 +1,315 @@
|
||||
# AGENTS.md
|
||||
|
||||
AI coding assistant context for FastAPI + Next.js Full-Stack Template.
|
||||
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
# Backend (Python with uv)
|
||||
cd backend
|
||||
make install-dev # Install dependencies
|
||||
make test # Run tests
|
||||
uv run uvicorn app.main:app --reload # Start dev server
|
||||
|
||||
# Frontend (Node.js)
|
||||
cd frontend
|
||||
bun install # Install dependencies
|
||||
bun run dev # Start dev server
|
||||
bun run generate:api # Generate API client from OpenAPI
|
||||
bun run test:e2e # Run E2E tests
|
||||
```
|
||||
|
||||
**Access points:**
|
||||
- Frontend: **http://localhost:3000**
|
||||
- Backend API: **http://localhost:8000**
|
||||
- API Docs: **http://localhost:8000/docs**
|
||||
|
||||
Default superuser (change in production):
|
||||
- Email: `admin@example.com`
|
||||
- Password: `admin123`
|
||||
|
||||
## Project Architecture
|
||||
|
||||
**Full-stack TypeScript/Python application:**
|
||||
|
||||
```
|
||||
├── backend/ # FastAPI backend
|
||||
│ ├── app/
|
||||
│ │ ├── api/ # API routes (auth, users, organizations, admin)
|
||||
│ │ ├── core/ # Core functionality (auth, config, database)
|
||||
│ │ ├── repositories/ # Repository pattern (database operations)
|
||||
│ │ ├── models/ # SQLAlchemy ORM models
|
||||
│ │ ├── schemas/ # Pydantic request/response schemas
|
||||
│ │ ├── services/ # Business logic layer
|
||||
│ │ └── utils/ # Utilities (security, device detection)
|
||||
│ ├── tests/ # 96% coverage, 987 tests
|
||||
│ └── alembic/ # Database migrations
|
||||
│
|
||||
└── frontend/ # Next.js 16 frontend
|
||||
├── src/
|
||||
│ ├── app/ # App Router pages (Next.js 16)
|
||||
│ ├── components/ # React components
|
||||
│ ├── lib/
|
||||
│ │ ├── api/ # Auto-generated API client
|
||||
│ │ └── stores/ # Zustand state management
|
||||
│ └── hooks/ # Custom React hooks
|
||||
└── e2e/ # Playwright E2E tests (56 passing)
|
||||
```
|
||||
|
||||
## Critical Implementation Notes
|
||||
|
||||
### Authentication Flow
|
||||
- **JWT-based**: Access tokens (15 min) + refresh tokens (7 days)
|
||||
- **OAuth/Social Login**: Google and GitHub with PKCE support
|
||||
- **Session tracking**: Database-backed with device info, IP, user agent
|
||||
- **Token refresh**: Validates JTI in database, not just JWT signature
|
||||
- **Authorization**: FastAPI dependencies in `api/dependencies/auth.py`
|
||||
- `get_current_user`: Requires valid access token
|
||||
- `get_current_active_user`: Requires active account
|
||||
- `get_optional_current_user`: Accepts authenticated or anonymous
|
||||
- `get_current_superuser`: Requires superuser flag
|
||||
|
||||
### OAuth Provider Mode (MCP Integration)
|
||||
Full OAuth 2.0 Authorization Server for MCP (Model Context Protocol) clients:
|
||||
- **Authorization Code Flow with PKCE**: RFC 7636 compliant
|
||||
- **JWT access tokens**: Self-contained, no DB lookup required
|
||||
- **Opaque refresh tokens**: Stored hashed in database, supports rotation
|
||||
- **Token introspection**: RFC 7662 compliant endpoint
|
||||
- **Token revocation**: RFC 7009 compliant endpoint
|
||||
- **Server metadata**: RFC 8414 compliant discovery endpoint
|
||||
- **Consent management**: User can review and revoke app permissions
|
||||
|
||||
**API endpoints:**
|
||||
- `GET /.well-known/oauth-authorization-server` - Server metadata
|
||||
- `GET /oauth/provider/authorize` - Authorization endpoint
|
||||
- `POST /oauth/provider/authorize/consent` - Consent submission
|
||||
- `POST /oauth/provider/token` - Token endpoint
|
||||
- `POST /oauth/provider/revoke` - Token revocation
|
||||
- `POST /oauth/provider/introspect` - Token introspection
|
||||
- Client management endpoints (admin only)
|
||||
|
||||
**Scopes supported:** `openid`, `profile`, `email`, `read:users`, `write:users`, `admin`
|
||||
|
||||
**OAuth Configuration (backend `.env`):**
|
||||
```bash
|
||||
# OAuth Social Login (as OAuth Consumer)
|
||||
OAUTH_ENABLED=true # Enable OAuth social login
|
||||
OAUTH_AUTO_LINK_BY_EMAIL=true # Auto-link accounts by email
|
||||
OAUTH_STATE_EXPIRE_MINUTES=10 # CSRF state expiration
|
||||
|
||||
# Google OAuth
|
||||
OAUTH_GOOGLE_CLIENT_ID=your-google-client-id
|
||||
OAUTH_GOOGLE_CLIENT_SECRET=your-google-client-secret
|
||||
|
||||
# GitHub OAuth
|
||||
OAUTH_GITHUB_CLIENT_ID=your-github-client-id
|
||||
OAUTH_GITHUB_CLIENT_SECRET=your-github-client-secret
|
||||
|
||||
# OAuth Provider Mode (as Authorization Server for MCP)
|
||||
OAUTH_PROVIDER_ENABLED=true # Enable OAuth provider mode
|
||||
OAUTH_ISSUER=https://api.yourdomain.com # JWT issuer URL (must be HTTPS in production)
|
||||
```
|
||||
|
||||
### Database Pattern
|
||||
- **Async SQLAlchemy 2.0** with PostgreSQL
|
||||
- **Connection pooling**: 20 base connections, 50 max overflow
|
||||
- **Repository base class**: `repositories/base.py` with common operations
|
||||
- **Migrations**: Alembic with helper script `migrate.py`
|
||||
- `python migrate.py auto "message"` - Generate and apply
|
||||
- `python migrate.py list` - View history
|
||||
|
||||
### Frontend State Management
|
||||
- **Zustand stores**: Lightweight state management
|
||||
- **TanStack Query**: API data fetching/caching
|
||||
- **Auto-generated client**: From OpenAPI spec via `bun run generate:api`
|
||||
- **Dependency Injection**: ALWAYS use `useAuth()` from `AuthContext`, NEVER import `useAuthStore` directly
|
||||
|
||||
### Internationalization (i18n)
|
||||
- **next-intl v4**: Type-safe internationalization library
|
||||
- **Locale routing**: `/en/*`, `/it/*` (English and Italian supported)
|
||||
- **Translation files**: `frontend/messages/en.json`, `frontend/messages/it.json`
|
||||
- **LocaleSwitcher**: Component for seamless language switching
|
||||
- **SEO-friendly**: Locale-aware metadata, sitemaps, and robots.txt
|
||||
- **Type safety**: Full TypeScript support for translations
|
||||
- **Utilities**: `frontend/src/lib/i18n/` (metadata, routing, utils)
|
||||
|
||||
### Organization System
|
||||
Three-tier RBAC:
|
||||
- **Owner**: Full control (delete org, manage all members)
|
||||
- **Admin**: Add/remove members, assign admin role (not owner)
|
||||
- **Member**: Read-only organization access
|
||||
|
||||
Permission dependencies in `api/dependencies/permissions.py`:
|
||||
- `require_organization_owner`
|
||||
- `require_organization_admin`
|
||||
- `require_organization_member`
|
||||
- `can_manage_organization_member`
|
||||
|
||||
### Testing Infrastructure
|
||||
|
||||
**Backend Unit/Integration (pytest + SQLite):**
|
||||
- 96% coverage, 987 tests
|
||||
- Security-focused: JWT attacks, session hijacking, privilege escalation
|
||||
- Async fixtures in `tests/conftest.py`
|
||||
- Run: `IS_TEST=True uv run pytest` or `make test`
|
||||
- Coverage: `make test-cov`
|
||||
|
||||
**Backend E2E (pytest + Testcontainers + Schemathesis):**
|
||||
- Real PostgreSQL via Docker containers
|
||||
- OpenAPI contract testing with Schemathesis
|
||||
- Install: `make install-e2e`
|
||||
- Run: `make test-e2e`
|
||||
- Schema tests: `make test-e2e-schema`
|
||||
- Docs: `backend/docs/E2E_TESTING.md`
|
||||
|
||||
**Frontend Unit Tests (Jest):**
|
||||
- 97% coverage
|
||||
- Component, hook, and utility testing
|
||||
- Run: `bun run test`
|
||||
- Coverage: `bun run test:coverage`
|
||||
|
||||
**Frontend E2E Tests (Playwright):**
|
||||
- 56 passing, 1 skipped (zero flaky tests)
|
||||
- Complete user flows (auth, navigation, settings)
|
||||
- Run: `bun run test:e2e`
|
||||
- UI mode: `bun run test:e2e:ui`
|
||||
|
||||
### Development Tooling
|
||||
|
||||
**Backend:**
|
||||
- **uv**: Modern Python package manager (10-100x faster than pip)
|
||||
- **Ruff**: All-in-one linting/formatting (replaces Black, Flake8, isort)
|
||||
- **Pyright**: Static type checking (strict mode)
|
||||
- **pip-audit**: Dependency vulnerability scanning (OSV database)
|
||||
- **detect-secrets**: Hardcoded secrets detection
|
||||
- **pip-licenses**: License compliance checking
|
||||
- **pre-commit**: Git hook framework (Ruff, detect-secrets, standard checks)
|
||||
- **Makefile**: `make help` for all commands
|
||||
|
||||
**Frontend:**
|
||||
- **Next.js 16**: App Router with React 19
|
||||
- **TypeScript**: Full type safety
|
||||
- **TailwindCSS + shadcn/ui**: Design system
|
||||
- **ESLint + Prettier**: Code quality
|
||||
|
||||
### Environment Configuration
|
||||
|
||||
**Backend** (`.env`):
|
||||
```bash
|
||||
POSTGRES_USER=postgres
|
||||
POSTGRES_PASSWORD=your_password
|
||||
POSTGRES_HOST=db
|
||||
POSTGRES_PORT=5432
|
||||
POSTGRES_DB=app
|
||||
|
||||
SECRET_KEY=your-secret-key-min-32-chars
|
||||
ENVIRONMENT=development|production
|
||||
CSP_MODE=relaxed|strict|disabled
|
||||
|
||||
FIRST_SUPERUSER_EMAIL=admin@example.com
|
||||
FIRST_SUPERUSER_PASSWORD=admin123
|
||||
|
||||
BACKEND_CORS_ORIGINS=["http://localhost:3000"]
|
||||
```
|
||||
|
||||
**Frontend** (`.env.local`):
|
||||
```bash
|
||||
NEXT_PUBLIC_API_URL=http://localhost:8000/api/v1
|
||||
```
|
||||
|
||||
## Common Development Workflows
|
||||
|
||||
### Adding a New API Endpoint
|
||||
|
||||
1. **Define schema** in `backend/app/schemas/`
|
||||
2. **Create repository** in `backend/app/repositories/`
|
||||
3. **Implement route** in `backend/app/api/routes/`
|
||||
4. **Register router** in `backend/app/api/main.py`
|
||||
5. **Write tests** in `backend/tests/api/`
|
||||
6. **Generate frontend client**: `bun run generate:api`
|
||||
|
||||
### Database Migrations
|
||||
|
||||
```bash
|
||||
cd backend
|
||||
python migrate.py generate "description" # Create migration
|
||||
python migrate.py apply # Apply migrations
|
||||
python migrate.py auto "description" # Generate + apply
|
||||
```
|
||||
|
||||
### Frontend Component Development
|
||||
|
||||
1. **Create component** in `frontend/src/components/`
|
||||
2. **Follow design system** (see `frontend/docs/design-system/`)
|
||||
3. **Use dependency injection** for auth (`useAuth()` not `useAuthStore`)
|
||||
4. **Write tests** in `frontend/tests/` or `__tests__/`
|
||||
5. **Run type check**: `bun run type-check`
|
||||
|
||||
## Security Features
|
||||
|
||||
- **Password hashing**: bcrypt with salt rounds
|
||||
- **Rate limiting**: 60 req/min default, 10 req/min on auth endpoints
|
||||
- **Security headers**: CSP, X-Frame-Options, HSTS, etc.
|
||||
- **CSRF protection**: Built into FastAPI
|
||||
- **Session revocation**: Database-backed session tracking
|
||||
- **Comprehensive security tests**: JWT algorithm attacks, session hijacking, privilege escalation
|
||||
- **Dependency vulnerability scanning**: `make dep-audit` (pip-audit against OSV database)
|
||||
- **License compliance**: `make license-check` (blocks GPL-3.0/AGPL)
|
||||
- **Secrets detection**: Pre-commit hook blocks hardcoded secrets
|
||||
- **Unified security pipeline**: `make audit` (all security checks), `make check` (quality + security + tests)
|
||||
|
||||
## Docker Deployment
|
||||
|
||||
```bash
|
||||
# Development (with hot reload)
|
||||
docker-compose -f docker-compose.dev.yml up
|
||||
|
||||
# Production
|
||||
docker-compose up -d
|
||||
|
||||
# Run migrations
|
||||
docker-compose exec backend alembic upgrade head
|
||||
|
||||
# Create first superuser
|
||||
docker-compose exec backend python -c "from app.init_db import init_db; import asyncio; asyncio.run(init_db())"
|
||||
```
|
||||
|
||||
## Documentation
|
||||
|
||||
**For comprehensive documentation, see:**
|
||||
- **[README.md](./README.md)** - User-facing project overview
|
||||
- **[CLAUDE.md](./CLAUDE.md)** - Claude Code-specific guidance
|
||||
- **Backend docs**: `backend/docs/` (Architecture, Coding Standards, Common Pitfalls, Feature Examples)
|
||||
- **Frontend docs**: `frontend/docs/` (Design System, Architecture, E2E Testing)
|
||||
- **API docs**: http://localhost:8000/docs (Swagger UI when running)
|
||||
|
||||
## Current Status (Nov 2025)
|
||||
|
||||
### Completed Features ✅
|
||||
- Authentication system (JWT with refresh tokens, OAuth/social login)
|
||||
- **OAuth Provider Mode (MCP-ready)**: Full OAuth 2.0 Authorization Server
|
||||
- Session management (device tracking, revocation)
|
||||
- User management (full lifecycle, password change)
|
||||
- Organization system (multi-tenant with RBAC)
|
||||
- Admin panel (user/org management, bulk operations)
|
||||
- **Internationalization (i18n)** with English and Italian
|
||||
- Comprehensive test coverage (96% backend, 97% frontend unit, 56 E2E tests)
|
||||
- Design system documentation
|
||||
- **Marketing landing page** with animations
|
||||
- **`/dev` documentation portal** with live examples
|
||||
- **Toast notifications**, charts, markdown rendering
|
||||
- **SEO optimization** (sitemap, robots.txt, locale metadata)
|
||||
- Docker deployment
|
||||
|
||||
### In Progress 🚧
|
||||
- Frontend admin pages (70% complete)
|
||||
- Email integration (templates ready, SMTP pending)
|
||||
|
||||
### Planned 🔮
|
||||
- GitHub Actions CI/CD
|
||||
- Additional languages (Spanish, French, German, etc.)
|
||||
- SSO/SAML authentication
|
||||
- Real-time notifications (WebSockets)
|
||||
- Webhook system
|
||||
- Background job processing
|
||||
- File upload/storage
|
||||
734
CLAUDE.md
734
CLAUDE.md
@@ -1,10 +1,14 @@
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
Claude Code context for FastAPI + Next.js Full-Stack Template.
|
||||
|
||||
## Critical User Preferences
|
||||
**See [AGENTS.md](./AGENTS.md) for project context, architecture, and development commands.**
|
||||
|
||||
### File Operations - NEVER Use Heredoc/Cat Append
|
||||
## Claude Code-Specific Guidance
|
||||
|
||||
### Critical User Preferences
|
||||
|
||||
#### File Operations - NEVER Use Heredoc/Cat Append
|
||||
**ALWAYS use Read/Write/Edit tools instead of `cat >> file << EOF` commands.**
|
||||
|
||||
This triggers manual approval dialogs and disrupts workflow.
|
||||
@@ -18,193 +22,53 @@ EOF
|
||||
# CORRECT ✅ - Use Read, then Write tools
|
||||
```
|
||||
|
||||
### Work Style
|
||||
#### Work Style
|
||||
- User prefers autonomous operation without frequent interruptions
|
||||
- Ask for batch permissions upfront for long work sessions
|
||||
- Work independently, document decisions clearly
|
||||
- Only use emojis if the user explicitly requests it
|
||||
|
||||
## Project Architecture
|
||||
### When Working with This Stack
|
||||
|
||||
This is a **FastAPI + Next.js full-stack application** with the following structure:
|
||||
**Dependency Management:**
|
||||
- Backend uses **uv** (modern Python package manager), not pip
|
||||
- Always use `uv run` prefix: `IS_TEST=True uv run pytest`
|
||||
- Or use Makefile commands: `make test`, `make install-dev`
|
||||
- Add dependencies: `uv add <package>` or `uv add --dev <package>`
|
||||
|
||||
### Backend (FastAPI)
|
||||
```
|
||||
backend/app/
|
||||
├── api/ # API routes organized by version
|
||||
│ ├── routes/ # Endpoint implementations (auth, users, sessions, admin, organizations)
|
||||
│ └── dependencies/ # FastAPI dependencies (auth, permissions)
|
||||
├── core/ # Core functionality
|
||||
│ ├── config.py # Settings (Pydantic BaseSettings)
|
||||
│ ├── database.py # SQLAlchemy async engine setup
|
||||
│ ├── auth.py # JWT token generation/validation
|
||||
│ └── exceptions.py # Custom exception classes and handlers
|
||||
├── crud/ # Database CRUD operations (base, user, session, organization)
|
||||
├── models/ # SQLAlchemy ORM models
|
||||
├── schemas/ # Pydantic request/response schemas
|
||||
├── services/ # Business logic layer (auth_service)
|
||||
└── utils/ # Utilities (security, device detection, test helpers)
|
||||
```
|
||||
**Database Migrations:**
|
||||
- Use the `migrate.py` helper script, not Alembic directly
|
||||
- Generate + apply: `python migrate.py auto "message"`
|
||||
- Never commit migrations without testing them first
|
||||
- Check current state: `python migrate.py current`
|
||||
|
||||
### Frontend (Next.js 15)
|
||||
```
|
||||
frontend/src/
|
||||
├── app/ # Next.js App Router pages
|
||||
├── components/ # React components (auth/, ui/)
|
||||
├── lib/
|
||||
│ ├── api/ # API client (auto-generated from OpenAPI)
|
||||
│ ├── stores/ # Zustand state management
|
||||
│ └── utils/ # Utility functions
|
||||
└── hooks/ # Custom React hooks
|
||||
```
|
||||
**Frontend API Client Generation:**
|
||||
- Run `bun run generate:api` after backend schema changes
|
||||
- Client is auto-generated from OpenAPI spec
|
||||
- Located in `frontend/src/lib/api/generated/`
|
||||
- NEVER manually edit generated files
|
||||
|
||||
## Development Commands
|
||||
**Testing Commands:**
|
||||
- Backend unit/integration: `IS_TEST=True uv run pytest` (always prefix with `IS_TEST=True`)
|
||||
- Backend E2E (requires Docker): `make test-e2e`
|
||||
- Frontend unit: `bun run test`
|
||||
- Frontend E2E: `bun run test:e2e`
|
||||
- Use `make test` or `make test-cov` in backend for convenience
|
||||
|
||||
### Backend
|
||||
**Security & Quality Commands (Backend):**
|
||||
- `make validate` — lint + format + type checks
|
||||
- `make audit` — dependency vulnerabilities + license compliance
|
||||
- `make validate-all` — quality + security checks
|
||||
- `make check` — **full pipeline**: quality + security + tests
|
||||
|
||||
#### Setup
|
||||
```bash
|
||||
cd backend
|
||||
python -m venv .venv
|
||||
source .venv/bin/activate # or .venv\Scripts\activate on Windows
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
#### Database Migrations
|
||||
```bash
|
||||
# Using the migration helper (preferred)
|
||||
python migrate.py generate "migration message" # Generate migration
|
||||
python migrate.py apply # Apply migrations
|
||||
python migrate.py auto "message" # Generate and apply in one step
|
||||
python migrate.py list # List all migrations
|
||||
python migrate.py current # Show current revision
|
||||
python migrate.py check # Check DB connection
|
||||
|
||||
# Or using Alembic directly
|
||||
alembic revision --autogenerate -m "message"
|
||||
alembic upgrade head
|
||||
```
|
||||
|
||||
#### Testing
|
||||
|
||||
**Test Coverage: High (comprehensive test suite)**
|
||||
- Security-focused testing with JWT algorithm attack prevention (CVE-2015-9235)
|
||||
- Session hijacking and privilege escalation tests included
|
||||
- Missing lines justified as defensive code, error handlers, and production-only code
|
||||
|
||||
```bash
|
||||
# Run all tests (uses pytest-xdist for parallel execution)
|
||||
IS_TEST=True pytest
|
||||
|
||||
# Run with coverage (use -n 0 for accurate coverage)
|
||||
IS_TEST=True pytest --cov=app --cov-report=term-missing -n 0
|
||||
|
||||
# Run specific test file
|
||||
IS_TEST=True pytest tests/api/test_auth.py -v
|
||||
|
||||
# Run single test
|
||||
IS_TEST=True pytest tests/api/test_auth.py::TestLogin::test_login_success -v
|
||||
|
||||
# Run with HTML coverage report
|
||||
IS_TEST=True pytest --cov=app --cov-report=html -n 0
|
||||
open htmlcov/index.html
|
||||
```
|
||||
|
||||
#### Running Locally
|
||||
```bash
|
||||
cd backend
|
||||
source .venv/bin/activate
|
||||
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||
### Frontend
|
||||
|
||||
#### Setup
|
||||
```bash
|
||||
cd frontend
|
||||
npm install
|
||||
```
|
||||
|
||||
#### Development
|
||||
```bash
|
||||
npm run dev # Start dev server on http://localhost:3000
|
||||
npm run build # Production build
|
||||
npm run lint # ESLint
|
||||
npm run type-check # TypeScript checking
|
||||
```
|
||||
|
||||
#### Testing
|
||||
```bash
|
||||
# Unit tests (Jest)
|
||||
npm test # Run all unit tests
|
||||
npm run test:watch # Watch mode
|
||||
npm run test:coverage # With coverage
|
||||
|
||||
# E2E tests (Playwright)
|
||||
npm run test:e2e # Run all E2E tests
|
||||
npm run test:e2e:ui # Open Playwright UI
|
||||
npm run test:e2e:debug # Debug mode
|
||||
npx playwright test auth-login.spec.ts # Run specific file
|
||||
```
|
||||
|
||||
**E2E Test Best Practices:**
|
||||
- Use `Promise.all()` pattern for Next.js Link navigation:
|
||||
```typescript
|
||||
await Promise.all([
|
||||
page.waitForURL('/target', { timeout: 10000 }),
|
||||
link.click()
|
||||
]);
|
||||
```
|
||||
- Use ID-based selectors for validation errors (e.g., `#email-error`)
|
||||
- Error IDs use dashes not underscores (`#new-password-error`)
|
||||
- Target `.border-destructive[role="alert"]` to avoid Next.js route announcer conflicts
|
||||
- Uses 12 workers in non-CI mode (`workers: 12` in `playwright.config.ts`)
|
||||
- URL assertions should use regex to handle query params: `/\/auth\/login/`
|
||||
|
||||
### Docker
|
||||
|
||||
```bash
|
||||
# Development (with hot reload)
|
||||
docker-compose -f docker-compose.dev.yml up
|
||||
|
||||
# Production
|
||||
docker-compose up -d
|
||||
|
||||
# Rebuild specific service
|
||||
docker-compose build backend
|
||||
docker-compose build frontend
|
||||
```
|
||||
|
||||
## Key Architectural Patterns
|
||||
|
||||
### Authentication Flow
|
||||
1. **Login**: `POST /api/v1/auth/login` returns access + refresh tokens
|
||||
- Access token: 15 minutes expiry (JWT)
|
||||
- Refresh token: 7 days expiry (JWT with JTI stored in DB)
|
||||
- Session tracking with device info (IP, user agent, device ID)
|
||||
|
||||
2. **Token Refresh**: `POST /api/v1/auth/refresh` validates refresh token JTI
|
||||
- Checks session is active in database
|
||||
- Issues new access token (refresh token remains valid)
|
||||
- Updates session `last_used_at`
|
||||
|
||||
3. **Authorization**: FastAPI dependencies in `api/dependencies/auth.py`
|
||||
- `get_current_user`: Validates access token, returns User (raises 401 if invalid)
|
||||
- `get_current_active_user`: Requires valid access token + active account
|
||||
- `get_optional_current_user`: Accepts both authenticated and anonymous users (returns User or None)
|
||||
- `get_current_superuser`: Requires superuser flag
|
||||
|
||||
### Database Pattern: Async SQLAlchemy
|
||||
- **Engine**: Created in `core/database.py` with connection pooling
|
||||
- **Sessions**: AsyncSession from `async_sessionmaker`
|
||||
- **CRUD**: Base class in `crud/base.py` with common operations
|
||||
- Inherits: `CRUDUser`, `CRUDSession`, `CRUDOrganization`
|
||||
- Pattern: `async def get(db: AsyncSession, id: str) -> Model | None`
|
||||
|
||||
### Frontend State Management
|
||||
- **Zustand stores**: `lib/stores/` (authStore, etc.)
|
||||
- **TanStack Query**: API data fetching/caching
|
||||
- **Auto-generated client**: `lib/api/generated/` from OpenAPI spec
|
||||
- Generate with: `npm run generate:api` (runs `scripts/generate-api-client.sh`)
|
||||
**Backend E2E Testing (requires Docker):**
|
||||
- Install deps: `make install-e2e`
|
||||
- Run all E2E tests: `make test-e2e`
|
||||
- Run schema tests only: `make test-e2e-schema`
|
||||
- Run all tests: `make test-all` (unit + E2E)
|
||||
- Uses Testcontainers (real PostgreSQL) + Schemathesis (OpenAPI contract testing)
|
||||
- Markers: `@pytest.mark.e2e`, `@pytest.mark.postgres`, `@pytest.mark.schemathesis`
|
||||
- See: `backend/docs/E2E_TESTING.md` for complete guide
|
||||
|
||||
### 🔴 CRITICAL: Auth Store Dependency Injection Pattern
|
||||
|
||||
@@ -230,394 +94,160 @@ const { user, isAuthenticated } = useAuth();
|
||||
1. `AuthContext.tsx` - DI boundary, legitimately needs real store
|
||||
2. `client.ts` - Non-React context, uses dynamic import + `__TEST_AUTH_STORE__` check
|
||||
|
||||
**See**: `frontend/docs/ARCHITECTURE_FIX_REPORT.md` for full details.
|
||||
### E2E Test Best Practices
|
||||
|
||||
### Session Management Architecture
|
||||
**Database-backed session tracking** (not just JWT):
|
||||
- Each refresh token has a corresponding `UserSession` record
|
||||
- Tracks: device info, IP, location, last used timestamp
|
||||
- Supports session revocation (logout from specific devices)
|
||||
- Cleanup job removes expired sessions
|
||||
When writing or fixing Playwright tests:
|
||||
|
||||
### Permission System
|
||||
Three-tier organization roles:
|
||||
- **Owner**: Full control (delete org, manage all members)
|
||||
- **Admin**: Can add/remove members, assign admin role (not owner)
|
||||
- **Member**: Read-only organization access
|
||||
**Navigation Pattern:**
|
||||
```typescript
|
||||
// ✅ CORRECT - Use Promise.all for Next.js Link clicks
|
||||
await Promise.all([
|
||||
page.waitForURL('/target', { timeout: 10000 }),
|
||||
link.click()
|
||||
]);
|
||||
```
|
||||
|
||||
Dependencies in `api/dependencies/permissions.py`:
|
||||
- `require_organization_owner`
|
||||
- `require_organization_admin`
|
||||
- `require_organization_member`
|
||||
- `can_manage_organization_member` (owner or admin, but not self-demotion)
|
||||
**Selectors:**
|
||||
- Use ID-based selectors for validation errors: `#email-error`
|
||||
- Error IDs use dashes not underscores: `#new-password-error`
|
||||
- Target `.border-destructive[role="alert"]` to avoid Next.js route announcer conflicts
|
||||
- Avoid generic `[role="alert"]` which matches multiple elements
|
||||
|
||||
## Testing Infrastructure
|
||||
**URL Assertions:**
|
||||
```typescript
|
||||
// ✅ Use regex to handle query params
|
||||
await expect(page).toHaveURL(/\/auth\/login/);
|
||||
|
||||
### Backend Test Patterns
|
||||
// ❌ Don't use exact strings (fails with query params)
|
||||
await expect(page).toHaveURL('/auth/login');
|
||||
```
|
||||
|
||||
**Fixtures** (in `tests/conftest.py`):
|
||||
- `async_test_db`: Fresh SQLite in-memory database per test
|
||||
- `client`: AsyncClient with test database override
|
||||
- `async_test_user`: Pre-created regular user
|
||||
- `async_test_superuser`: Pre-created superuser
|
||||
- `user_token` / `superuser_token`: Access tokens for API calls
|
||||
**Configuration:**
|
||||
- Uses 12 workers in non-CI mode (`playwright.config.ts`)
|
||||
- Reduces to 2 workers in CI for stability
|
||||
- Tests are designed to be non-flaky with proper waits
|
||||
|
||||
**Database Mocking for Exception Testing**:
|
||||
### Important Implementation Details
|
||||
|
||||
**Authentication Testing:**
|
||||
- Backend fixtures in `tests/conftest.py`:
|
||||
- `async_test_db`: Fresh SQLite per test
|
||||
- `async_test_user` / `async_test_superuser`: Pre-created users
|
||||
- `user_token` / `superuser_token`: Access tokens for API calls
|
||||
- Always use `@pytest.mark.asyncio` for async tests
|
||||
- Use `@pytest_asyncio.fixture` for async fixtures
|
||||
|
||||
**Database Testing:**
|
||||
```python
|
||||
# Mock database exceptions correctly
|
||||
from unittest.mock import patch, AsyncMock
|
||||
|
||||
# Mock database commit to raise exception
|
||||
async def mock_commit():
|
||||
raise OperationalError("Connection lost", {}, Exception())
|
||||
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
||||
with pytest.raises(OperationalError):
|
||||
await crud_method(session, obj_in=data)
|
||||
await repo_method(session, obj_in=data)
|
||||
mock_rollback.assert_called_once()
|
||||
```
|
||||
|
||||
**Testing Routes**:
|
||||
```python
|
||||
@pytest.mark.asyncio
|
||||
async def test_endpoint(client, user_token):
|
||||
response = await client.get(
|
||||
"/api/v1/endpoint",
|
||||
headers={"Authorization": f"Bearer {user_token}"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
```
|
||||
|
||||
**IMPORTANT**: Use `@pytest_asyncio.fixture` for async fixtures, not `@pytest.fixture`
|
||||
|
||||
### Frontend Test Patterns
|
||||
|
||||
**Unit Tests (Jest)**:
|
||||
```typescript
|
||||
// SSR-safe mocking
|
||||
jest.mock('@/lib/stores/authStore', () => ({
|
||||
useAuthStore: jest.fn()
|
||||
}));
|
||||
|
||||
beforeEach(() => {
|
||||
(useAuthStore as jest.Mock).mockReturnValue({
|
||||
user: mockUser,
|
||||
login: mockLogin
|
||||
});
|
||||
});
|
||||
```
|
||||
|
||||
**E2E Tests (Playwright)**:
|
||||
```typescript
|
||||
test('navigation', async ({ page }) => {
|
||||
await page.goto('/');
|
||||
|
||||
const link = page.getByRole('link', { name: 'Login' });
|
||||
await Promise.all([
|
||||
page.waitForURL(/\/auth\/login/, { timeout: 10000 }),
|
||||
link.click()
|
||||
]);
|
||||
|
||||
await expect(page).toHaveURL(/\/auth\/login/);
|
||||
});
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
### Environment Variables
|
||||
|
||||
**Backend** (`.env`):
|
||||
```bash
|
||||
# Database
|
||||
POSTGRES_USER=postgres
|
||||
POSTGRES_PASSWORD=your_password
|
||||
POSTGRES_HOST=db
|
||||
POSTGRES_PORT=5432
|
||||
POSTGRES_DB=app
|
||||
|
||||
# Security
|
||||
SECRET_KEY=your-secret-key-min-32-chars
|
||||
ENVIRONMENT=development|production
|
||||
CSP_MODE=relaxed|strict|disabled
|
||||
|
||||
# First Superuser (auto-created on init)
|
||||
FIRST_SUPERUSER_EMAIL=admin@example.com
|
||||
FIRST_SUPERUSER_PASSWORD=admin123
|
||||
|
||||
# CORS
|
||||
BACKEND_CORS_ORIGINS=["http://localhost:3000"]
|
||||
```
|
||||
|
||||
**Frontend** (`.env.local`):
|
||||
```bash
|
||||
NEXT_PUBLIC_API_URL=http://localhost:8000/api/v1
|
||||
```
|
||||
|
||||
### Database Connection Pooling
|
||||
Configured in `core/config.py`:
|
||||
- `db_pool_size`: 20 (default connections)
|
||||
- `db_max_overflow`: 50 (max overflow)
|
||||
- `db_pool_timeout`: 30 seconds
|
||||
- `db_pool_recycle`: 3600 seconds (recycle after 1 hour)
|
||||
|
||||
### Security Headers
|
||||
Automatically applied via middleware in `main.py`:
|
||||
- `X-Frame-Options: DENY`
|
||||
- `X-Content-Type-Options: nosniff`
|
||||
- `X-XSS-Protection: 1; mode=block`
|
||||
- `Strict-Transport-Security` (production only)
|
||||
- Content-Security-Policy (configurable via `CSP_MODE`)
|
||||
|
||||
### Rate Limiting
|
||||
- Implemented with `slowapi`
|
||||
- Default: 60 requests/minute per IP
|
||||
- Applied to auth endpoints (login, register, password reset)
|
||||
- Override in route decorators: `@limiter.limit("10/minute")`
|
||||
|
||||
## Common Workflows
|
||||
|
||||
### Adding a New API Endpoint
|
||||
|
||||
1. **Create schema** (`backend/app/schemas/`):
|
||||
```python
|
||||
class ItemCreate(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
|
||||
class ItemResponse(BaseModel):
|
||||
id: UUID
|
||||
name: str
|
||||
created_at: datetime
|
||||
```
|
||||
|
||||
2. **Create CRUD operations** (`backend/app/crud/`):
|
||||
```python
|
||||
class CRUDItem(CRUDBase[Item, ItemCreate, ItemUpdate]):
|
||||
async def get_by_name(self, db: AsyncSession, name: str) -> Item | None:
|
||||
result = await db.execute(select(Item).where(Item.name == name))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
item = CRUDItem(Item)
|
||||
```
|
||||
|
||||
3. **Create route** (`backend/app/api/routes/items.py`):
|
||||
```python
|
||||
from app.api.dependencies.auth import get_current_user
|
||||
|
||||
@router.post("/", response_model=ItemResponse)
|
||||
async def create_item(
|
||||
item_in: ItemCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
item = await item_crud.create(db, obj_in=item_in)
|
||||
return item
|
||||
```
|
||||
|
||||
4. **Register router** (`backend/app/api/main.py`):
|
||||
```python
|
||||
from app.api.routes import items
|
||||
api_router.include_router(items.router, prefix="/items", tags=["Items"])
|
||||
```
|
||||
|
||||
5. **Write tests** (`backend/tests/api/test_items.py`):
|
||||
```python
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_item(client, user_token):
|
||||
response = await client.post(
|
||||
"/api/v1/items",
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
json={"name": "Test Item"}
|
||||
)
|
||||
assert response.status_code == 201
|
||||
```
|
||||
|
||||
6. **Generate frontend client**:
|
||||
```bash
|
||||
cd frontend
|
||||
npm run generate:api
|
||||
```
|
||||
|
||||
### Adding a New React Component
|
||||
|
||||
1. **Create component** (`frontend/src/components/`):
|
||||
```typescript
|
||||
export function MyComponent() {
|
||||
const { user } = useAuthStore();
|
||||
return <div>Hello {user?.firstName}</div>;
|
||||
}
|
||||
```
|
||||
|
||||
2. **Add tests** (`frontend/src/components/__tests__/`):
|
||||
```typescript
|
||||
import { render, screen } from '@testing-library/react';
|
||||
|
||||
test('renders component', () => {
|
||||
render(<MyComponent />);
|
||||
expect(screen.getByText(/Hello/)).toBeInTheDocument();
|
||||
});
|
||||
```
|
||||
|
||||
3. **Add to page** (`frontend/src/app/page.tsx`):
|
||||
```typescript
|
||||
import { MyComponent } from '@/components/MyComponent';
|
||||
|
||||
export default function Page() {
|
||||
return <MyComponent />;
|
||||
}
|
||||
```
|
||||
|
||||
## Current Project Status (Nov 2025)
|
||||
|
||||
### Completed Features
|
||||
- ✅ Authentication system (JWT with refresh tokens)
|
||||
- ✅ Session management (device tracking, revocation)
|
||||
- ✅ User management (CRUD, password change)
|
||||
- ✅ Organization system (multi-tenant with roles)
|
||||
- ✅ Admin panel (user/org management, bulk operations)
|
||||
- ✅ E2E test suite (56 passing, 1 skipped, zero flaky tests)
|
||||
|
||||
### Test Coverage
|
||||
- **Backend**: 97% overall (743 tests, all passing) ✅
|
||||
- Comprehensive security testing (JWT attacks, session hijacking, privilege escalation)
|
||||
- User CRUD: 100% ✅
|
||||
- Session CRUD: 100% ✅
|
||||
- Auth routes: 99% ✅
|
||||
- Organization routes: 100% ✅
|
||||
- Permissions: 100% ✅
|
||||
- 84 missing lines justified (defensive code, error handlers, production-only code)
|
||||
|
||||
- **Frontend E2E**: 56 passing, 1 skipped across 7 files ✅
|
||||
- auth-login.spec.ts (19 tests)
|
||||
- auth-register.spec.ts (14 tests)
|
||||
- auth-password-reset.spec.ts (10 tests)
|
||||
- navigation.spec.ts (10 tests)
|
||||
- settings-password.spec.ts (3 tests)
|
||||
- settings-profile.spec.ts (2 tests)
|
||||
- settings-navigation.spec.ts (5 tests)
|
||||
- settings-sessions.spec.ts (1 skipped - route not yet implemented)
|
||||
|
||||
## Email Service Integration
|
||||
|
||||
The project includes a **placeholder email service** (`backend/app/services/email_service.py`) designed for easy integration with production email providers.
|
||||
|
||||
### Current Implementation
|
||||
|
||||
**Console Backend (Default)**:
|
||||
- Logs email content to console/logs instead of sending
|
||||
- Safe for development and testing
|
||||
- No external dependencies required
|
||||
|
||||
### Production Integration
|
||||
|
||||
To enable email functionality, implement one of these approaches:
|
||||
|
||||
**Option 1: SMTP Integration** (Recommended for most use cases)
|
||||
```python
|
||||
# In app/services/email_service.py, complete the SMTPEmailBackend implementation
|
||||
from aiosmtplib import SMTP
|
||||
from email.mime.text import MIMEText
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
|
||||
# Add environment variables to .env:
|
||||
# SMTP_HOST=smtp.gmail.com
|
||||
# SMTP_PORT=587
|
||||
# SMTP_USERNAME=your-email@gmail.com
|
||||
# SMTP_PASSWORD=your-app-password
|
||||
```
|
||||
|
||||
**Option 2: Third-Party Service** (SendGrid, AWS SES, Mailgun, etc.)
|
||||
```python
|
||||
# Create a new backend class, e.g., SendGridEmailBackend
|
||||
class SendGridEmailBackend(EmailBackend):
|
||||
def __init__(self, api_key: str):
|
||||
self.api_key = api_key
|
||||
self.client = sendgrid.SendGridAPIClient(api_key)
|
||||
|
||||
async def send_email(self, to, subject, html_content, text_content=None):
|
||||
# Implement SendGrid sending logic
|
||||
pass
|
||||
|
||||
# Update global instance in email_service.py:
|
||||
# email_service = EmailService(SendGridEmailBackend(settings.SENDGRID_API_KEY))
|
||||
```
|
||||
|
||||
**Option 3: External Microservice**
|
||||
- Use a dedicated email microservice via HTTP API
|
||||
- Implement `HTTPEmailBackend` that makes async HTTP requests
|
||||
|
||||
### Email Templates Included
|
||||
|
||||
The service includes pre-built templates for:
|
||||
- **Password Reset**: `send_password_reset_email()` - 1 hour expiry
|
||||
- **Email Verification**: `send_email_verification()` - 24 hour expiry
|
||||
|
||||
Both include responsive HTML and plain text versions.
|
||||
|
||||
### Integration Points
|
||||
|
||||
Email sending is called from:
|
||||
- `app/api/routes/auth.py` - Password reset flow (placeholder comments)
|
||||
- Registration flow - Ready for email verification integration
|
||||
|
||||
**Note**: Current auth routes have placeholder comments where email functionality should be integrated. Search for "TODO: Send email" in the codebase.
|
||||
|
||||
## API Documentation
|
||||
|
||||
Once backend is running:
|
||||
- **Swagger UI**: http://localhost:8000/docs
|
||||
- **ReDoc**: http://localhost:8000/redoc
|
||||
- **OpenAPI JSON**: http://localhost:8000/api/v1/openapi.json
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Tests failing with "Module was never imported"
|
||||
Run with single process: `pytest -n 0`
|
||||
|
||||
### Coverage not improving despite new tests
|
||||
- Verify tests actually execute endpoints (check response.status_code)
|
||||
- Generate HTML coverage: `pytest --cov=app --cov-report=html -n 0`
|
||||
- Check for dependency override issues in test fixtures
|
||||
|
||||
### Frontend type errors
|
||||
```bash
|
||||
npm run type-check # Check all types
|
||||
npx tsc --noEmit # Same but shorter
|
||||
```
|
||||
|
||||
### E2E tests flaking
|
||||
- Check worker count (should be 4, not 16+)
|
||||
- Use `Promise.all()` for navigation
|
||||
- Use regex for URL assertions
|
||||
- Target specific selectors (avoid generic `[role="alert"]`)
|
||||
|
||||
### Database migration conflicts
|
||||
```bash
|
||||
python migrate.py list # Check migration history
|
||||
alembic downgrade -1 # Downgrade one revision
|
||||
alembic upgrade head # Re-apply
|
||||
```
|
||||
|
||||
## Additional Documentation
|
||||
|
||||
### Backend Documentation
|
||||
- `backend/docs/ARCHITECTURE.md`: System architecture and design patterns
|
||||
- `backend/docs/CODING_STANDARDS.md`: Code quality standards and best practices
|
||||
- `backend/docs/COMMON_PITFALLS.md`: Common mistakes and how to avoid them
|
||||
- `backend/docs/FEATURE_EXAMPLE.md`: Step-by-step feature implementation guide
|
||||
|
||||
### Frontend Documentation
|
||||
- **`frontend/docs/ARCHITECTURE_FIX_REPORT.md`**: ⭐ Critical DI pattern fixes (READ THIS!)
|
||||
- `frontend/e2e/README.md`: E2E testing setup and guidelines
|
||||
- **`frontend/docs/design-system/`**: Comprehensive design system documentation
|
||||
- `README.md`: Hub with learning paths (start here)
|
||||
- `00-quick-start.md`: 5-minute crash course
|
||||
- `01-foundations.md`: Colors (OKLCH), typography, spacing, shadows
|
||||
- `02-components.md`: shadcn/ui component library guide
|
||||
- `03-layouts.md`: Layout patterns (Grid vs Flex decision trees)
|
||||
- `04-spacing-philosophy.md`: Parent-controlled spacing strategy
|
||||
- `05-component-creation.md`: When to create vs compose components
|
||||
- `06-forms.md`: Form patterns with react-hook-form + Zod
|
||||
- `07-accessibility.md`: WCAG AA compliance, keyboard navigation, screen readers
|
||||
- `08-ai-guidelines.md`: **AI code generation rules (read this!)**
|
||||
- `99-reference.md`: Quick reference cheat sheet (bookmark this)
|
||||
**Frontend Component Development:**
|
||||
- Follow design system docs in `frontend/docs/design-system/`
|
||||
- Read `08-ai-guidelines.md` for AI code generation rules
|
||||
- Use parent-controlled spacing (see `04-spacing-philosophy.md`)
|
||||
- WCAG AA compliance required (see `07-accessibility.md`)
|
||||
|
||||
**Security Considerations:**
|
||||
- Backend has comprehensive security tests (JWT attacks, session hijacking)
|
||||
- Never skip security headers in production
|
||||
- Rate limiting is configured in route decorators: `@limiter.limit("10/minute")`
|
||||
- Session revocation is database-backed, not just JWT expiry
|
||||
- Run `make audit` to check for dependency vulnerabilities and license compliance
|
||||
- Run `make check` for the full pipeline: quality + security + tests
|
||||
- Pre-commit hooks enforce Ruff lint/format and detect-secrets on every commit
|
||||
- Setup hooks: `cd backend && uv run pre-commit install`
|
||||
|
||||
### Common Workflows Guidance
|
||||
|
||||
**When Adding a New Feature:**
|
||||
1. Start with backend schema and repository
|
||||
2. Implement API route with proper authorization
|
||||
3. Write backend tests (aim for >90% coverage)
|
||||
4. Generate frontend API client: `bun run generate:api`
|
||||
5. Implement frontend components
|
||||
6. Write frontend unit tests
|
||||
7. Add E2E tests for critical flows
|
||||
8. Update relevant documentation
|
||||
|
||||
**When Fixing Tests:**
|
||||
- Backend: Check test database isolation and async fixture usage
|
||||
- Frontend unit: Verify mocking of `useAuth()` not `useAuthStore`
|
||||
- E2E: Use `Promise.all()` pattern and regex URL assertions
|
||||
|
||||
**When Debugging:**
|
||||
- Backend: Check `IS_TEST=True` environment variable is set
|
||||
- Frontend: Run `bun run type-check` first
|
||||
- E2E: Use `bun run test:e2e:debug` for step-by-step debugging
|
||||
- Check logs: Backend has detailed error logging
|
||||
|
||||
**Demo Mode (Frontend-Only Showcase):**
|
||||
- Enable: `echo "NEXT_PUBLIC_DEMO_MODE=true" > frontend/.env.local`
|
||||
- Uses MSW (Mock Service Worker) to intercept API calls in browser
|
||||
- Zero backend required - perfect for Vercel deployments
|
||||
- **Fully Automated**: MSW handlers auto-generated from OpenAPI spec
|
||||
- Run `bun run generate:api` → updates both API client AND MSW handlers
|
||||
- No manual synchronization needed!
|
||||
- Demo credentials (any password ≥8 chars works):
|
||||
- User: `demo@example.com` / `DemoPass123`
|
||||
- Admin: `admin@example.com` / `AdminPass123`
|
||||
- **Safe**: MSW never runs during tests (Jest or Playwright)
|
||||
- **Coverage**: Mock files excluded from linting and coverage
|
||||
- **Documentation**: `frontend/docs/DEMO_MODE.md` for complete guide
|
||||
|
||||
### Tool Usage Preferences
|
||||
|
||||
**Prefer specialized tools over bash:**
|
||||
- Use Read/Write/Edit tools for file operations
|
||||
- Never use `cat`, `echo >`, or heredoc for file manipulation
|
||||
- Use Task tool with `subagent_type=Explore` for codebase exploration
|
||||
- Use Grep tool for code search, not bash `grep`
|
||||
|
||||
**When to use parallel tool calls:**
|
||||
- Independent git commands: `git status`, `git diff`, `git log`
|
||||
- Reading multiple unrelated files
|
||||
- Running multiple test suites simultaneously
|
||||
- Independent validation steps
|
||||
|
||||
## Custom Skills
|
||||
|
||||
No Claude Code Skills installed yet. To create one, invoke the built-in "skill-creator" skill.
|
||||
|
||||
**Potential skill ideas for this project:**
|
||||
- API endpoint generator workflow (schema → repository → route → tests → frontend client)
|
||||
- Component generator with design system compliance
|
||||
- Database migration troubleshooting helper
|
||||
- Test coverage analyzer and improvement suggester
|
||||
- E2E test generator for new features
|
||||
|
||||
## Additional Resources
|
||||
|
||||
**Comprehensive Documentation:**
|
||||
- [AGENTS.md](./AGENTS.md) - Framework-agnostic AI assistant context
|
||||
- [README.md](./README.md) - User-facing project overview
|
||||
- `backend/docs/` - Backend architecture, coding standards, common pitfalls
|
||||
- `frontend/docs/design-system/` - Complete design system guide
|
||||
|
||||
**API Documentation (when running):**
|
||||
- Swagger UI: http://localhost:8000/docs
|
||||
- ReDoc: http://localhost:8000/redoc
|
||||
- OpenAPI JSON: http://localhost:8000/api/v1/openapi.json
|
||||
|
||||
**Testing Documentation:**
|
||||
- Backend tests: `backend/tests/` (97% coverage)
|
||||
- Frontend E2E: `frontend/e2e/README.md`
|
||||
- Design system: `frontend/docs/design-system/08-ai-guidelines.md`
|
||||
|
||||
---
|
||||
|
||||
**For project architecture, development commands, and general context, see [AGENTS.md](./AGENTS.md).**
|
||||
|
||||
@@ -90,22 +90,27 @@ Ready to write some code? Awesome!
|
||||
```bash
|
||||
cd backend
|
||||
|
||||
# Setup virtual environment
|
||||
python -m venv .venv
|
||||
source .venv/bin/activate
|
||||
# Install dependencies (uv manages virtual environment automatically)
|
||||
make install-dev
|
||||
|
||||
# Install dependencies
|
||||
pip install -r requirements.txt
|
||||
# Setup pre-commit hooks
|
||||
uv run pre-commit install
|
||||
|
||||
# Setup environment
|
||||
cp .env.example .env
|
||||
# Edit .env with your settings
|
||||
|
||||
# Run migrations
|
||||
alembic upgrade head
|
||||
python migrate.py apply
|
||||
|
||||
# Run quality + security checks
|
||||
make validate-all
|
||||
|
||||
# Run tests
|
||||
IS_TEST=True pytest
|
||||
make test
|
||||
|
||||
# Run full pipeline (quality + security + tests)
|
||||
make check
|
||||
|
||||
# Start dev server
|
||||
uvicorn app.main:app --reload
|
||||
@@ -117,20 +122,20 @@ uvicorn app.main:app --reload
|
||||
cd frontend
|
||||
|
||||
# Install dependencies
|
||||
npm install
|
||||
bun install
|
||||
|
||||
# Setup environment
|
||||
cp .env.local.example .env.local
|
||||
|
||||
# Generate API client
|
||||
npm run generate:api
|
||||
bun run generate:api
|
||||
|
||||
# Run tests
|
||||
npm test
|
||||
npm run test:e2e:ui
|
||||
bun run test
|
||||
bun run test:e2e:ui
|
||||
|
||||
# Start dev server
|
||||
npm run dev
|
||||
bun run dev
|
||||
```
|
||||
|
||||
---
|
||||
@@ -199,7 +204,7 @@ export function UserProfile({ userId }: UserProfileProps) {
|
||||
|
||||
### Key Patterns
|
||||
|
||||
- **Backend**: Use CRUD pattern, keep routes thin, business logic in services
|
||||
- **Backend**: Use repository pattern, keep routes thin, business logic in services
|
||||
- **Frontend**: Use React Query for server state, Zustand for client state
|
||||
- **Both**: Handle errors gracefully, log appropriately, write tests
|
||||
|
||||
@@ -320,7 +325,7 @@ Fixed stuff
|
||||
### Before Submitting
|
||||
|
||||
- [ ] Code follows project style guidelines
|
||||
- [ ] All tests pass locally
|
||||
- [ ] `make check` passes (quality + security + tests) in backend
|
||||
- [ ] New tests added for new features
|
||||
- [ ] Documentation updated if needed
|
||||
- [ ] No merge conflicts with `main`
|
||||
|
||||
119
Makefile
119
Makefile
@@ -1,31 +1,124 @@
|
||||
.PHONY: dev prod down clean clean-slate
|
||||
.PHONY: help dev dev-full prod down logs logs-dev clean clean-slate drop-db reset-db push-images deploy scan-images
|
||||
|
||||
VERSION ?= latest
|
||||
REGISTRY := gitea.pragmazest.com/cardosofelipe/app
|
||||
REGISTRY ?= ghcr.io/cardosofelipe/pragma-stack
|
||||
|
||||
# Default target
|
||||
help:
|
||||
@echo "FastAPI + Next.js Full-Stack Template"
|
||||
@echo ""
|
||||
@echo "Development:"
|
||||
@echo " make dev - Start backend + db (frontend runs separately)"
|
||||
@echo " make dev-full - Start all services including frontend"
|
||||
@echo " make down - Stop all services"
|
||||
@echo " make logs-dev - Follow dev container logs"
|
||||
@echo ""
|
||||
@echo "Database:"
|
||||
@echo " make drop-db - Drop and recreate empty database"
|
||||
@echo " make reset-db - Drop database and apply all migrations"
|
||||
@echo ""
|
||||
@echo "Production:"
|
||||
@echo " make prod - Start production stack"
|
||||
@echo " make deploy - Pull and deploy latest images"
|
||||
@echo " make push-images - Build and push images to registry"
|
||||
@echo " make scan-images - Scan production images for CVEs (requires trivy)"
|
||||
@echo " make logs - Follow production container logs"
|
||||
@echo ""
|
||||
@echo "Cleanup:"
|
||||
@echo " make clean - Stop containers"
|
||||
@echo " make clean-slate - Stop containers AND delete volumes (DATA LOSS!)"
|
||||
@echo ""
|
||||
@echo "Subdirectory commands:"
|
||||
@echo " cd backend && make help - Backend-specific commands"
|
||||
@echo " cd frontend && npm run - Frontend-specific commands"
|
||||
|
||||
# ============================================================================
|
||||
# Development
|
||||
# ============================================================================
|
||||
|
||||
dev:
|
||||
docker compose -f docker-compose.dev.yml up --build -d
|
||||
# Bring up all dev services except the frontend
|
||||
docker compose -f docker-compose.dev.yml up --build -d --scale frontend=0
|
||||
@echo ""
|
||||
@echo "Frontend is not started by 'make dev'."
|
||||
@echo "To run the frontend locally, open a new terminal and run:"
|
||||
@echo " cd frontend && npm run dev"
|
||||
|
||||
prod:
|
||||
docker compose up --build -d
|
||||
dev-full:
|
||||
# Bring up all dev services including the frontend (full stack)
|
||||
docker compose -f docker-compose.dev.yml up --build -d
|
||||
|
||||
down:
|
||||
docker compose down
|
||||
|
||||
logs:
|
||||
docker compose logs -f
|
||||
|
||||
logs-dev:
|
||||
docker compose -f docker-compose.dev.yml logs -f
|
||||
|
||||
# ============================================================================
|
||||
# Database Management
|
||||
# ============================================================================
|
||||
|
||||
drop-db:
|
||||
@echo "Dropping local database..."
|
||||
@docker compose -f docker-compose.dev.yml exec -T db psql -U postgres -c "DROP DATABASE IF EXISTS app WITH (FORCE);" 2>/dev/null || \
|
||||
docker compose -f docker-compose.dev.yml exec -T db psql -U postgres -c "DROP DATABASE IF EXISTS app;"
|
||||
@docker compose -f docker-compose.dev.yml exec -T db psql -U postgres -c "CREATE DATABASE app;"
|
||||
@echo "Database dropped and recreated (empty)"
|
||||
|
||||
reset-db: drop-db
|
||||
@echo "Applying migrations..."
|
||||
@cd backend && uv run python migrate.py --local apply
|
||||
@echo "Database reset complete!"
|
||||
|
||||
# ============================================================================
|
||||
# Production / Deployment
|
||||
# ============================================================================
|
||||
|
||||
prod:
|
||||
docker compose up --build -d
|
||||
|
||||
deploy:
|
||||
docker compose -f docker-compose.deploy.yml pull
|
||||
docker compose -f docker-compose.deploy.yml up -d
|
||||
|
||||
clean:
|
||||
docker compose down -
|
||||
|
||||
# WARNING! THIS REMOVES CONTAINERS AND VOLUMES AS WELL - DO NOT USE THIS UNLESS YOU WANT TO START OVER WITH DATA AND ALL
|
||||
clean-slate:
|
||||
docker compose down -v
|
||||
|
||||
push-images:
|
||||
docker build -t $(REGISTRY)/backend:$(VERSION) ./backend
|
||||
docker build -t $(REGISTRY)/frontend:$(VERSION) ./frontend
|
||||
docker push $(REGISTRY)/backend:$(VERSION)
|
||||
docker push $(REGISTRY)/frontend:$(VERSION)
|
||||
docker push $(REGISTRY)/frontend:$(VERSION)
|
||||
|
||||
scan-images:
|
||||
@docker info > /dev/null 2>&1 || (echo "❌ Docker is not running!"; exit 1)
|
||||
@echo "🐳 Building and scanning production images for CVEs..."
|
||||
docker build -t $(REGISTRY)/backend:scan --target production ./backend
|
||||
docker build -t $(REGISTRY)/frontend:scan --target runner ./frontend
|
||||
@echo ""
|
||||
@echo "=== Backend Image Scan ==="
|
||||
@if command -v trivy > /dev/null 2>&1; then \
|
||||
trivy image --severity HIGH,CRITICAL --exit-code 1 $(REGISTRY)/backend:scan; \
|
||||
else \
|
||||
echo "ℹ️ Trivy not found locally, using Docker to run Trivy..."; \
|
||||
docker run --rm -v /var/run/docker.sock:/var/run/docker.sock aquasec/trivy image --severity HIGH,CRITICAL --exit-code 1 $(REGISTRY)/backend:scan; \
|
||||
fi
|
||||
@echo ""
|
||||
@echo "=== Frontend Image Scan ==="
|
||||
@if command -v trivy > /dev/null 2>&1; then \
|
||||
trivy image --severity HIGH,CRITICAL --exit-code 1 $(REGISTRY)/frontend:scan; \
|
||||
else \
|
||||
docker run --rm -v /var/run/docker.sock:/var/run/docker.sock aquasec/trivy image --severity HIGH,CRITICAL --exit-code 1 $(REGISTRY)/frontend:scan; \
|
||||
fi
|
||||
@echo "✅ No HIGH/CRITICAL CVEs found in production images!"
|
||||
|
||||
# ============================================================================
|
||||
# Cleanup
|
||||
# ============================================================================
|
||||
|
||||
clean:
|
||||
docker compose down
|
||||
|
||||
# WARNING! THIS REMOVES CONTAINERS AND VOLUMES AS WELL - DO NOT USE THIS UNLESS YOU WANT TO START OVER WITH DATA AND ALL
|
||||
clean-slate:
|
||||
docker compose -f docker-compose.dev.yml down -v --remove-orphans
|
||||
|
||||
224
README.md
224
README.md
@@ -1,29 +1,29 @@
|
||||
# FastAPI + Next.js Full-Stack Template
|
||||
# <img src="frontend/public/logo.svg" alt="PragmaStack" width="32" height="32" style="vertical-align: middle" /> PragmaStack
|
||||
|
||||
> **Production-ready, security-first, full-stack TypeScript/Python template with authentication, multi-tenancy, and a comprehensive admin panel.**
|
||||
> **The Pragmatic Full-Stack Template. Production-ready, security-first, and opinionated.**
|
||||
|
||||
<!--
|
||||
TODO: Replace these static badges with dynamic CI/CD badges when GitHub Actions is set up
|
||||
Example: https://github.com/YOUR_ORG/YOUR_REPO/actions/workflows/backend-tests.yml/badge.svg
|
||||
-->
|
||||
|
||||
[](./backend/tests)
|
||||
[](./backend/tests)
|
||||
[](./frontend/tests)
|
||||
[](./frontend/tests)
|
||||
[](./frontend/e2e)
|
||||
[](./LICENSE)
|
||||
[](./CONTRIBUTING.md)
|
||||
|
||||

|
||||
|
||||
---
|
||||
|
||||
## Why This Template?
|
||||
## Why PragmaStack?
|
||||
|
||||
Building a modern full-stack application from scratch means solving the same problems over and over: authentication, authorization, multi-tenancy, admin panels, session management, database migrations, API documentation, testing infrastructure...
|
||||
Building a modern full-stack application often leads to "analysis paralysis" or "boilerplate fatigue". You spend weeks setting up authentication, testing, and linting before writing a single line of business logic.
|
||||
|
||||
**This template gives you all of that, battle-tested and ready to go.**
|
||||
**PragmaStack cuts through the noise.**
|
||||
|
||||
Instead of spending weeks on boilerplate, you can focus on building your unique features. Whether you're building a SaaS product, an internal tool, or a side project, this template provides a rock-solid foundation with modern best practices baked in.
|
||||
We provide a **pragmatic**, opinionated foundation that prioritizes:
|
||||
- **Speed**: Ship features, not config files.
|
||||
- **Robustness**: Security and testing are not optional.
|
||||
- **Clarity**: Code that is easy to read and maintain.
|
||||
|
||||
Whether you're building a SaaS, an internal tool, or a side project, PragmaStack gives you a solid starting point without the bloat.
|
||||
|
||||
---
|
||||
|
||||
@@ -31,12 +31,26 @@ Instead of spending weeks on boilerplate, you can focus on building your unique
|
||||
|
||||
### 🔐 **Authentication & Security**
|
||||
- JWT-based authentication with access + refresh tokens
|
||||
- **OAuth/Social Login** (Google, GitHub) with PKCE support
|
||||
- **OAuth 2.0 Authorization Server** (MCP-ready) for third-party integrations
|
||||
- Session management with device tracking and revocation
|
||||
- Password reset flow (email integration ready)
|
||||
- Secure password hashing (bcrypt)
|
||||
- CSRF protection, rate limiting, and security headers
|
||||
- Comprehensive security tests (JWT algorithm attacks, session hijacking, privilege escalation)
|
||||
|
||||
### 🔌 **OAuth Provider Mode (MCP Integration)**
|
||||
Full OAuth 2.0 Authorization Server for Model Context Protocol (MCP) and third-party clients:
|
||||
- **RFC 7636**: Authorization Code Flow with PKCE (S256 only)
|
||||
- **RFC 8414**: Server metadata discovery at `/.well-known/oauth-authorization-server`
|
||||
- **RFC 7662**: Token introspection endpoint
|
||||
- **RFC 7009**: Token revocation endpoint
|
||||
- **JWT access tokens**: Self-contained, configurable lifetime
|
||||
- **Opaque refresh tokens**: Secure rotation, database-backed revocation
|
||||
- **Consent management**: Users can review and revoke app permissions
|
||||
- **Client management**: Admin endpoints for registering OAuth clients
|
||||
- **Scopes**: `openid`, `profile`, `email`, `read:users`, `write:users`, `admin`
|
||||
|
||||
### 👥 **Multi-Tenancy & Organizations**
|
||||
- Full organization system with role-based access control (Owner, Admin, Member)
|
||||
- Invite/remove members, manage permissions
|
||||
@@ -44,18 +58,35 @@ Instead of spending weeks on boilerplate, you can focus on building your unique
|
||||
- User can belong to multiple organizations
|
||||
|
||||
### 🛠️ **Admin Panel**
|
||||
- Complete user management (CRUD, activate/deactivate, bulk operations)
|
||||
- Complete user management (full lifecycle, activate/deactivate, bulk operations)
|
||||
- Organization management (create, edit, delete, member management)
|
||||
- Session monitoring across all users
|
||||
- Real-time statistics dashboard
|
||||
- Admin-only routes with proper authorization
|
||||
|
||||
### 🎨 **Modern Frontend**
|
||||
- Next.js 15 with App Router and React 19
|
||||
- Comprehensive design system built on shadcn/ui + TailwindCSS
|
||||
- Next.js 16 with App Router and React 19
|
||||
- **PragmaStack Design System** built on shadcn/ui + TailwindCSS
|
||||
- Pre-configured theme with dark mode support (coming soon)
|
||||
- Responsive, accessible components (WCAG AA compliant)
|
||||
- Developer documentation at `/dev` (in progress)
|
||||
- Rich marketing landing page with animated components
|
||||
- Live component showcase and documentation at `/dev`
|
||||
|
||||
### 🌍 **Internationalization (i18n)**
|
||||
- Built-in multi-language support with next-intl v4
|
||||
- Locale-based routing (`/en/*`, `/it/*`)
|
||||
- Seamless language switching with LocaleSwitcher component
|
||||
- SEO-friendly URLs and metadata per locale
|
||||
- Translation files for English and Italian (easily extensible)
|
||||
- Type-safe translations throughout the app
|
||||
|
||||
### 🎯 **Content & UX Features**
|
||||
- **Toast notifications** with Sonner for elegant user feedback
|
||||
- **Smooth animations** powered by Framer Motion
|
||||
- **Markdown rendering** with syntax highlighting (GitHub Flavored Markdown)
|
||||
- **Charts and visualizations** ready with Recharts
|
||||
- **SEO optimization** with dynamic sitemap and robots.txt generation
|
||||
- **Session tracking UI** with device information and revocation controls
|
||||
|
||||
### 🧪 **Comprehensive Testing**
|
||||
- **Backend Testing**: ~97% unit test coverage
|
||||
@@ -75,9 +106,10 @@ Instead of spending weeks on boilerplate, you can focus on building your unique
|
||||
### 📚 **Developer Experience**
|
||||
- Auto-generated TypeScript API client from OpenAPI spec
|
||||
- Interactive API documentation (Swagger + ReDoc)
|
||||
- Database migrations with Alembic
|
||||
- Hot reload in development
|
||||
- Comprehensive code documentation
|
||||
- Database migrations with Alembic helper script
|
||||
- Hot reload in development for both frontend and backend
|
||||
- Comprehensive code documentation and design system docs
|
||||
- Live component playground at `/dev` with code examples
|
||||
- Docker support for easy deployment
|
||||
- VSCode workspace settings included
|
||||
|
||||
@@ -89,6 +121,68 @@ Instead of spending weeks on boilerplate, you can focus on building your unique
|
||||
- Health check endpoints
|
||||
- Production security headers
|
||||
- Rate limiting on sensitive endpoints
|
||||
- SEO optimization with dynamic sitemaps and robots.txt
|
||||
- Multi-language SEO with locale-specific metadata
|
||||
- Performance monitoring and bundle analysis
|
||||
|
||||
---
|
||||
|
||||
## 📸 Screenshots
|
||||
|
||||
<details>
|
||||
<summary>Click to view screenshots</summary>
|
||||
|
||||
### Landing Page
|
||||

|
||||
|
||||
|
||||
|
||||
### Authentication
|
||||

|
||||
|
||||
|
||||
|
||||
### Admin Dashboard
|
||||

|
||||
|
||||
|
||||
|
||||
### Design System
|
||||

|
||||
|
||||
</details>
|
||||
|
||||
---
|
||||
|
||||
## 🎭 Demo Mode
|
||||
|
||||
**Try the frontend without a backend!** Perfect for:
|
||||
- **Free deployment** on Vercel (no backend costs)
|
||||
- **Portfolio showcasing** with live demos
|
||||
- **Client presentations** without infrastructure setup
|
||||
|
||||
### Quick Start
|
||||
|
||||
```bash
|
||||
cd frontend
|
||||
echo "NEXT_PUBLIC_DEMO_MODE=true" > .env.local
|
||||
bun run dev
|
||||
```
|
||||
|
||||
**Demo Credentials:**
|
||||
- Regular user: `demo@example.com` / `DemoPass123`
|
||||
- Admin user: `admin@example.com` / `AdminPass123`
|
||||
|
||||
Demo mode uses [Mock Service Worker (MSW)](https://mswjs.io/) to intercept API calls in the browser. Your code remains unchanged - the same components work with both real and mocked backends.
|
||||
|
||||
**Key Features:**
|
||||
- ✅ Zero backend required
|
||||
- ✅ All features functional (auth, admin, stats)
|
||||
- ✅ Realistic network delays and errors
|
||||
- ✅ Does NOT interfere with tests (97%+ coverage maintained)
|
||||
- ✅ One-line toggle: `NEXT_PUBLIC_DEMO_MODE=true`
|
||||
|
||||
📖 **[Complete Demo Mode Documentation](./frontend/docs/DEMO_MODE.md)**
|
||||
|
||||
---
|
||||
|
||||
@@ -103,13 +197,18 @@ Instead of spending weeks on boilerplate, you can focus on building your unique
|
||||
- **[pytest](https://pytest.org/)** - Testing framework with async support
|
||||
|
||||
### Frontend
|
||||
- **[Next.js 15](https://nextjs.org/)** - React framework with App Router
|
||||
- **[Next.js 16](https://nextjs.org/)** - React framework with App Router
|
||||
- **[React 19](https://react.dev/)** - UI library
|
||||
- **[TypeScript](https://www.typescriptlang.org/)** - Type-safe JavaScript
|
||||
- **[TailwindCSS](https://tailwindcss.com/)** - Utility-first CSS framework
|
||||
- **[shadcn/ui](https://ui.shadcn.com/)** - Beautiful, accessible component library
|
||||
- **[next-intl](https://next-intl.dev/)** - Internationalization (i18n) with type safety
|
||||
- **[TanStack Query](https://tanstack.com/query)** - Powerful data fetching/caching
|
||||
- **[Zustand](https://zustand-demo.pmnd.rs/)** - Lightweight state management
|
||||
- **[Framer Motion](https://www.framer.com/motion/)** - Production-ready animation library
|
||||
- **[Sonner](https://sonner.emilkowal.ski/)** - Beautiful toast notifications
|
||||
- **[Recharts](https://recharts.org/)** - Composable charting library
|
||||
- **[React Markdown](https://github.com/remarkjs/react-markdown)** - Markdown rendering with GFM support
|
||||
- **[Playwright](https://playwright.dev/)** - End-to-end testing
|
||||
|
||||
### DevOps
|
||||
@@ -135,12 +234,11 @@ The fastest way to get started is with Docker:
|
||||
|
||||
```bash
|
||||
# Clone the repository
|
||||
git clone https://github.com/yourusername/fast-next-template.git
|
||||
git clone https://github.com/cardosofelipe/pragma-stack.git
|
||||
cd fast-next-template
|
||||
|
||||
# Copy environment files
|
||||
cp backend/.env.example backend/.env
|
||||
cp frontend/.env.local.example frontend/.env.local
|
||||
# Copy environment file
|
||||
cp .env.template .env
|
||||
|
||||
# Start all services (backend, frontend, database)
|
||||
docker-compose up
|
||||
@@ -200,17 +298,17 @@ uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
|
||||
cd frontend
|
||||
|
||||
# Install dependencies
|
||||
npm install
|
||||
bun install
|
||||
|
||||
# Setup environment
|
||||
cp .env.local.example .env.local
|
||||
# Edit .env.local with your backend URL
|
||||
|
||||
# Generate API client
|
||||
npm run generate:api
|
||||
bun run generate:api
|
||||
|
||||
# Start development server
|
||||
npm run dev
|
||||
bun run dev
|
||||
```
|
||||
|
||||
Visit http://localhost:3000 to see your app!
|
||||
@@ -224,7 +322,7 @@ Visit http://localhost:3000 to see your app!
|
||||
│ ├── app/
|
||||
│ │ ├── api/ # API routes and dependencies
|
||||
│ │ ├── core/ # Core functionality (auth, config, database)
|
||||
│ │ ├── crud/ # Database operations
|
||||
│ │ ├── repositories/ # Repository pattern (database operations)
|
||||
│ │ ├── models/ # SQLAlchemy models
|
||||
│ │ ├── schemas/ # Pydantic schemas
|
||||
│ │ ├── services/ # Business logic
|
||||
@@ -279,7 +377,7 @@ open htmlcov/index.html
|
||||
```
|
||||
|
||||
**Test types:**
|
||||
- **Unit tests**: CRUD operations, utilities, business logic
|
||||
- **Unit tests**: Repository operations, utilities, business logic
|
||||
- **Integration tests**: API endpoints with database
|
||||
- **Security tests**: JWT algorithm attacks, session hijacking, privilege escalation
|
||||
- **Error handling tests**: Database failures, validation errors
|
||||
@@ -292,13 +390,13 @@ open htmlcov/index.html
|
||||
cd frontend
|
||||
|
||||
# Run unit tests
|
||||
npm test
|
||||
bun run test
|
||||
|
||||
# Run with coverage
|
||||
npm run test:coverage
|
||||
bun run test:coverage
|
||||
|
||||
# Watch mode
|
||||
npm run test:watch
|
||||
bun run test:watch
|
||||
```
|
||||
|
||||
**Test types:**
|
||||
@@ -316,10 +414,10 @@ npm run test:watch
|
||||
cd frontend
|
||||
|
||||
# Run E2E tests
|
||||
npm run test:e2e
|
||||
bun run test:e2e
|
||||
|
||||
# Run E2E tests in UI mode (recommended for development)
|
||||
npm run test:e2e:ui
|
||||
bun run test:e2e:ui
|
||||
|
||||
# Run specific test file
|
||||
npx playwright test auth-login.spec.ts
|
||||
@@ -338,6 +436,17 @@ npx playwright show-report
|
||||
|
||||
---
|
||||
|
||||
## 🤖 AI-Friendly Documentation
|
||||
|
||||
This project includes comprehensive documentation designed for AI coding assistants:
|
||||
|
||||
- **[AGENTS.md](./AGENTS.md)** - Framework-agnostic AI assistant context for PragmaStack
|
||||
- **[CLAUDE.md](./CLAUDE.md)** - Claude Code-specific guidance
|
||||
|
||||
These files provide AI assistants with the **PragmaStack** architecture, patterns, and best practices.
|
||||
|
||||
---
|
||||
|
||||
## 🗄️ Database Migrations
|
||||
|
||||
The template uses Alembic for database migrations:
|
||||
@@ -365,22 +474,25 @@ python migrate.py current
|
||||
|
||||
## 📖 Documentation
|
||||
|
||||
### AI Assistant Documentation
|
||||
|
||||
- **[AGENTS.md](./AGENTS.md)** - Framework-agnostic AI coding assistant context
|
||||
- **[CLAUDE.md](./CLAUDE.md)** - Claude Code-specific guidance and preferences
|
||||
|
||||
### Backend Documentation
|
||||
|
||||
- **[ARCHITECTURE.md](./backend/docs/ARCHITECTURE.md)** - System architecture and design patterns
|
||||
- **[CODING_STANDARDS.md](./backend/docs/CODING_STANDARDS.md)** - Code quality standards
|
||||
- **[COMMON_PITFALLS.md](./backend/docs/COMMON_PITFALLS.md)** - Common mistakes to avoid
|
||||
- **[FEATURE_EXAMPLE.md](./backend/docs/FEATURE_EXAMPLE.md)** - Step-by-step feature guide
|
||||
- **[CLAUDE.md](./CLAUDE.md)** - Comprehensive development guide
|
||||
|
||||
### Frontend Documentation
|
||||
|
||||
- **[Design System Docs](./frontend/docs/design-system/)** - Complete design system guide
|
||||
- **[PragmaStack Design System](./frontend/docs/design-system/)** - Complete design system guide
|
||||
- Quick start, foundations (colors, typography, spacing)
|
||||
- Component library guide
|
||||
- Layout patterns, spacing philosophy
|
||||
- Forms, accessibility, AI guidelines
|
||||
- **[ARCHITECTURE_FIX_REPORT.md](./frontend/docs/ARCHITECTURE_FIX_REPORT.md)** - Critical dependency injection patterns
|
||||
- **[E2E Testing Guide](./frontend/e2e/README.md)** - E2E testing setup and best practices
|
||||
|
||||
### API Documentation
|
||||
@@ -429,37 +541,43 @@ docker-compose down
|
||||
## 🛣️ Roadmap & Status
|
||||
|
||||
### ✅ Completed
|
||||
- [x] Authentication system (JWT, refresh tokens, session management)
|
||||
- [x] User management (CRUD, profile, password change)
|
||||
- [x] Authentication system (JWT, refresh tokens, session management, OAuth)
|
||||
- [x] User management (full lifecycle, profile, password change)
|
||||
- [x] Organization system with RBAC (Owner, Admin, Member)
|
||||
- [x] Admin panel (users, organizations, sessions, statistics)
|
||||
- [x] **Internationalization (i18n)** with next-intl (English + Italian)
|
||||
- [x] Backend testing infrastructure (~97% coverage)
|
||||
- [x] Frontend unit testing infrastructure (~97% coverage)
|
||||
- [x] Frontend E2E testing (Playwright, zero flaky tests)
|
||||
- [x] Design system documentation
|
||||
- [x] Database migrations
|
||||
- [x] **Marketing landing page** with animated components
|
||||
- [x] **`/dev` documentation portal** with live component examples
|
||||
- [x] **Toast notifications** system (Sonner)
|
||||
- [x] **Charts and visualizations** (Recharts)
|
||||
- [x] **Animation system** (Framer Motion)
|
||||
- [x] **Markdown rendering** with syntax highlighting
|
||||
- [x] **SEO optimization** (sitemap, robots.txt, locale-aware metadata)
|
||||
- [x] Database migrations with helper script
|
||||
- [x] Docker deployment
|
||||
- [x] API documentation (OpenAPI/Swagger)
|
||||
|
||||
### 🚧 In Progress
|
||||
- [ ] Frontend admin pages (70% complete)
|
||||
- [ ] Dark mode theme
|
||||
- [ ] `/dev` documentation page with examples
|
||||
- [ ] Email integration (templates ready, SMTP pending)
|
||||
- [ ] Chart/visualization components
|
||||
|
||||
### 🔮 Planned
|
||||
- [ ] GitHub Actions CI/CD pipelines
|
||||
- [ ] Dynamic test coverage badges from CI
|
||||
- [ ] E2E test coverage reporting
|
||||
- [ ] Additional authentication methods (OAuth, SSO)
|
||||
- [ ] OAuth token encryption at rest (security hardening)
|
||||
- [ ] Additional languages (Spanish, French, German, etc.)
|
||||
- [ ] SSO/SAML authentication
|
||||
- [ ] Real-time notifications with WebSockets
|
||||
- [ ] Webhook system
|
||||
- [ ] Background job processing
|
||||
- [ ] File upload/storage
|
||||
- [ ] Notification system
|
||||
- [ ] Audit logging
|
||||
- [ ] File upload/storage (S3-compatible)
|
||||
- [ ] Audit logging system
|
||||
- [ ] API versioning example
|
||||
|
||||
|
||||
---
|
||||
|
||||
## 🤝 Contributing
|
||||
@@ -489,7 +607,7 @@ Contributions are welcome! Whether you're fixing bugs, improving documentation,
|
||||
|
||||
### Reporting Issues
|
||||
|
||||
Found a bug? Have a suggestion? [Open an issue](https://github.com/yourusername/fast-next-template/issues)!
|
||||
Found a bug? Have a suggestion? [Open an issue](https://github.com/cardosofelipe/pragma-stack/issues)!
|
||||
|
||||
Please include:
|
||||
- Clear description of the issue/suggestion
|
||||
@@ -523,8 +641,8 @@ This template is built on the shoulders of giants:
|
||||
## 💬 Questions?
|
||||
|
||||
- **Documentation**: Check the `/docs` folders in backend and frontend
|
||||
- **Issues**: [GitHub Issues](https://github.com/yourusername/fast-next-template/issues)
|
||||
- **Discussions**: [GitHub Discussions](https://github.com/yourusername/fast-next-template/discussions)
|
||||
- **Issues**: [GitHub Issues](https://github.com/cardosofelipe/pragma-stack/issues)
|
||||
- **Discussions**: [GitHub Discussions](https://github.com/cardosofelipe/pragma-stack/discussions)
|
||||
|
||||
---
|
||||
|
||||
|
||||
@@ -11,16 +11,19 @@ omit =
|
||||
app/utils/auth_test_utils.py
|
||||
|
||||
# Async implementations not yet in use
|
||||
app/crud/base_async.py
|
||||
app/repositories/base_async.py
|
||||
app/core/database_async.py
|
||||
|
||||
# CLI scripts - run manually, not tested
|
||||
app/init_db.py
|
||||
|
||||
# __init__ files with no logic
|
||||
app/__init__.py
|
||||
app/api/__init__.py
|
||||
app/api/routes/__init__.py
|
||||
app/api/dependencies/__init__.py
|
||||
app/core/__init__.py
|
||||
app/crud/__init__.py
|
||||
app/repositories/__init__.py
|
||||
app/models/__init__.py
|
||||
app/schemas/__init__.py
|
||||
app/services/__init__.py
|
||||
|
||||
@@ -1,2 +1,17 @@
|
||||
.venv
|
||||
*.iml
|
||||
*.iml
|
||||
|
||||
# Python build and cache artifacts
|
||||
__pycache__/
|
||||
.pytest_cache/
|
||||
.mypy_cache/
|
||||
.ruff_cache/
|
||||
*.pyc
|
||||
*.pyo
|
||||
|
||||
# Packaging artifacts
|
||||
*.egg-info/
|
||||
build/
|
||||
dist/
|
||||
htmlcov/
|
||||
.uv_cache/
|
||||
44
backend/.pre-commit-config.yaml
Normal file
44
backend/.pre-commit-config.yaml
Normal file
@@ -0,0 +1,44 @@
|
||||
# Pre-commit hooks for backend quality and security checks.
|
||||
#
|
||||
# Install:
|
||||
# cd backend && uv run pre-commit install
|
||||
#
|
||||
# Run manually on all files:
|
||||
# cd backend && uv run pre-commit run --all-files
|
||||
#
|
||||
# Skip hooks temporarily:
|
||||
# git commit --no-verify
|
||||
#
|
||||
repos:
|
||||
# ── Code Quality ──────────────────────────────────────────────────────────
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.14.4
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix, --exit-non-zero-on-fix]
|
||||
- id: ruff-format
|
||||
|
||||
# ── General File Hygiene ──────────────────────────────────────────────────
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v5.0.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
- id: end-of-file-fixer
|
||||
- id: check-yaml
|
||||
- id: check-toml
|
||||
- id: check-merge-conflict
|
||||
- id: check-added-large-files
|
||||
args: [--maxkb=500]
|
||||
- id: debug-statements
|
||||
|
||||
# ── Security ──────────────────────────────────────────────────────────────
|
||||
- repo: https://github.com/Yelp/detect-secrets
|
||||
rev: v1.5.0
|
||||
hooks:
|
||||
- id: detect-secrets
|
||||
args: ['--baseline', '.secrets.baseline']
|
||||
exclude: |
|
||||
(?x)^(
|
||||
.*\.lock$|
|
||||
.*\.svg$
|
||||
)$
|
||||
1073
backend/.secrets.baseline
Normal file
1073
backend/.secrets.baseline
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,53 +1,67 @@
|
||||
# Development stage
|
||||
FROM python:3.12-slim AS development
|
||||
|
||||
# Create non-root user
|
||||
RUN groupadd -r appuser && useradd -r -g appuser appuser
|
||||
|
||||
WORKDIR /app
|
||||
ENV PYTHONDONTWRITEBYTECODE=1 \
|
||||
PYTHONUNBUFFERED=1 \
|
||||
PYTHONPATH=/app
|
||||
PYTHONPATH=/app \
|
||||
UV_COMPILE_BYTECODE=1 \
|
||||
UV_LINK_MODE=copy \
|
||||
UV_NO_CACHE=1
|
||||
|
||||
# Install system dependencies and uv
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends gcc postgresql-client curl && \
|
||||
apt-get install -y --no-install-recommends gcc postgresql-client curl ca-certificates && \
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh && \
|
||||
mv /root/.local/bin/uv* /usr/local/bin/ && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
# Copy dependency files
|
||||
COPY pyproject.toml uv.lock ./
|
||||
|
||||
# Install dependencies using uv (development mode with dev dependencies)
|
||||
RUN uv sync --extra dev --frozen
|
||||
|
||||
# Copy application code
|
||||
COPY . .
|
||||
COPY entrypoint.sh /usr/local/bin/
|
||||
RUN chmod +x /usr/local/bin/entrypoint.sh
|
||||
|
||||
# Set ownership to non-root user
|
||||
RUN chown -R appuser:appuser /app
|
||||
|
||||
# Switch to non-root user
|
||||
USER appuser
|
||||
# Note: Running as root in development for bind mount compatibility
|
||||
# Production stage uses non-root user for security
|
||||
|
||||
ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
|
||||
|
||||
# Production stage
|
||||
FROM python:3.12-slim AS production
|
||||
# Production stage — Alpine eliminates glibc CVEs (e.g. CVE-2026-0861)
|
||||
FROM python:3.12-alpine AS production
|
||||
|
||||
# Create non-root user
|
||||
RUN groupadd -r appuser && useradd -r -g appuser appuser
|
||||
RUN addgroup -S appuser && adduser -S -G appuser appuser
|
||||
|
||||
WORKDIR /app
|
||||
ENV PYTHONDONTWRITEBYTECODE=1 \
|
||||
PYTHONUNBUFFERED=1 \
|
||||
PYTHONPATH=/app
|
||||
PYTHONPATH=/app \
|
||||
UV_COMPILE_BYTECODE=1 \
|
||||
UV_LINK_MODE=copy \
|
||||
UV_NO_CACHE=1
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends postgresql-client curl && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
# Install system dependencies and uv
|
||||
RUN apk add --no-cache postgresql-client curl ca-certificates && \
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh && \
|
||||
mv /root/.local/bin/uv* /usr/local/bin/
|
||||
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
# Copy dependency files
|
||||
COPY pyproject.toml uv.lock ./
|
||||
|
||||
# Install build dependencies, compile Python packages, then remove build deps
|
||||
RUN apk add --no-cache --virtual .build-deps \
|
||||
gcc g++ musl-dev python3-dev linux-headers libffi-dev openssl-dev && \
|
||||
uv sync --frozen --no-dev && \
|
||||
apk del .build-deps
|
||||
|
||||
# Copy application code
|
||||
COPY . .
|
||||
COPY entrypoint.sh /usr/local/bin/
|
||||
RUN chmod +x /usr/local/bin/entrypoint.sh
|
||||
@@ -63,4 +77,4 @@ HEALTHCHECK --interval=30s --timeout=10s --start-period=40s --retries=3 \
|
||||
CMD curl -f http://localhost:8000/health || exit 1
|
||||
|
||||
ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
CMD ["uv", "run", "uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
|
||||
220
backend/Makefile
Normal file
220
backend/Makefile
Normal file
@@ -0,0 +1,220 @@
|
||||
.PHONY: help lint lint-fix format format-check type-check test test-cov validate clean install-dev sync check-docker install-e2e test-e2e test-e2e-schema test-all dep-audit license-check audit validate-all check benchmark benchmark-check benchmark-save scan-image test-api-security
|
||||
|
||||
# Prevent a stale VIRTUAL_ENV in the caller's shell from confusing uv
|
||||
unexport VIRTUAL_ENV
|
||||
|
||||
# Default target
|
||||
help:
|
||||
@echo "🚀 FastAPI Backend - Development Commands"
|
||||
@echo ""
|
||||
@echo "Setup:"
|
||||
@echo " make install-dev - Install all dependencies with uv (includes dev)"
|
||||
@echo " make install-e2e - Install E2E test dependencies (requires Docker)"
|
||||
@echo " make sync - Sync dependencies from uv.lock"
|
||||
@echo ""
|
||||
@echo "Quality Checks:"
|
||||
@echo " make lint - Run Ruff linter (check only)"
|
||||
@echo " make lint-fix - Run Ruff linter with auto-fix"
|
||||
@echo " make format - Format code with Ruff"
|
||||
@echo " make format-check - Check if code is formatted"
|
||||
@echo " make type-check - Run pyright type checking"
|
||||
@echo " make validate - Run all checks (lint + format + types + schema fuzz)"
|
||||
@echo ""
|
||||
@echo "Performance:"
|
||||
@echo " make benchmark - Run performance benchmarks"
|
||||
@echo " make benchmark-save - Run benchmarks and save as baseline"
|
||||
@echo " make benchmark-check - Run benchmarks and compare against baseline"
|
||||
@echo ""
|
||||
@echo "Security & Audit:"
|
||||
@echo " make dep-audit - Scan dependencies for known vulnerabilities"
|
||||
@echo " make license-check - Check dependency license compliance"
|
||||
@echo " make audit - Run all security audits (deps + licenses)"
|
||||
@echo " make scan-image - Scan Docker image for CVEs (requires trivy)"
|
||||
@echo " make validate-all - Run all quality + security checks"
|
||||
@echo " make check - Full pipeline: quality + security + tests"
|
||||
@echo ""
|
||||
@echo "Testing:"
|
||||
@echo " make test - Run pytest (unit/integration, SQLite)"
|
||||
@echo " make test-cov - Run pytest with coverage report"
|
||||
@echo " make test-e2e - Run E2E tests (PostgreSQL, requires Docker)"
|
||||
@echo " make test-e2e-schema - Run Schemathesis API schema tests"
|
||||
@echo " make test-all - Run all tests (unit + E2E)"
|
||||
@echo " make check-docker - Check if Docker is available"
|
||||
@echo " make check - Full pipeline: quality + security + tests"
|
||||
@echo ""
|
||||
@echo "Cleanup:"
|
||||
@echo " make clean - Remove cache and build artifacts"
|
||||
|
||||
# ============================================================================
|
||||
# Setup & Cleanup
|
||||
# ============================================================================
|
||||
|
||||
install-dev:
|
||||
@echo "📦 Installing all dependencies with uv (includes dev)..."
|
||||
@uv sync --extra dev
|
||||
@echo "✅ Development environment ready!"
|
||||
|
||||
sync:
|
||||
@echo "🔄 Syncing dependencies from uv.lock..."
|
||||
@uv sync --extra dev
|
||||
@echo "✅ Dependencies synced!"
|
||||
|
||||
# ============================================================================
|
||||
# Code Quality
|
||||
# ============================================================================
|
||||
|
||||
lint:
|
||||
@echo "🔍 Running Ruff linter..."
|
||||
@uv run ruff check app/ tests/
|
||||
|
||||
lint-fix:
|
||||
@echo "🔧 Running Ruff linter with auto-fix..."
|
||||
@uv run ruff check --fix app/ tests/
|
||||
|
||||
format:
|
||||
@echo "✨ Formatting code with Ruff..."
|
||||
@uv run ruff format app/ tests/
|
||||
|
||||
format-check:
|
||||
@echo "📋 Checking code formatting..."
|
||||
@uv run ruff format --check app/ tests/
|
||||
|
||||
type-check:
|
||||
@echo "🔎 Running pyright type checking..."
|
||||
@uv run pyright app/
|
||||
|
||||
validate: lint format-check type-check test-api-security
|
||||
@echo "✅ All quality checks passed!"
|
||||
|
||||
# API Security Testing (Schemathesis property-based fuzzing)
|
||||
test-api-security: check-docker
|
||||
@echo "🔐 Running Schemathesis API security fuzzing..."
|
||||
@IS_TEST=True PYTHONPATH=. uv run pytest tests/e2e/ -v -m "schemathesis" --tb=short -n 0
|
||||
@echo "✅ API schema security tests passed!"
|
||||
|
||||
# ============================================================================
|
||||
# Security & Audit
|
||||
# ============================================================================
|
||||
|
||||
dep-audit:
|
||||
@echo "🔒 Scanning dependencies for known vulnerabilities..."
|
||||
@uv run pip-audit --desc --skip-editable
|
||||
@echo "✅ No known vulnerabilities found!"
|
||||
|
||||
license-check:
|
||||
@echo "📜 Checking dependency license compliance..."
|
||||
@uv run pip-licenses --fail-on="GPL-3.0-or-later;AGPL-3.0-or-later" --format=plain > /dev/null
|
||||
@echo "✅ All dependency licenses are compliant!"
|
||||
|
||||
audit: dep-audit license-check
|
||||
@echo "✅ All security audits passed!"
|
||||
|
||||
scan-image: check-docker
|
||||
@echo "🐳 Scanning Docker image for OS-level CVEs with Trivy..."
|
||||
@docker build -t pragma-backend:scan -q --target production .
|
||||
@if command -v trivy > /dev/null 2>&1; then \
|
||||
trivy image --severity HIGH,CRITICAL --exit-code 1 pragma-backend:scan; \
|
||||
else \
|
||||
echo "ℹ️ Trivy not found locally, using Docker to run Trivy..."; \
|
||||
docker run --rm -v /var/run/docker.sock:/var/run/docker.sock aquasec/trivy image --severity HIGH,CRITICAL --exit-code 1 pragma-backend:scan; \
|
||||
fi
|
||||
@echo "✅ No HIGH/CRITICAL CVEs found in Docker image!"
|
||||
|
||||
validate-all: validate audit
|
||||
@echo "✅ All quality + security checks passed!"
|
||||
|
||||
check: validate-all test
|
||||
@echo "✅ Full validation pipeline complete!"
|
||||
|
||||
# ============================================================================
|
||||
# Testing
|
||||
# ============================================================================
|
||||
|
||||
test:
|
||||
@echo "🧪 Running tests..."
|
||||
@IS_TEST=True PYTHONPATH=. uv run pytest
|
||||
|
||||
test-cov:
|
||||
@echo "🧪 Running tests with coverage..."
|
||||
@IS_TEST=True PYTHONPATH=. uv run pytest --cov=app --cov-report=term-missing --cov-report=html -n 16
|
||||
@echo "📊 Coverage report generated in htmlcov/index.html"
|
||||
|
||||
# ============================================================================
|
||||
# E2E Testing (requires Docker)
|
||||
# ============================================================================
|
||||
|
||||
check-docker:
|
||||
@docker info > /dev/null 2>&1 || (echo ""; \
|
||||
echo "Docker is not running!"; \
|
||||
echo ""; \
|
||||
echo "E2E tests require Docker to be running."; \
|
||||
echo "Please start Docker Desktop or Docker Engine and try again."; \
|
||||
echo ""; \
|
||||
echo "Quick start:"; \
|
||||
echo " macOS/Windows: Open Docker Desktop"; \
|
||||
echo " Linux: sudo systemctl start docker"; \
|
||||
echo ""; \
|
||||
exit 1)
|
||||
@echo "Docker is available"
|
||||
|
||||
install-e2e:
|
||||
@echo "📦 Installing E2E test dependencies..."
|
||||
@uv sync --extra dev --extra e2e
|
||||
@echo "✅ E2E dependencies installed!"
|
||||
|
||||
test-e2e: check-docker
|
||||
@echo "🧪 Running E2E tests with PostgreSQL..."
|
||||
@IS_TEST=True PYTHONPATH=. uv run pytest tests/e2e/ -v --tb=short -n 0
|
||||
@echo "✅ E2E tests complete!"
|
||||
|
||||
test-e2e-schema: check-docker
|
||||
@echo "🧪 Running Schemathesis API schema tests..."
|
||||
@IS_TEST=True PYTHONPATH=. uv run pytest tests/e2e/ -v -m "schemathesis" --tb=short -n 0
|
||||
|
||||
# ============================================================================
|
||||
# Performance Benchmarks
|
||||
# ============================================================================
|
||||
|
||||
benchmark:
|
||||
@echo "⏱️ Running performance benchmarks..."
|
||||
@IS_TEST=True PYTHONPATH=. uv run pytest tests/benchmarks/ -v --benchmark-only --benchmark-sort=mean -p no:xdist --override-ini='addopts='
|
||||
|
||||
benchmark-save:
|
||||
@echo "⏱️ Running benchmarks and saving baseline..."
|
||||
@IS_TEST=True PYTHONPATH=. uv run pytest tests/benchmarks/ -v --benchmark-only --benchmark-save=baseline --benchmark-sort=mean -p no:xdist --override-ini='addopts='
|
||||
@echo "✅ Benchmark baseline saved to .benchmarks/"
|
||||
|
||||
benchmark-check:
|
||||
@echo "⏱️ Running benchmarks and comparing against baseline..."
|
||||
@if find .benchmarks -name '*_baseline*' -print -quit 2>/dev/null | grep -q .; then \
|
||||
IS_TEST=True PYTHONPATH=. uv run pytest tests/benchmarks/ -v --benchmark-only --benchmark-compare=0001_baseline --benchmark-sort=mean --benchmark-compare-fail=mean:200% -p no:xdist --override-ini='addopts='; \
|
||||
echo "✅ No performance regressions detected!"; \
|
||||
else \
|
||||
echo "⚠️ No benchmark baseline found. Run 'make benchmark-save' first to create one."; \
|
||||
echo " Running benchmarks without comparison..."; \
|
||||
IS_TEST=True PYTHONPATH=. uv run pytest tests/benchmarks/ -v --benchmark-only --benchmark-save=baseline --benchmark-sort=mean -p no:xdist --override-ini='addopts='; \
|
||||
echo "✅ Benchmark baseline created. Future runs of 'make benchmark-check' will compare against it."; \
|
||||
fi
|
||||
|
||||
test-all:
|
||||
@echo "🧪 Running ALL tests (unit + E2E)..."
|
||||
@$(MAKE) test
|
||||
@$(MAKE) test-e2e
|
||||
|
||||
# ============================================================================
|
||||
# Cleanup
|
||||
# ============================================================================
|
||||
|
||||
clean:
|
||||
@echo "🧹 Cleaning up..."
|
||||
@find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true
|
||||
@find . -type d -name ".pytest_cache" -exec rm -rf {} + 2>/dev/null || true
|
||||
@find . -type d -name ".pyright" -exec rm -rf {} + 2>/dev/null || true
|
||||
@find . -type d -name ".ruff_cache" -exec rm -rf {} + 2>/dev/null || true
|
||||
@find . -type d -name "*.egg-info" -exec rm -rf {} + 2>/dev/null || true
|
||||
@find . -type d -name "htmlcov" -exec rm -rf {} + 2>/dev/null || true
|
||||
@find . -type d -name "build" -exec rm -rf {} + 2>/dev/null || true
|
||||
@find . -type d -name ".uv_cache" -exec rm -rf {} + 2>/dev/null || true
|
||||
@find . -type f -name ".coverage" -delete 2>/dev/null || true
|
||||
@find . -type f -name "*.pyc" -delete 2>/dev/null || true
|
||||
@echo "✅ Cleanup complete!"
|
||||
@@ -1,10 +1,12 @@
|
||||
# Backend API
|
||||
# PragmaStack Backend API
|
||||
|
||||
> FastAPI-based REST API with async SQLAlchemy, JWT authentication, and comprehensive testing.
|
||||
> The pragmatic, production-ready FastAPI backend for PragmaStack.
|
||||
|
||||
## Overview
|
||||
|
||||
Production-ready FastAPI backend featuring:
|
||||
Opinionated, secure, and fast. This backend provides the solid foundation you need to ship features, not boilerplate.
|
||||
|
||||
Features:
|
||||
|
||||
- **Authentication**: JWT with refresh tokens, session management, device tracking
|
||||
- **Database**: Async PostgreSQL with SQLAlchemy 2.0, Alembic migrations
|
||||
@@ -12,30 +14,42 @@ Production-ready FastAPI backend featuring:
|
||||
- **Multi-tenancy**: Organization-based access control with roles (Owner/Admin/Member)
|
||||
- **Testing**: 97%+ coverage with security-focused test suite
|
||||
- **Performance**: Async throughout, connection pooling, optimized queries
|
||||
- **Modern Tooling**: uv for dependencies, Ruff for linting/formatting, Pyright for type checking
|
||||
- **Security Auditing**: Automated dependency vulnerability scanning, license compliance, secrets detection
|
||||
- **Pre-commit Hooks**: Ruff, detect-secrets, and standard checks on every commit
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Python 3.11+
|
||||
- Python 3.12+
|
||||
- PostgreSQL 14+ (or SQLite for development)
|
||||
- pip and virtualenv
|
||||
- **[uv](https://docs.astral.sh/uv/)** - Modern Python package manager (replaces pip)
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# Create virtual environment
|
||||
python -m venv .venv
|
||||
source .venv/bin/activate # On Windows: .venv\Scripts\activate
|
||||
# Install uv (if not already installed)
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
||||
# Install dependencies
|
||||
pip install -r requirements.txt
|
||||
# Install all dependencies (production + dev)
|
||||
cd backend
|
||||
uv sync --extra dev
|
||||
|
||||
# Or use the Makefile
|
||||
make install-dev
|
||||
|
||||
# Copy environment template
|
||||
cp .env.example .env
|
||||
# Edit .env with your configuration
|
||||
```
|
||||
|
||||
**Why uv?**
|
||||
- 🚀 10-100x faster than pip
|
||||
- 🔒 Reproducible builds via `uv.lock` lockfile
|
||||
- 📦 Better dependency resolution
|
||||
- ⚡ Built by Astral (creators of Ruff)
|
||||
|
||||
### Database Setup
|
||||
|
||||
```bash
|
||||
@@ -49,6 +63,11 @@ alembic upgrade head
|
||||
### Run Development Server
|
||||
|
||||
```bash
|
||||
# Using uv
|
||||
uv run uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
|
||||
|
||||
# Or activate environment first
|
||||
source .venv/bin/activate # On Windows: .venv\Scripts\activate
|
||||
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||
@@ -57,6 +76,180 @@ API will be available at:
|
||||
- **Swagger Docs**: http://localhost:8000/docs
|
||||
- **ReDoc**: http://localhost:8000/redoc
|
||||
|
||||
---
|
||||
|
||||
## Dependency Management with uv
|
||||
|
||||
### Understanding uv
|
||||
|
||||
**uv** is the modern standard for Python dependency management, built in Rust for speed and reliability.
|
||||
|
||||
**Key files:**
|
||||
- `pyproject.toml` - Declares dependencies and tool configurations
|
||||
- `uv.lock` - Locks exact versions for reproducible builds (commit to git)
|
||||
|
||||
### Common Commands
|
||||
|
||||
#### Installing Dependencies
|
||||
|
||||
```bash
|
||||
# Install all dependencies from lockfile
|
||||
uv sync --extra dev
|
||||
|
||||
# Install only production dependencies (no dev tools)
|
||||
uv sync
|
||||
|
||||
# Or use the Makefile
|
||||
make install-dev # Install with dev dependencies
|
||||
make sync # Sync from lockfile
|
||||
```
|
||||
|
||||
#### Adding Dependencies
|
||||
|
||||
```bash
|
||||
# Add a production dependency
|
||||
uv add httpx
|
||||
|
||||
# Add a development dependency
|
||||
uv add --dev pytest-mock
|
||||
|
||||
# Add with version constraint
|
||||
uv add "fastapi>=0.115.0,<0.116.0"
|
||||
|
||||
# Add exact version
|
||||
uv add "pydantic==2.10.6"
|
||||
```
|
||||
|
||||
After adding dependencies, **commit both `pyproject.toml` and `uv.lock`** to git.
|
||||
|
||||
#### Removing Dependencies
|
||||
|
||||
```bash
|
||||
# Remove a package
|
||||
uv remove httpx
|
||||
|
||||
# Remove a dev dependency
|
||||
uv remove --dev pytest-mock
|
||||
```
|
||||
|
||||
#### Updating Dependencies
|
||||
|
||||
```bash
|
||||
# Update all packages to latest compatible versions
|
||||
uv sync --upgrade
|
||||
|
||||
# Update a specific package
|
||||
uv add --upgrade fastapi
|
||||
|
||||
# Check for outdated packages
|
||||
uv pip list --outdated
|
||||
```
|
||||
|
||||
#### Running Commands in uv Environment
|
||||
|
||||
```bash
|
||||
# Run any Python command via uv (no activation needed)
|
||||
uv run python script.py
|
||||
uv run pytest
|
||||
uv run pyright app/
|
||||
|
||||
# Or activate the virtual environment
|
||||
source .venv/bin/activate
|
||||
python script.py
|
||||
pytest
|
||||
```
|
||||
|
||||
### Makefile Commands
|
||||
|
||||
We provide convenient Makefile commands that use uv:
|
||||
|
||||
```bash
|
||||
# Setup
|
||||
make install-dev # Install all dependencies (prod + dev)
|
||||
make sync # Sync from lockfile
|
||||
|
||||
# Code Quality
|
||||
make lint # Run Ruff linter (check only)
|
||||
make lint-fix # Run Ruff with auto-fix
|
||||
make format # Format code with Ruff
|
||||
make format-check # Check if code is formatted
|
||||
make type-check # Run Pyright type checking
|
||||
make validate # Run all checks (lint + format + types)
|
||||
|
||||
# Security & Audit
|
||||
make dep-audit # Scan dependencies for known vulnerabilities (CVEs)
|
||||
make license-check # Check dependency license compliance
|
||||
make audit # Run all security audits (deps + licenses)
|
||||
make validate-all # Run all quality + security checks
|
||||
make check # Full pipeline: quality + security + tests
|
||||
|
||||
# Testing
|
||||
make test # Run all tests
|
||||
make test-cov # Run tests with coverage report
|
||||
make test-e2e # Run E2E tests (PostgreSQL, requires Docker)
|
||||
make test-e2e-schema # Run Schemathesis API schema tests
|
||||
make test-all # Run all tests (unit + E2E)
|
||||
|
||||
# Utilities
|
||||
make clean # Remove cache and build artifacts
|
||||
make help # Show all commands
|
||||
```
|
||||
|
||||
### Dependency Workflow Example
|
||||
|
||||
```bash
|
||||
# 1. Clone repository
|
||||
git clone <repo-url>
|
||||
cd backend
|
||||
|
||||
# 2. Install dependencies
|
||||
make install-dev
|
||||
|
||||
# 3. Make changes, add a new dependency
|
||||
uv add httpx
|
||||
|
||||
# 4. Test your changes
|
||||
make test
|
||||
|
||||
# 5. Commit (includes uv.lock)
|
||||
git add pyproject.toml uv.lock
|
||||
git commit -m "Add httpx dependency"
|
||||
|
||||
# 6. Other developers pull and sync
|
||||
git pull
|
||||
make sync # Uses the committed uv.lock
|
||||
```
|
||||
|
||||
### Troubleshooting uv
|
||||
|
||||
**Dependencies not found after install:**
|
||||
```bash
|
||||
# Make sure you're using uv run or activated environment
|
||||
uv run pytest # Option 1: Run via uv
|
||||
source .venv/bin/activate # Option 2: Activate first
|
||||
pytest
|
||||
```
|
||||
|
||||
**Lockfile out of sync:**
|
||||
```bash
|
||||
# Regenerate lockfile
|
||||
uv lock
|
||||
|
||||
# Force reinstall from lockfile
|
||||
uv sync --reinstall
|
||||
```
|
||||
|
||||
**uv not found:**
|
||||
```bash
|
||||
# Install uv globally
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
||||
# Add to PATH if needed
|
||||
export PATH="$HOME/.cargo/bin:$PATH"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Development
|
||||
|
||||
### Project Structure
|
||||
@@ -71,7 +264,7 @@ app/
|
||||
│ ├── database.py # Database engine setup
|
||||
│ ├── auth.py # JWT token handling
|
||||
│ └── exceptions.py # Custom exceptions
|
||||
├── crud/ # Database operations
|
||||
├── repositories/ # Repository pattern (database operations)
|
||||
├── models/ # SQLAlchemy ORM models
|
||||
├── schemas/ # Pydantic request/response schemas
|
||||
├── services/ # Business logic layer
|
||||
@@ -144,20 +337,22 @@ alembic downgrade -1
|
||||
### Testing
|
||||
|
||||
```bash
|
||||
# Run all tests
|
||||
IS_TEST=True pytest
|
||||
# Using Makefile (recommended)
|
||||
make test # Run all tests
|
||||
make test-cov # Run with coverage report
|
||||
|
||||
# Run with coverage
|
||||
IS_TEST=True pytest --cov=app --cov-report=term-missing -n 0
|
||||
# Using uv directly
|
||||
IS_TEST=True uv run pytest
|
||||
IS_TEST=True uv run pytest --cov=app --cov-report=term-missing -n 0
|
||||
|
||||
# Run specific test file
|
||||
IS_TEST=True pytest tests/api/test_auth.py -v
|
||||
IS_TEST=True uv run pytest tests/api/test_auth.py -v
|
||||
|
||||
# Run single test
|
||||
IS_TEST=True pytest tests/api/test_auth.py::TestLogin::test_login_success -v
|
||||
IS_TEST=True uv run pytest tests/api/test_auth.py::TestLogin::test_login_success -v
|
||||
|
||||
# Generate HTML coverage report
|
||||
IS_TEST=True pytest --cov=app --cov-report=html -n 0
|
||||
IS_TEST=True uv run pytest --cov=app --cov-report=html -n 0
|
||||
open htmlcov/index.html
|
||||
```
|
||||
|
||||
@@ -166,17 +361,37 @@ open htmlcov/index.html
|
||||
### Code Quality
|
||||
|
||||
```bash
|
||||
# Type checking
|
||||
mypy app
|
||||
# Using Makefile (recommended)
|
||||
make lint # Ruff linting
|
||||
make format # Ruff formatting
|
||||
make type-check # Pyright type checking
|
||||
make validate # All checks at once
|
||||
|
||||
# Linting
|
||||
ruff check app
|
||||
# Security audits
|
||||
make dep-audit # Scan dependencies for CVEs
|
||||
make license-check # Check license compliance
|
||||
make audit # All security audits
|
||||
make validate-all # Quality + security checks
|
||||
make check # Full pipeline: quality + security + tests
|
||||
|
||||
# Format code
|
||||
black app
|
||||
isort app
|
||||
# Using uv directly
|
||||
uv run ruff check app/ tests/
|
||||
uv run ruff format app/ tests/
|
||||
uv run pyright app/
|
||||
```
|
||||
|
||||
**Tools:**
|
||||
- **Ruff**: All-in-one linting, formatting, and import sorting (replaces Black, Flake8, isort)
|
||||
- **Pyright**: Static type checking (strict mode)
|
||||
- **pip-audit**: Dependency vulnerability scanning against the OSV database
|
||||
- **pip-licenses**: Dependency license compliance checking
|
||||
- **detect-secrets**: Hardcoded secrets/credentials detection
|
||||
- **pre-commit**: Git hook framework for automated checks on every commit
|
||||
|
||||
All configurations are in `pyproject.toml`.
|
||||
|
||||
---
|
||||
|
||||
## API Documentation
|
||||
|
||||
Once the server is running, interactive API documentation is available:
|
||||
@@ -194,6 +409,8 @@ Once the server is running, interactive API documentation is available:
|
||||
- Raw OpenAPI 3.0 specification
|
||||
- Use for client generation
|
||||
|
||||
---
|
||||
|
||||
## Authentication
|
||||
|
||||
### Token-Based Authentication
|
||||
@@ -229,6 +446,8 @@ curl -H "Authorization: Bearer <access_token>" \
|
||||
- `Admin`: Can manage members (except owners)
|
||||
- `Member`: Read-only access
|
||||
|
||||
---
|
||||
|
||||
## Common Tasks
|
||||
|
||||
### Create a Superuser
|
||||
@@ -243,7 +462,7 @@ See [docs/FEATURE_EXAMPLE.md](docs/FEATURE_EXAMPLE.md) for step-by-step guide.
|
||||
|
||||
Quick overview:
|
||||
1. Create Pydantic schemas in `app/schemas/`
|
||||
2. Create CRUD operations in `app/crud/`
|
||||
2. Create repository in `app/repositories/`
|
||||
3. Create route in `app/api/routes/`
|
||||
4. Register router in `app/api/main.py`
|
||||
5. Write tests in `tests/api/`
|
||||
@@ -258,8 +477,12 @@ python migrate.py check
|
||||
curl http://localhost:8000/health
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Docker Support
|
||||
|
||||
The Dockerfile uses **uv** for fast, reproducible builds:
|
||||
|
||||
```bash
|
||||
# Development with hot reload
|
||||
docker-compose -f docker-compose.dev.yml up
|
||||
@@ -271,17 +494,38 @@ docker-compose up -d
|
||||
docker-compose build backend
|
||||
```
|
||||
|
||||
**Docker features:**
|
||||
- Multi-stage builds (development + production)
|
||||
- uv for fast dependency installation
|
||||
- `uv.lock` ensures exact versions in containers
|
||||
- Development stage includes dev dependencies
|
||||
- Production stage optimized for size and security
|
||||
|
||||
---
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
**Module Import Errors**
|
||||
```bash
|
||||
# Ensure you're in the backend directory
|
||||
cd backend
|
||||
# Ensure dependencies are installed
|
||||
make install-dev
|
||||
|
||||
# Activate virtual environment
|
||||
source .venv/bin/activate
|
||||
# Or sync from lockfile
|
||||
make sync
|
||||
|
||||
# Verify Python environment
|
||||
uv run python --version
|
||||
```
|
||||
|
||||
**uv command not found**
|
||||
```bash
|
||||
# Install uv globally
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
||||
# Add to PATH (add to ~/.bashrc or ~/.zshrc)
|
||||
export PATH="$HOME/.cargo/bin:$PATH"
|
||||
```
|
||||
|
||||
**Database Connection Failed**
|
||||
@@ -306,10 +550,19 @@ alembic upgrade head
|
||||
**Tests Failing**
|
||||
```bash
|
||||
# Run with verbose output
|
||||
IS_TEST=True pytest -vv
|
||||
make test
|
||||
|
||||
# Run single test to isolate issue
|
||||
IS_TEST=True pytest tests/api/test_auth.py::TestLogin::test_login_success -vv
|
||||
IS_TEST=True uv run pytest tests/api/test_auth.py::TestLogin::test_login_success -vv
|
||||
```
|
||||
|
||||
**Dependencies out of sync**
|
||||
```bash
|
||||
# Regenerate lockfile from pyproject.toml
|
||||
uv lock
|
||||
|
||||
# Reinstall everything
|
||||
make install-dev
|
||||
```
|
||||
|
||||
### Getting Help
|
||||
@@ -321,6 +574,8 @@ See our detailed documentation:
|
||||
- [COMMON_PITFALLS.md](docs/COMMON_PITFALLS.md) - Mistakes to avoid
|
||||
- [FEATURE_EXAMPLE.md](docs/FEATURE_EXAMPLE.md) - Adding new features
|
||||
|
||||
---
|
||||
|
||||
## Performance
|
||||
|
||||
### Database Connection Pooling
|
||||
@@ -343,6 +598,8 @@ Configured in `app/core/config.py`:
|
||||
- Bulk operations for admin actions
|
||||
- Indexed foreign keys and common lookups
|
||||
|
||||
---
|
||||
|
||||
## Security
|
||||
|
||||
### Built-in Security Features
|
||||
@@ -355,13 +612,44 @@ Configured in `app/core/config.py`:
|
||||
- **Security Headers**: CSP, HSTS, X-Frame-Options, etc.
|
||||
- **Input Validation**: Pydantic schemas, SQL injection prevention (ORM)
|
||||
|
||||
### Security Auditing
|
||||
|
||||
Automated, deterministic security checks are built into the development workflow:
|
||||
|
||||
```bash
|
||||
# Scan dependencies for known vulnerabilities (CVEs)
|
||||
make dep-audit
|
||||
|
||||
# Check dependency license compliance (blocks GPL-3.0/AGPL)
|
||||
make license-check
|
||||
|
||||
# Run all security audits
|
||||
make audit
|
||||
|
||||
# Full pipeline: quality + security + tests
|
||||
make check
|
||||
```
|
||||
|
||||
**Pre-commit hooks** automatically run on every commit:
|
||||
- **Ruff** lint + format checks
|
||||
- **detect-secrets** blocks commits containing hardcoded secrets
|
||||
- **Standard checks**: trailing whitespace, YAML/TOML validation, merge conflict detection, large file prevention
|
||||
|
||||
Setup pre-commit hooks:
|
||||
```bash
|
||||
uv run pre-commit install
|
||||
```
|
||||
|
||||
### Security Best Practices
|
||||
|
||||
1. **Never commit secrets**: Use `.env` files (git-ignored)
|
||||
1. **Never commit secrets**: Use `.env` files (git-ignored), enforced by detect-secrets pre-commit hook
|
||||
2. **Strong SECRET_KEY**: Min 32 chars, cryptographically random
|
||||
3. **HTTPS in production**: Required for token security
|
||||
4. **Regular updates**: Keep dependencies current
|
||||
4. **Regular updates**: Keep dependencies current (`uv sync --upgrade`), run `make dep-audit` to check for CVEs
|
||||
5. **Audit logs**: Monitor authentication events
|
||||
6. **Run `make check` before pushing**: Validates quality, security, and tests in one command
|
||||
|
||||
---
|
||||
|
||||
## Monitoring
|
||||
|
||||
@@ -388,13 +676,32 @@ logging.basicConfig(level=logging.INFO)
|
||||
# In production, use JSON logs for log aggregation
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- **FastAPI Documentation**: https://fastapi.tiangolo.com
|
||||
### Official Documentation
|
||||
- **uv**: https://docs.astral.sh/uv/
|
||||
- **FastAPI**: https://fastapi.tiangolo.com
|
||||
- **SQLAlchemy 2.0**: https://docs.sqlalchemy.org/en/20/
|
||||
- **Pydantic**: https://docs.pydantic.dev/
|
||||
- **Alembic**: https://alembic.sqlalchemy.org/
|
||||
- **Ruff**: https://docs.astral.sh/ruff/
|
||||
|
||||
### Our Documentation
|
||||
- [Root README](../README.md) - Project-wide information
|
||||
- [CLAUDE.md](../CLAUDE.md) - Comprehensive development guide
|
||||
|
||||
---
|
||||
|
||||
**Note**: For project-wide information (license, contributing guidelines, deployment), see the [root README](../README.md).
|
||||
**Built with modern Python tooling:**
|
||||
- 🚀 **uv** - 10-100x faster dependency management
|
||||
- ⚡ **Ruff** - 10-100x faster linting & formatting
|
||||
- 🔍 **Pyright** - Static type checking (strict mode)
|
||||
- ✅ **pytest** - Comprehensive test suite
|
||||
- 🔒 **pip-audit** - Dependency vulnerability scanning
|
||||
- 🔑 **detect-secrets** - Hardcoded secrets detection
|
||||
- 📜 **pip-licenses** - License compliance checking
|
||||
- 🪝 **pre-commit** - Automated git hooks
|
||||
|
||||
**All configured in a single `pyproject.toml` file!**
|
||||
|
||||
@@ -2,6 +2,13 @@
|
||||
script_location = app/alembic
|
||||
sqlalchemy.url = postgresql://postgres:postgres@db:5432/app
|
||||
|
||||
# Use sequential naming: 0001_message.py, 0002_message.py, etc.
|
||||
# The rev_id is still used internally but filename is cleaner
|
||||
file_template = %%(rev)s_%%(slug)s
|
||||
|
||||
# Allow specifying custom revision IDs via --rev-id flag
|
||||
revision_environment = true
|
||||
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
|
||||
@@ -2,12 +2,11 @@ import sys
|
||||
from logging.config import fileConfig
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy import engine_from_config, pool, text, create_engine
|
||||
from alembic import context
|
||||
from sqlalchemy import create_engine, engine_from_config, pool, text
|
||||
from sqlalchemy.engine.url import make_url
|
||||
from sqlalchemy.exc import OperationalError
|
||||
|
||||
from alembic import context
|
||||
|
||||
# Get the path to the app directory (parent of 'alembic')
|
||||
app_dir = Path(__file__).resolve().parent.parent
|
||||
# Add the app directory to Python path
|
||||
@@ -23,6 +22,25 @@ from app.models import *
|
||||
# access to the values within the .ini file in use.
|
||||
config = context.config
|
||||
|
||||
|
||||
def include_object(object, name, type_, reflected, compare_to):
|
||||
"""
|
||||
Filter objects for autogenerate.
|
||||
|
||||
Skip comparing functional indexes (like LOWER(column)) and partial indexes
|
||||
(with WHERE clauses) as Alembic cannot reliably detect these from models.
|
||||
These should be managed manually via dedicated performance migrations.
|
||||
|
||||
Convention: Any index starting with "ix_perf_" is automatically excluded.
|
||||
This allows adding new performance indexes without updating this file.
|
||||
"""
|
||||
if type_ == "index" and name:
|
||||
# Convention-based: any index prefixed with ix_perf_ is manual
|
||||
if name.startswith("ix_perf_"):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
if config.config_file_name is not None:
|
||||
@@ -66,7 +84,9 @@ def ensure_database_exists(db_url: str) -> None:
|
||||
admin_url = url.set(database="postgres")
|
||||
|
||||
# CREATE DATABASE cannot run inside a transaction
|
||||
admin_engine = create_engine(str(admin_url), isolation_level="AUTOCOMMIT", poolclass=pool.NullPool)
|
||||
admin_engine = create_engine(
|
||||
str(admin_url), isolation_level="AUTOCOMMIT", poolclass=pool.NullPool
|
||||
)
|
||||
try:
|
||||
with admin_engine.connect() as conn:
|
||||
exists = conn.execute(
|
||||
@@ -99,6 +119,8 @@ def run_migrations_offline() -> None:
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
compare_type=True,
|
||||
include_object=include_object,
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
@@ -123,7 +145,10 @@ def run_migrations_online() -> None:
|
||||
|
||||
with connectable.connect() as connection:
|
||||
context.configure(
|
||||
connection=connection, target_metadata=target_metadata
|
||||
connection=connection,
|
||||
target_metadata=target_metadata,
|
||||
compare_type=True,
|
||||
include_object=include_object,
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
@@ -133,4 +158,4 @@ def run_migrations_online() -> None:
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
run_migrations_online()
|
||||
|
||||
446
backend/app/alembic/versions/0001_initial_models.py
Normal file
446
backend/app/alembic/versions/0001_initial_models.py
Normal file
@@ -0,0 +1,446 @@
|
||||
"""initial models
|
||||
|
||||
Revision ID: 0001
|
||||
Revises:
|
||||
Create Date: 2025-11-27 09:08:09.464506
|
||||
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "0001"
|
||||
down_revision: str | None = None
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"oauth_states",
|
||||
sa.Column("state", sa.String(length=255), nullable=False),
|
||||
sa.Column("code_verifier", sa.String(length=128), nullable=True),
|
||||
sa.Column("nonce", sa.String(length=255), nullable=True),
|
||||
sa.Column("provider", sa.String(length=50), nullable=False),
|
||||
sa.Column("redirect_uri", sa.String(length=500), nullable=True),
|
||||
sa.Column("user_id", sa.UUID(), nullable=True),
|
||||
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_oauth_states_state"), "oauth_states", ["state"], unique=True
|
||||
)
|
||||
op.create_table(
|
||||
"organizations",
|
||||
sa.Column("name", sa.String(length=255), nullable=False),
|
||||
sa.Column("slug", sa.String(length=255), nullable=False),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=False),
|
||||
sa.Column("settings", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_organizations_is_active"), "organizations", ["is_active"], unique=False
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_organizations_name"), "organizations", ["name"], unique=False
|
||||
)
|
||||
op.create_index(
|
||||
"ix_organizations_name_active",
|
||||
"organizations",
|
||||
["name", "is_active"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_organizations_slug"), "organizations", ["slug"], unique=True
|
||||
)
|
||||
op.create_index(
|
||||
"ix_organizations_slug_active",
|
||||
"organizations",
|
||||
["slug", "is_active"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_table(
|
||||
"users",
|
||||
sa.Column("email", sa.String(length=255), nullable=False),
|
||||
sa.Column("password_hash", sa.String(length=255), nullable=True),
|
||||
sa.Column("first_name", sa.String(length=100), nullable=False),
|
||||
sa.Column("last_name", sa.String(length=100), nullable=True),
|
||||
sa.Column("phone_number", sa.String(length=20), nullable=True),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=False),
|
||||
sa.Column("is_superuser", sa.Boolean(), nullable=False),
|
||||
sa.Column(
|
||||
"preferences", postgresql.JSONB(astext_type=sa.Text()), nullable=True
|
||||
),
|
||||
sa.Column("locale", sa.String(length=10), nullable=True),
|
||||
sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(op.f("ix_users_deleted_at"), "users", ["deleted_at"], unique=False)
|
||||
op.create_index(op.f("ix_users_email"), "users", ["email"], unique=True)
|
||||
op.create_index(op.f("ix_users_is_active"), "users", ["is_active"], unique=False)
|
||||
op.create_index(
|
||||
op.f("ix_users_is_superuser"), "users", ["is_superuser"], unique=False
|
||||
)
|
||||
op.create_index(op.f("ix_users_locale"), "users", ["locale"], unique=False)
|
||||
op.create_table(
|
||||
"oauth_accounts",
|
||||
sa.Column("user_id", sa.UUID(), nullable=False),
|
||||
sa.Column("provider", sa.String(length=50), nullable=False),
|
||||
sa.Column("provider_user_id", sa.String(length=255), nullable=False),
|
||||
sa.Column("provider_email", sa.String(length=255), nullable=True),
|
||||
sa.Column("access_token_encrypted", sa.String(length=2048), nullable=True),
|
||||
sa.Column("refresh_token_encrypted", sa.String(length=2048), nullable=True),
|
||||
sa.Column("token_expires_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint(
|
||||
"provider", "provider_user_id", name="uq_oauth_provider_user"
|
||||
),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_oauth_accounts_provider"), "oauth_accounts", ["provider"], unique=False
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_oauth_accounts_provider_email"),
|
||||
"oauth_accounts",
|
||||
["provider_email"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_oauth_accounts_user_id"), "oauth_accounts", ["user_id"], unique=False
|
||||
)
|
||||
op.create_index(
|
||||
"ix_oauth_accounts_user_provider",
|
||||
"oauth_accounts",
|
||||
["user_id", "provider"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_table(
|
||||
"oauth_clients",
|
||||
sa.Column("client_id", sa.String(length=64), nullable=False),
|
||||
sa.Column("client_secret_hash", sa.String(length=255), nullable=True),
|
||||
sa.Column("client_name", sa.String(length=255), nullable=False),
|
||||
sa.Column("client_description", sa.String(length=1000), nullable=True),
|
||||
sa.Column("client_type", sa.String(length=20), nullable=False),
|
||||
sa.Column(
|
||||
"redirect_uris", postgresql.JSONB(astext_type=sa.Text()), nullable=False
|
||||
),
|
||||
sa.Column(
|
||||
"allowed_scopes", postgresql.JSONB(astext_type=sa.Text()), nullable=False
|
||||
),
|
||||
sa.Column("access_token_lifetime", sa.String(length=10), nullable=False),
|
||||
sa.Column("refresh_token_lifetime", sa.String(length=10), nullable=False),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=False),
|
||||
sa.Column("owner_user_id", sa.UUID(), nullable=True),
|
||||
sa.Column("mcp_server_url", sa.String(length=2048), nullable=True),
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(["owner_user_id"], ["users.id"], ondelete="SET NULL"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_oauth_clients_client_id"), "oauth_clients", ["client_id"], unique=True
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_oauth_clients_is_active"), "oauth_clients", ["is_active"], unique=False
|
||||
)
|
||||
op.create_table(
|
||||
"user_organizations",
|
||||
sa.Column("user_id", sa.UUID(), nullable=False),
|
||||
sa.Column("organization_id", sa.UUID(), nullable=False),
|
||||
sa.Column(
|
||||
"role",
|
||||
sa.Enum("OWNER", "ADMIN", "MEMBER", "GUEST", name="organizationrole"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=False),
|
||||
sa.Column("custom_permissions", sa.String(length=500), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["organization_id"], ["organizations.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("user_id", "organization_id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_user_org_org_active",
|
||||
"user_organizations",
|
||||
["organization_id", "is_active"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index("ix_user_org_role", "user_organizations", ["role"], unique=False)
|
||||
op.create_index(
|
||||
"ix_user_org_user_active",
|
||||
"user_organizations",
|
||||
["user_id", "is_active"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_user_organizations_is_active"),
|
||||
"user_organizations",
|
||||
["is_active"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_table(
|
||||
"user_sessions",
|
||||
sa.Column("user_id", sa.UUID(), nullable=False),
|
||||
sa.Column("refresh_token_jti", sa.String(length=255), nullable=False),
|
||||
sa.Column("device_name", sa.String(length=255), nullable=True),
|
||||
sa.Column("device_id", sa.String(length=255), nullable=True),
|
||||
sa.Column("ip_address", sa.String(length=45), nullable=True),
|
||||
sa.Column("user_agent", sa.String(length=500), nullable=True),
|
||||
sa.Column("last_used_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=False),
|
||||
sa.Column("location_city", sa.String(length=100), nullable=True),
|
||||
sa.Column("location_country", sa.String(length=100), nullable=True),
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_user_sessions_is_active"), "user_sessions", ["is_active"], unique=False
|
||||
)
|
||||
op.create_index(
|
||||
"ix_user_sessions_jti_active",
|
||||
"user_sessions",
|
||||
["refresh_token_jti", "is_active"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_user_sessions_refresh_token_jti"),
|
||||
"user_sessions",
|
||||
["refresh_token_jti"],
|
||||
unique=True,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_user_sessions_user_active",
|
||||
"user_sessions",
|
||||
["user_id", "is_active"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_user_sessions_user_id"), "user_sessions", ["user_id"], unique=False
|
||||
)
|
||||
op.create_table(
|
||||
"oauth_authorization_codes",
|
||||
sa.Column("code", sa.String(length=128), nullable=False),
|
||||
sa.Column("client_id", sa.String(length=64), nullable=False),
|
||||
sa.Column("user_id", sa.UUID(), nullable=False),
|
||||
sa.Column("redirect_uri", sa.String(length=2048), nullable=False),
|
||||
sa.Column("scope", sa.String(length=1000), nullable=False),
|
||||
sa.Column("code_challenge", sa.String(length=128), nullable=True),
|
||||
sa.Column("code_challenge_method", sa.String(length=10), nullable=True),
|
||||
sa.Column("state", sa.String(length=256), nullable=True),
|
||||
sa.Column("nonce", sa.String(length=256), nullable=True),
|
||||
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("used", sa.Boolean(), nullable=False),
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["client_id"], ["oauth_clients.client_id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_oauth_authorization_codes_client_user",
|
||||
"oauth_authorization_codes",
|
||||
["client_id", "user_id"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_oauth_authorization_codes_code"),
|
||||
"oauth_authorization_codes",
|
||||
["code"],
|
||||
unique=True,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_oauth_authorization_codes_expires_at",
|
||||
"oauth_authorization_codes",
|
||||
["expires_at"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_table(
|
||||
"oauth_consents",
|
||||
sa.Column("user_id", sa.UUID(), nullable=False),
|
||||
sa.Column("client_id", sa.String(length=64), nullable=False),
|
||||
sa.Column("granted_scopes", sa.String(length=1000), nullable=False),
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["client_id"], ["oauth_clients.client_id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_oauth_consents_user_client",
|
||||
"oauth_consents",
|
||||
["user_id", "client_id"],
|
||||
unique=True,
|
||||
)
|
||||
op.create_table(
|
||||
"oauth_provider_refresh_tokens",
|
||||
sa.Column("token_hash", sa.String(length=64), nullable=False),
|
||||
sa.Column("jti", sa.String(length=64), nullable=False),
|
||||
sa.Column("client_id", sa.String(length=64), nullable=False),
|
||||
sa.Column("user_id", sa.UUID(), nullable=False),
|
||||
sa.Column("scope", sa.String(length=1000), nullable=False),
|
||||
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("revoked", sa.Boolean(), nullable=False),
|
||||
sa.Column("last_used_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("device_info", sa.String(length=500), nullable=True),
|
||||
sa.Column("ip_address", sa.String(length=45), nullable=True),
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["client_id"], ["oauth_clients.client_id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_oauth_provider_refresh_tokens_client_user",
|
||||
"oauth_provider_refresh_tokens",
|
||||
["client_id", "user_id"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_oauth_provider_refresh_tokens_expires_at",
|
||||
"oauth_provider_refresh_tokens",
|
||||
["expires_at"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_oauth_provider_refresh_tokens_jti"),
|
||||
"oauth_provider_refresh_tokens",
|
||||
["jti"],
|
||||
unique=True,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_oauth_provider_refresh_tokens_revoked"),
|
||||
"oauth_provider_refresh_tokens",
|
||||
["revoked"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_oauth_provider_refresh_tokens_token_hash"),
|
||||
"oauth_provider_refresh_tokens",
|
||||
["token_hash"],
|
||||
unique=True,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_oauth_provider_refresh_tokens_user_revoked",
|
||||
"oauth_provider_refresh_tokens",
|
||||
["user_id", "revoked"],
|
||||
unique=False,
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(
|
||||
"ix_oauth_provider_refresh_tokens_user_revoked",
|
||||
table_name="oauth_provider_refresh_tokens",
|
||||
)
|
||||
op.drop_index(
|
||||
op.f("ix_oauth_provider_refresh_tokens_token_hash"),
|
||||
table_name="oauth_provider_refresh_tokens",
|
||||
)
|
||||
op.drop_index(
|
||||
op.f("ix_oauth_provider_refresh_tokens_revoked"),
|
||||
table_name="oauth_provider_refresh_tokens",
|
||||
)
|
||||
op.drop_index(
|
||||
op.f("ix_oauth_provider_refresh_tokens_jti"),
|
||||
table_name="oauth_provider_refresh_tokens",
|
||||
)
|
||||
op.drop_index(
|
||||
"ix_oauth_provider_refresh_tokens_expires_at",
|
||||
table_name="oauth_provider_refresh_tokens",
|
||||
)
|
||||
op.drop_index(
|
||||
"ix_oauth_provider_refresh_tokens_client_user",
|
||||
table_name="oauth_provider_refresh_tokens",
|
||||
)
|
||||
op.drop_table("oauth_provider_refresh_tokens")
|
||||
op.drop_index("ix_oauth_consents_user_client", table_name="oauth_consents")
|
||||
op.drop_table("oauth_consents")
|
||||
op.drop_index(
|
||||
"ix_oauth_authorization_codes_expires_at",
|
||||
table_name="oauth_authorization_codes",
|
||||
)
|
||||
op.drop_index(
|
||||
op.f("ix_oauth_authorization_codes_code"),
|
||||
table_name="oauth_authorization_codes",
|
||||
)
|
||||
op.drop_index(
|
||||
"ix_oauth_authorization_codes_client_user",
|
||||
table_name="oauth_authorization_codes",
|
||||
)
|
||||
op.drop_table("oauth_authorization_codes")
|
||||
op.drop_index(op.f("ix_user_sessions_user_id"), table_name="user_sessions")
|
||||
op.drop_index("ix_user_sessions_user_active", table_name="user_sessions")
|
||||
op.drop_index(
|
||||
op.f("ix_user_sessions_refresh_token_jti"), table_name="user_sessions"
|
||||
)
|
||||
op.drop_index("ix_user_sessions_jti_active", table_name="user_sessions")
|
||||
op.drop_index(op.f("ix_user_sessions_is_active"), table_name="user_sessions")
|
||||
op.drop_table("user_sessions")
|
||||
op.drop_index(
|
||||
op.f("ix_user_organizations_is_active"), table_name="user_organizations"
|
||||
)
|
||||
op.drop_index("ix_user_org_user_active", table_name="user_organizations")
|
||||
op.drop_index("ix_user_org_role", table_name="user_organizations")
|
||||
op.drop_index("ix_user_org_org_active", table_name="user_organizations")
|
||||
op.drop_table("user_organizations")
|
||||
op.drop_index(op.f("ix_oauth_clients_is_active"), table_name="oauth_clients")
|
||||
op.drop_index(op.f("ix_oauth_clients_client_id"), table_name="oauth_clients")
|
||||
op.drop_table("oauth_clients")
|
||||
op.drop_index("ix_oauth_accounts_user_provider", table_name="oauth_accounts")
|
||||
op.drop_index(op.f("ix_oauth_accounts_user_id"), table_name="oauth_accounts")
|
||||
op.drop_index(op.f("ix_oauth_accounts_provider_email"), table_name="oauth_accounts")
|
||||
op.drop_index(op.f("ix_oauth_accounts_provider"), table_name="oauth_accounts")
|
||||
op.drop_table("oauth_accounts")
|
||||
op.drop_index(op.f("ix_users_locale"), table_name="users")
|
||||
op.drop_index(op.f("ix_users_is_superuser"), table_name="users")
|
||||
op.drop_index(op.f("ix_users_is_active"), table_name="users")
|
||||
op.drop_index(op.f("ix_users_email"), table_name="users")
|
||||
op.drop_index(op.f("ix_users_deleted_at"), table_name="users")
|
||||
op.drop_table("users")
|
||||
op.drop_index("ix_organizations_slug_active", table_name="organizations")
|
||||
op.drop_index(op.f("ix_organizations_slug"), table_name="organizations")
|
||||
op.drop_index("ix_organizations_name_active", table_name="organizations")
|
||||
op.drop_index(op.f("ix_organizations_name"), table_name="organizations")
|
||||
op.drop_index(op.f("ix_organizations_is_active"), table_name="organizations")
|
||||
op.drop_table("organizations")
|
||||
op.drop_index(op.f("ix_oauth_states_state"), table_name="oauth_states")
|
||||
op.drop_table("oauth_states")
|
||||
# ### end Alembic commands ###
|
||||
127
backend/app/alembic/versions/0002_add_performance_indexes.py
Normal file
127
backend/app/alembic/versions/0002_add_performance_indexes.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""Add performance indexes
|
||||
|
||||
Revision ID: 0002
|
||||
Revises: 0001
|
||||
Create Date: 2025-11-27
|
||||
|
||||
Performance indexes that Alembic cannot auto-detect:
|
||||
- Functional indexes (LOWER expressions)
|
||||
- Partial indexes (WHERE clauses)
|
||||
|
||||
These indexes use the ix_perf_ prefix and are excluded from autogenerate
|
||||
via the include_object() function in env.py.
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "0002"
|
||||
down_revision: str | None = "0001"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ==========================================================================
|
||||
# USERS TABLE - Performance indexes for authentication
|
||||
# ==========================================================================
|
||||
|
||||
# Case-insensitive email lookup for login/registration
|
||||
# Query: SELECT * FROM users WHERE LOWER(email) = LOWER(:email) AND deleted_at IS NULL
|
||||
# Impact: High - every login, registration check, password reset
|
||||
op.create_index(
|
||||
"ix_perf_users_email_lower",
|
||||
"users",
|
||||
[sa.text("LOWER(email)")],
|
||||
unique=False,
|
||||
postgresql_where=sa.text("deleted_at IS NULL"),
|
||||
)
|
||||
|
||||
# Active users lookup (non-soft-deleted)
|
||||
# Query: SELECT * FROM users WHERE deleted_at IS NULL AND ...
|
||||
# Impact: Medium - user listings, admin queries
|
||||
op.create_index(
|
||||
"ix_perf_users_active",
|
||||
"users",
|
||||
["is_active"],
|
||||
unique=False,
|
||||
postgresql_where=sa.text("deleted_at IS NULL"),
|
||||
)
|
||||
|
||||
# ==========================================================================
|
||||
# ORGANIZATIONS TABLE - Performance indexes for multi-tenant lookups
|
||||
# ==========================================================================
|
||||
|
||||
# Case-insensitive slug lookup for URL routing
|
||||
# Query: SELECT * FROM organizations WHERE LOWER(slug) = LOWER(:slug) AND is_active = true
|
||||
# Impact: Medium - every organization page load
|
||||
op.create_index(
|
||||
"ix_perf_organizations_slug_lower",
|
||||
"organizations",
|
||||
[sa.text("LOWER(slug)")],
|
||||
unique=False,
|
||||
postgresql_where=sa.text("is_active = true"),
|
||||
)
|
||||
|
||||
# ==========================================================================
|
||||
# USER SESSIONS TABLE - Performance indexes for session management
|
||||
# ==========================================================================
|
||||
|
||||
# Expired session cleanup
|
||||
# Query: SELECT * FROM user_sessions WHERE expires_at < NOW() AND is_active = true
|
||||
# Impact: Medium - background cleanup jobs
|
||||
op.create_index(
|
||||
"ix_perf_user_sessions_expires",
|
||||
"user_sessions",
|
||||
["expires_at"],
|
||||
unique=False,
|
||||
postgresql_where=sa.text("is_active = true"),
|
||||
)
|
||||
|
||||
# ==========================================================================
|
||||
# OAUTH PROVIDER TOKENS - Performance indexes for token management
|
||||
# ==========================================================================
|
||||
|
||||
# Expired refresh token cleanup
|
||||
# Query: SELECT * FROM oauth_provider_refresh_tokens WHERE expires_at < NOW() AND revoked = false
|
||||
# Impact: Medium - OAuth token cleanup, validation
|
||||
op.create_index(
|
||||
"ix_perf_oauth_refresh_tokens_expires",
|
||||
"oauth_provider_refresh_tokens",
|
||||
["expires_at"],
|
||||
unique=False,
|
||||
postgresql_where=sa.text("revoked = false"),
|
||||
)
|
||||
|
||||
# ==========================================================================
|
||||
# OAUTH AUTHORIZATION CODES - Performance indexes for auth flow
|
||||
# ==========================================================================
|
||||
|
||||
# Expired authorization code cleanup
|
||||
# Query: DELETE FROM oauth_authorization_codes WHERE expires_at < NOW() AND used = false
|
||||
# Impact: Low-Medium - OAuth cleanup jobs
|
||||
op.create_index(
|
||||
"ix_perf_oauth_auth_codes_expires",
|
||||
"oauth_authorization_codes",
|
||||
["expires_at"],
|
||||
unique=False,
|
||||
postgresql_where=sa.text("used = false"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop indexes in reverse order
|
||||
op.drop_index(
|
||||
"ix_perf_oauth_auth_codes_expires", table_name="oauth_authorization_codes"
|
||||
)
|
||||
op.drop_index(
|
||||
"ix_perf_oauth_refresh_tokens_expires",
|
||||
table_name="oauth_provider_refresh_tokens",
|
||||
)
|
||||
op.drop_index("ix_perf_user_sessions_expires", table_name="user_sessions")
|
||||
op.drop_index("ix_perf_organizations_slug_lower", table_name="organizations")
|
||||
op.drop_index("ix_perf_users_active", table_name="users")
|
||||
op.drop_index("ix_perf_users_email_lower", table_name="users")
|
||||
@@ -0,0 +1,35 @@
|
||||
"""rename oauth account token fields drop encrypted suffix
|
||||
|
||||
Revision ID: 0003
|
||||
Revises: 0002
|
||||
Create Date: 2026-02-27 01:03:18.869178
|
||||
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "0003"
|
||||
down_revision: str | None = "0002"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column(
|
||||
"oauth_accounts", "access_token_encrypted", new_column_name="access_token"
|
||||
)
|
||||
op.alter_column(
|
||||
"oauth_accounts", "refresh_token_encrypted", new_column_name="refresh_token"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.alter_column(
|
||||
"oauth_accounts", "access_token", new_column_name="access_token_encrypted"
|
||||
)
|
||||
op.alter_column(
|
||||
"oauth_accounts", "refresh_token", new_column_name="refresh_token_encrypted"
|
||||
)
|
||||
@@ -1,78 +0,0 @@
|
||||
"""add_performance_indexes
|
||||
|
||||
Revision ID: 1174fffbe3e4
|
||||
Revises: fbf6318a8a36
|
||||
Create Date: 2025-11-01 04:15:25.367010
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '1174fffbe3e4'
|
||||
down_revision: Union[str, None] = 'fbf6318a8a36'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Add performance indexes for optimized queries."""
|
||||
|
||||
# Index for session cleanup queries
|
||||
# Optimizes: DELETE WHERE is_active = FALSE AND expires_at < now AND created_at < cutoff
|
||||
op.create_index(
|
||||
'ix_user_sessions_cleanup',
|
||||
'user_sessions',
|
||||
['is_active', 'expires_at', 'created_at'],
|
||||
unique=False,
|
||||
postgresql_where=sa.text('is_active = false')
|
||||
)
|
||||
|
||||
# Index for user search queries (basic trigram support without pg_trgm extension)
|
||||
# Optimizes: WHERE email ILIKE '%search%' OR first_name ILIKE '%search%'
|
||||
# Note: For better performance, consider enabling pg_trgm extension
|
||||
op.create_index(
|
||||
'ix_users_email_lower',
|
||||
'users',
|
||||
[sa.text('LOWER(email)')],
|
||||
unique=False,
|
||||
postgresql_where=sa.text('deleted_at IS NULL')
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
'ix_users_first_name_lower',
|
||||
'users',
|
||||
[sa.text('LOWER(first_name)')],
|
||||
unique=False,
|
||||
postgresql_where=sa.text('deleted_at IS NULL')
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
'ix_users_last_name_lower',
|
||||
'users',
|
||||
[sa.text('LOWER(last_name)')],
|
||||
unique=False,
|
||||
postgresql_where=sa.text('deleted_at IS NULL')
|
||||
)
|
||||
|
||||
# Index for organization search
|
||||
op.create_index(
|
||||
'ix_organizations_name_lower',
|
||||
'organizations',
|
||||
[sa.text('LOWER(name)')],
|
||||
unique=False
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Remove performance indexes."""
|
||||
|
||||
# Drop indexes in reverse order
|
||||
op.drop_index('ix_organizations_name_lower', table_name='organizations')
|
||||
op.drop_index('ix_users_last_name_lower', table_name='users')
|
||||
op.drop_index('ix_users_first_name_lower', table_name='users')
|
||||
op.drop_index('ix_users_email_lower', table_name='users')
|
||||
op.drop_index('ix_user_sessions_cleanup', table_name='user_sessions')
|
||||
@@ -1,34 +0,0 @@
|
||||
"""add_soft_delete_to_users
|
||||
|
||||
Revision ID: 2d0fcec3b06d
|
||||
Revises: 9e4f2a1b8c7d
|
||||
Create Date: 2025-10-30 16:40:21.000021
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '2d0fcec3b06d'
|
||||
down_revision: Union[str, None] = '9e4f2a1b8c7d'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add deleted_at column for soft deletes
|
||||
op.add_column('users', sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True))
|
||||
|
||||
# Add index on deleted_at for efficient queries
|
||||
op.create_index('ix_users_deleted_at', 'users', ['deleted_at'])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove index
|
||||
op.drop_index('ix_users_deleted_at', table_name='users')
|
||||
|
||||
# Remove column
|
||||
op.drop_column('users', 'deleted_at')
|
||||
@@ -1,46 +0,0 @@
|
||||
"""Add all initial models
|
||||
|
||||
Revision ID: 38bf9e7e74b3
|
||||
Revises: 7396957cbe80
|
||||
Create Date: 2025-02-28 09:19:33.212278
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '38bf9e7e74b3'
|
||||
down_revision: Union[str, None] = '7396957cbe80'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
||||
op.create_table('users',
|
||||
sa.Column('email', sa.String(), nullable=False),
|
||||
sa.Column('password_hash', sa.String(), nullable=False),
|
||||
sa.Column('first_name', sa.String(), nullable=False),
|
||||
sa.Column('last_name', sa.String(), nullable=True),
|
||||
sa.Column('phone_number', sa.String(), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||
sa.Column('is_superuser', sa.Boolean(), nullable=False),
|
||||
sa.Column('preferences', sa.JSON(), nullable=True),
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=True)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f('ix_users_email'), table_name='users')
|
||||
op.drop_table('users')
|
||||
# ### end Alembic commands ###
|
||||
@@ -1,102 +0,0 @@
|
||||
"""add_user_sessions_table
|
||||
|
||||
Revision ID: 549b50ea888d
|
||||
Revises: b76c725fc3cf
|
||||
Create Date: 2025-10-31 07:41:18.729544
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '549b50ea888d'
|
||||
down_revision: Union[str, None] = 'b76c725fc3cf'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create user_sessions table for per-device session management
|
||||
op.create_table(
|
||||
'user_sessions',
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('user_id', sa.UUID(), nullable=False),
|
||||
sa.Column('refresh_token_jti', sa.String(length=255), nullable=False),
|
||||
sa.Column('device_name', sa.String(length=255), nullable=True),
|
||||
sa.Column('device_id', sa.String(length=255), nullable=True),
|
||||
sa.Column('ip_address', sa.String(length=45), nullable=True),
|
||||
sa.Column('user_agent', sa.String(length=500), nullable=True),
|
||||
sa.Column('last_used_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False, server_default='true'),
|
||||
sa.Column('location_city', sa.String(length=100), nullable=True),
|
||||
sa.Column('location_country', sa.String(length=100), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
|
||||
# Create foreign key to users table
|
||||
op.create_foreign_key(
|
||||
'fk_user_sessions_user_id',
|
||||
'user_sessions',
|
||||
'users',
|
||||
['user_id'],
|
||||
['id'],
|
||||
ondelete='CASCADE'
|
||||
)
|
||||
|
||||
# Create indexes for performance
|
||||
# 1. Lookup session by refresh token JTI (most common query)
|
||||
op.create_index(
|
||||
'ix_user_sessions_jti',
|
||||
'user_sessions',
|
||||
['refresh_token_jti'],
|
||||
unique=True
|
||||
)
|
||||
|
||||
# 2. Lookup sessions by user ID
|
||||
op.create_index(
|
||||
'ix_user_sessions_user_id',
|
||||
'user_sessions',
|
||||
['user_id']
|
||||
)
|
||||
|
||||
# 3. Composite index for active sessions by user
|
||||
op.create_index(
|
||||
'ix_user_sessions_user_active',
|
||||
'user_sessions',
|
||||
['user_id', 'is_active']
|
||||
)
|
||||
|
||||
# 4. Index on expires_at for cleanup job
|
||||
op.create_index(
|
||||
'ix_user_sessions_expires_at',
|
||||
'user_sessions',
|
||||
['expires_at']
|
||||
)
|
||||
|
||||
# 5. Composite index for active session lookup by JTI
|
||||
op.create_index(
|
||||
'ix_user_sessions_jti_active',
|
||||
'user_sessions',
|
||||
['refresh_token_jti', 'is_active']
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop indexes first
|
||||
op.drop_index('ix_user_sessions_jti_active', table_name='user_sessions')
|
||||
op.drop_index('ix_user_sessions_expires_at', table_name='user_sessions')
|
||||
op.drop_index('ix_user_sessions_user_active', table_name='user_sessions')
|
||||
op.drop_index('ix_user_sessions_user_id', table_name='user_sessions')
|
||||
op.drop_index('ix_user_sessions_jti', table_name='user_sessions')
|
||||
|
||||
# Drop foreign key
|
||||
op.drop_constraint('fk_user_sessions_user_id', 'user_sessions', type_='foreignkey')
|
||||
|
||||
# Drop table
|
||||
op.drop_table('user_sessions')
|
||||
@@ -1,24 +0,0 @@
|
||||
"""Initial empty migration
|
||||
|
||||
Revision ID: 7396957cbe80
|
||||
Revises:
|
||||
Create Date: 2025-02-27 12:47:46.445313
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '7396957cbe80'
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
pass
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
pass
|
||||
@@ -1,84 +0,0 @@
|
||||
"""Add missing indexes and fix column types
|
||||
|
||||
Revision ID: 9e4f2a1b8c7d
|
||||
Revises: 38bf9e7e74b3
|
||||
Create Date: 2025-10-30 10:00:00.000000
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '9e4f2a1b8c7d'
|
||||
down_revision: Union[str, None] = '38bf9e7e74b3'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add missing indexes for is_active and is_superuser
|
||||
op.create_index(op.f('ix_users_is_active'), 'users', ['is_active'], unique=False)
|
||||
op.create_index(op.f('ix_users_is_superuser'), 'users', ['is_superuser'], unique=False)
|
||||
|
||||
# Fix column types to match model definitions with explicit lengths
|
||||
op.alter_column('users', 'email',
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(length=255),
|
||||
nullable=False)
|
||||
|
||||
op.alter_column('users', 'password_hash',
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(length=255),
|
||||
nullable=False)
|
||||
|
||||
op.alter_column('users', 'first_name',
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(length=100),
|
||||
nullable=False,
|
||||
server_default='user') # Add server default
|
||||
|
||||
op.alter_column('users', 'last_name',
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(length=100),
|
||||
nullable=True)
|
||||
|
||||
op.alter_column('users', 'phone_number',
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(length=20),
|
||||
nullable=True)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Revert column types
|
||||
op.alter_column('users', 'phone_number',
|
||||
existing_type=sa.String(length=20),
|
||||
type_=sa.String(),
|
||||
nullable=True)
|
||||
|
||||
op.alter_column('users', 'last_name',
|
||||
existing_type=sa.String(length=100),
|
||||
type_=sa.String(),
|
||||
nullable=True)
|
||||
|
||||
op.alter_column('users', 'first_name',
|
||||
existing_type=sa.String(length=100),
|
||||
type_=sa.String(),
|
||||
nullable=False,
|
||||
server_default=None) # Remove server default
|
||||
|
||||
op.alter_column('users', 'password_hash',
|
||||
existing_type=sa.String(length=255),
|
||||
type_=sa.String(),
|
||||
nullable=False)
|
||||
|
||||
op.alter_column('users', 'email',
|
||||
existing_type=sa.String(length=255),
|
||||
type_=sa.String(),
|
||||
nullable=False)
|
||||
|
||||
# Drop indexes
|
||||
op.drop_index(op.f('ix_users_is_superuser'), table_name='users')
|
||||
op.drop_index(op.f('ix_users_is_active'), table_name='users')
|
||||
@@ -1,52 +0,0 @@
|
||||
"""add_composite_indexes
|
||||
|
||||
Revision ID: b76c725fc3cf
|
||||
Revises: 2d0fcec3b06d
|
||||
Create Date: 2025-10-30 16:41:33.273135
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'b76c725fc3cf'
|
||||
down_revision: Union[str, None] = '2d0fcec3b06d'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add composite indexes for common query patterns
|
||||
|
||||
# Composite index for filtering active users by role
|
||||
op.create_index(
|
||||
'ix_users_active_superuser',
|
||||
'users',
|
||||
['is_active', 'is_superuser'],
|
||||
postgresql_where=sa.text('deleted_at IS NULL')
|
||||
)
|
||||
|
||||
# Composite index for sorting active users by creation date
|
||||
op.create_index(
|
||||
'ix_users_active_created',
|
||||
'users',
|
||||
['is_active', 'created_at'],
|
||||
postgresql_where=sa.text('deleted_at IS NULL')
|
||||
)
|
||||
|
||||
# Composite index for email lookup of non-deleted users
|
||||
op.create_index(
|
||||
'ix_users_email_not_deleted',
|
||||
'users',
|
||||
['email', 'deleted_at']
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove composite indexes
|
||||
op.drop_index('ix_users_email_not_deleted', table_name='users')
|
||||
op.drop_index('ix_users_active_created', table_name='users')
|
||||
op.drop_index('ix_users_active_superuser', table_name='users')
|
||||
@@ -1,106 +0,0 @@
|
||||
"""add_organizations_and_user_organizations
|
||||
|
||||
Revision ID: fbf6318a8a36
|
||||
Revises: 549b50ea888d
|
||||
Create Date: 2025-10-31 12:08:05.141353
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'fbf6318a8a36'
|
||||
down_revision: Union[str, None] = '549b50ea888d'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create organizations table
|
||||
op.create_table(
|
||||
'organizations',
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('name', sa.String(length=255), nullable=False),
|
||||
sa.Column('slug', sa.String(length=255), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False, server_default='true'),
|
||||
sa.Column('settings', sa.JSON(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
|
||||
# Create indexes for organizations
|
||||
op.create_index('ix_organizations_name', 'organizations', ['name'])
|
||||
op.create_index('ix_organizations_slug', 'organizations', ['slug'], unique=True)
|
||||
op.create_index('ix_organizations_is_active', 'organizations', ['is_active'])
|
||||
op.create_index('ix_organizations_name_active', 'organizations', ['name', 'is_active'])
|
||||
op.create_index('ix_organizations_slug_active', 'organizations', ['slug', 'is_active'])
|
||||
|
||||
# Create user_organizations junction table
|
||||
op.create_table(
|
||||
'user_organizations',
|
||||
sa.Column('user_id', sa.UUID(), nullable=False),
|
||||
sa.Column('organization_id', sa.UUID(), nullable=False),
|
||||
sa.Column('role', sa.Enum('OWNER', 'ADMIN', 'MEMBER', 'GUEST', name='organizationrole'), nullable=False, server_default='MEMBER'),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False, server_default='true'),
|
||||
sa.Column('custom_permissions', sa.String(length=500), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint('user_id', 'organization_id')
|
||||
)
|
||||
|
||||
# Create foreign keys
|
||||
op.create_foreign_key(
|
||||
'fk_user_organizations_user_id',
|
||||
'user_organizations',
|
||||
'users',
|
||||
['user_id'],
|
||||
['id'],
|
||||
ondelete='CASCADE'
|
||||
)
|
||||
op.create_foreign_key(
|
||||
'fk_user_organizations_organization_id',
|
||||
'user_organizations',
|
||||
'organizations',
|
||||
['organization_id'],
|
||||
['id'],
|
||||
ondelete='CASCADE'
|
||||
)
|
||||
|
||||
# Create indexes for user_organizations
|
||||
op.create_index('ix_user_organizations_role', 'user_organizations', ['role'])
|
||||
op.create_index('ix_user_organizations_is_active', 'user_organizations', ['is_active'])
|
||||
op.create_index('ix_user_org_user_active', 'user_organizations', ['user_id', 'is_active'])
|
||||
op.create_index('ix_user_org_org_active', 'user_organizations', ['organization_id', 'is_active'])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop indexes for user_organizations
|
||||
op.drop_index('ix_user_org_org_active', table_name='user_organizations')
|
||||
op.drop_index('ix_user_org_user_active', table_name='user_organizations')
|
||||
op.drop_index('ix_user_organizations_is_active', table_name='user_organizations')
|
||||
op.drop_index('ix_user_organizations_role', table_name='user_organizations')
|
||||
|
||||
# Drop foreign keys
|
||||
op.drop_constraint('fk_user_organizations_organization_id', 'user_organizations', type_='foreignkey')
|
||||
op.drop_constraint('fk_user_organizations_user_id', 'user_organizations', type_='foreignkey')
|
||||
|
||||
# Drop user_organizations table
|
||||
op.drop_table('user_organizations')
|
||||
|
||||
# Drop indexes for organizations
|
||||
op.drop_index('ix_organizations_slug_active', table_name='organizations')
|
||||
op.drop_index('ix_organizations_name_active', table_name='organizations')
|
||||
op.drop_index('ix_organizations_is_active', table_name='organizations')
|
||||
op.drop_index('ix_organizations_slug', table_name='organizations')
|
||||
op.drop_index('ix_organizations_name', table_name='organizations')
|
||||
|
||||
# Drop organizations table
|
||||
op.drop_table('organizations')
|
||||
|
||||
# Drop enum type
|
||||
op.execute('DROP TYPE IF EXISTS organizationrole')
|
||||
@@ -1,22 +1,19 @@
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Depends, HTTPException, status, Header
|
||||
from fastapi import Depends, Header, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from fastapi.security.utils import get_authorization_scheme_param
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.auth import get_token_data, TokenExpiredError, TokenInvalidError
|
||||
from app.core.auth import TokenExpiredError, TokenInvalidError, get_token_data
|
||||
from app.core.database import get_db
|
||||
from app.models.user import User
|
||||
from app.repositories.user import user_repo
|
||||
|
||||
# OAuth2 configuration
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
token: str = Depends(oauth2_scheme)
|
||||
db: AsyncSession = Depends(get_db), token: str = Depends(oauth2_scheme)
|
||||
) -> User:
|
||||
"""
|
||||
Get the current authenticated user.
|
||||
@@ -35,22 +32,17 @@ async def get_current_user(
|
||||
# Decode token and get user ID
|
||||
token_data = get_token_data(token)
|
||||
|
||||
# Get user from database
|
||||
result = await db.execute(
|
||||
select(User).where(User.id == token_data.user_id)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
# Get user from database via repository
|
||||
user = await user_repo.get(db, id=str(token_data.user_id))
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
||||
)
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Inactive user"
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
|
||||
)
|
||||
|
||||
return user
|
||||
@@ -59,19 +51,17 @@ async def get_current_user(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Token expired",
|
||||
headers={"WWW-Authenticate": "Bearer"}
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
except TokenInvalidError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"}
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
|
||||
def get_current_active_user(
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> User:
|
||||
def get_current_active_user(current_user: User = Depends(get_current_user)) -> User:
|
||||
"""
|
||||
Check if the current user is active.
|
||||
|
||||
@@ -86,15 +76,12 @@ def get_current_active_user(
|
||||
"""
|
||||
if not current_user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Inactive user"
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
|
||||
)
|
||||
return current_user
|
||||
|
||||
|
||||
def get_current_superuser(
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> User:
|
||||
def get_current_superuser(current_user: User = Depends(get_current_user)) -> User:
|
||||
"""
|
||||
Check if the current user is a superuser.
|
||||
|
||||
@@ -109,13 +96,12 @@ def get_current_superuser(
|
||||
"""
|
||||
if not current_user.is_superuser:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Not enough permissions"
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="Not enough permissions"
|
||||
)
|
||||
return current_user
|
||||
|
||||
|
||||
async def get_optional_token(authorization: str = Header(None)) -> Optional[str]:
|
||||
async def get_optional_token(authorization: str = Header(None)) -> str | None:
|
||||
"""
|
||||
Get the token from the Authorization header without requiring it.
|
||||
|
||||
@@ -139,9 +125,8 @@ async def get_optional_token(authorization: str = Header(None)) -> Optional[str]
|
||||
|
||||
|
||||
async def get_optional_current_user(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
token: Optional[str] = Depends(get_optional_token)
|
||||
) -> Optional[User]:
|
||||
db: AsyncSession = Depends(get_db), token: str | None = Depends(get_optional_token)
|
||||
) -> User | None:
|
||||
"""
|
||||
Get the current user if authenticated, otherwise return None.
|
||||
Useful for endpoints that work with both authenticated and unauthenticated users.
|
||||
@@ -158,12 +143,9 @@ async def get_optional_current_user(
|
||||
|
||||
try:
|
||||
token_data = get_token_data(token)
|
||||
result = await db.execute(
|
||||
select(User).where(User.id == token_data.user_id)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
user = await user_repo.get(db, id=str(token_data.user_id))
|
||||
if not user or not user.is_active:
|
||||
return None
|
||||
return user
|
||||
except (TokenExpiredError, TokenInvalidError):
|
||||
return None
|
||||
return None
|
||||
|
||||
132
backend/app/api/dependencies/locale.py
Normal file
132
backend/app/api/dependencies/locale.py
Normal file
@@ -0,0 +1,132 @@
|
||||
# app/api/dependencies/locale.py
|
||||
"""
|
||||
Locale detection dependency for internationalization (i18n).
|
||||
|
||||
Implements a three-tier fallback system:
|
||||
1. User's saved preference (if authenticated and user.locale is set)
|
||||
2. Accept-Language header (for unauthenticated users or no saved preference)
|
||||
3. Default to English ("en")
|
||||
"""
|
||||
|
||||
from fastapi import Depends, Request
|
||||
|
||||
from app.api.dependencies.auth import get_optional_current_user
|
||||
from app.models.user import User
|
||||
|
||||
# Supported locales (BCP 47 format)
|
||||
# Template showcases English and Italian
|
||||
# Users can extend by adding more locales here
|
||||
# Note: Stored in lowercase for case-insensitive matching
|
||||
SUPPORTED_LOCALES = {"en", "it", "en-us", "en-gb", "it-it"}
|
||||
DEFAULT_LOCALE = "en"
|
||||
|
||||
|
||||
def parse_accept_language(accept_language: str) -> str | None:
|
||||
"""
|
||||
Parse the Accept-Language header and return the best matching supported locale.
|
||||
|
||||
The Accept-Language header format is:
|
||||
"it-IT,it;q=0.9,en-US;q=0.8,en;q=0.7"
|
||||
|
||||
This function extracts locales in priority order (by quality value) and returns
|
||||
the first one that matches our supported locales.
|
||||
|
||||
Args:
|
||||
accept_language: The Accept-Language header value
|
||||
|
||||
Returns:
|
||||
The best matching locale code, or None if no match found
|
||||
|
||||
Examples:
|
||||
>>> parse_accept_language("it-IT,it;q=0.9,en;q=0.8")
|
||||
"it-IT" # or "it" if it-IT is not supported
|
||||
>>> parse_accept_language("fr-FR,fr;q=0.9")
|
||||
None # French not supported
|
||||
"""
|
||||
if not accept_language:
|
||||
return None
|
||||
|
||||
# Split by comma to get individual locale entries
|
||||
# Format: "locale;q=weight" or just "locale"
|
||||
locales = []
|
||||
for entry in accept_language.split(","):
|
||||
# Remove quality value (;q=0.9) if present
|
||||
locale = entry.split(";")[0].strip()
|
||||
if locale:
|
||||
locales.append(locale)
|
||||
|
||||
# Check each locale in priority order
|
||||
for locale in locales:
|
||||
locale_lower = locale.lower()
|
||||
|
||||
# Try exact match first (e.g., "it-IT")
|
||||
if locale_lower in SUPPORTED_LOCALES:
|
||||
return locale_lower
|
||||
|
||||
# Try language code only (e.g., "it" from "it-IT")
|
||||
lang_code = locale_lower.split("-")[0]
|
||||
if lang_code in SUPPORTED_LOCALES:
|
||||
return lang_code
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def get_locale(
|
||||
request: Request,
|
||||
current_user: User | None = Depends(get_optional_current_user),
|
||||
) -> str:
|
||||
"""
|
||||
Detect and return the appropriate locale for the current request.
|
||||
|
||||
Three-tier fallback system:
|
||||
1. **User Preference** (highest priority)
|
||||
- If user is authenticated and has a saved locale preference, use it
|
||||
- This persists across sessions and devices
|
||||
|
||||
2. **Accept-Language Header** (second priority)
|
||||
- Parse the Accept-Language header from the request
|
||||
- Match against supported locales
|
||||
- Common for browser requests
|
||||
|
||||
3. **Default Locale** (fallback)
|
||||
- Return "en" (English) if no user preference and no header match
|
||||
|
||||
Args:
|
||||
request: The FastAPI request object (for accessing headers)
|
||||
current_user: The current authenticated user (optional)
|
||||
|
||||
Returns:
|
||||
A valid locale code from SUPPORTED_LOCALES (guaranteed to be supported)
|
||||
|
||||
Examples:
|
||||
>>> # Authenticated user with saved preference
|
||||
>>> await get_locale(request, user_with_locale_it)
|
||||
"it"
|
||||
|
||||
>>> # Unauthenticated user with Italian browser
|
||||
>>> # (request has Accept-Language: it-IT,it;q=0.9)
|
||||
>>> await get_locale(request, None)
|
||||
"it"
|
||||
|
||||
>>> # Unauthenticated user with unsupported language
|
||||
>>> # (request has Accept-Language: fr-FR,fr;q=0.9)
|
||||
>>> await get_locale(request, None)
|
||||
"en"
|
||||
"""
|
||||
# Priority 1: User's saved preference
|
||||
if current_user and current_user.locale:
|
||||
# Validate that saved locale is still supported
|
||||
# (in case SUPPORTED_LOCALES changed after user set preference)
|
||||
locale_value = str(current_user.locale)
|
||||
if locale_value in SUPPORTED_LOCALES:
|
||||
return locale_value
|
||||
|
||||
# Priority 2: Accept-Language header
|
||||
accept_language = request.headers.get("accept-language", "")
|
||||
if accept_language:
|
||||
detected_locale = parse_accept_language(accept_language)
|
||||
if detected_locale:
|
||||
return detected_locale
|
||||
|
||||
# Priority 3: Default fallback
|
||||
return DEFAULT_LOCALE
|
||||
@@ -7,7 +7,7 @@ These dependencies are optional and flexible:
|
||||
- Use require_org_role for organization-specific access control
|
||||
- Projects can choose to use these or implement their own permission system
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
@@ -15,14 +15,12 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.dependencies.auth import get_current_user
|
||||
from app.core.database import get_db
|
||||
from app.crud.organization import organization as organization_crud
|
||||
from app.models.user import User
|
||||
from app.models.user_organization import OrganizationRole
|
||||
from app.services.organization_service import organization_service
|
||||
|
||||
|
||||
def require_superuser(
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> User:
|
||||
def require_superuser(current_user: User = Depends(get_current_user)) -> User:
|
||||
"""
|
||||
Dependency to ensure the current user is a superuser.
|
||||
|
||||
@@ -36,7 +34,7 @@ def require_superuser(
|
||||
if not current_user.is_superuser:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Superuser privileges required"
|
||||
detail="Superuser privileges required",
|
||||
)
|
||||
return current_user
|
||||
|
||||
@@ -62,7 +60,7 @@ class OrganizationPermission:
|
||||
self,
|
||||
organization_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> User:
|
||||
"""
|
||||
Check if user has required role in the organization.
|
||||
@@ -83,22 +81,20 @@ class OrganizationPermission:
|
||||
return current_user
|
||||
|
||||
# Get user's role in organization
|
||||
user_role = await organization_crud.get_user_role_in_org(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
organization_id=organization_id
|
||||
user_role = await organization_service.get_user_role_in_org(
|
||||
db, user_id=current_user.id, organization_id=organization_id
|
||||
)
|
||||
|
||||
if not user_role:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Not a member of this organization"
|
||||
detail="Not a member of this organization",
|
||||
)
|
||||
|
||||
if user_role not in self.allowed_roles:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Role {user_role} not authorized. Required: {self.allowed_roles}"
|
||||
detail=f"Role {user_role} not authorized. Required: {self.allowed_roles}",
|
||||
)
|
||||
|
||||
return current_user
|
||||
@@ -106,18 +102,18 @@ class OrganizationPermission:
|
||||
|
||||
# Common permission presets for convenience
|
||||
require_org_owner = OrganizationPermission([OrganizationRole.OWNER])
|
||||
require_org_admin = OrganizationPermission([OrganizationRole.OWNER, OrganizationRole.ADMIN])
|
||||
require_org_member = OrganizationPermission([
|
||||
OrganizationRole.OWNER,
|
||||
OrganizationRole.ADMIN,
|
||||
OrganizationRole.MEMBER
|
||||
])
|
||||
require_org_admin = OrganizationPermission(
|
||||
[OrganizationRole.OWNER, OrganizationRole.ADMIN]
|
||||
)
|
||||
require_org_member = OrganizationPermission(
|
||||
[OrganizationRole.OWNER, OrganizationRole.ADMIN, OrganizationRole.MEMBER]
|
||||
)
|
||||
|
||||
|
||||
async def require_org_membership(
|
||||
organization_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> User:
|
||||
"""
|
||||
Ensure user is a member of the organization (any role).
|
||||
@@ -127,16 +123,14 @@ async def require_org_membership(
|
||||
if current_user.is_superuser:
|
||||
return current_user
|
||||
|
||||
user_role = await organization_crud.get_user_role_in_org(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
organization_id=organization_id
|
||||
user_role = await organization_service.get_user_role_in_org(
|
||||
db, user_id=current_user.id, organization_id=organization_id
|
||||
)
|
||||
|
||||
if not user_role:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Not a member of this organization"
|
||||
detail="Not a member of this organization",
|
||||
)
|
||||
|
||||
return current_user
|
||||
|
||||
41
backend/app/api/dependencies/services.py
Normal file
41
backend/app/api/dependencies/services.py
Normal file
@@ -0,0 +1,41 @@
|
||||
# app/api/dependencies/services.py
|
||||
"""FastAPI dependency functions for service singletons."""
|
||||
|
||||
from app.services import oauth_provider_service
|
||||
from app.services.auth_service import AuthService
|
||||
from app.services.oauth_service import OAuthService
|
||||
from app.services.organization_service import OrganizationService, organization_service
|
||||
from app.services.session_service import SessionService, session_service
|
||||
from app.services.user_service import UserService, user_service
|
||||
|
||||
|
||||
def get_auth_service() -> AuthService:
|
||||
"""Return the AuthService singleton for dependency injection."""
|
||||
from app.services.auth_service import AuthService as _AuthService
|
||||
|
||||
return _AuthService()
|
||||
|
||||
|
||||
def get_user_service() -> UserService:
|
||||
"""Return the UserService singleton for dependency injection."""
|
||||
return user_service
|
||||
|
||||
|
||||
def get_organization_service() -> OrganizationService:
|
||||
"""Return the OrganizationService singleton for dependency injection."""
|
||||
return organization_service
|
||||
|
||||
|
||||
def get_session_service() -> SessionService:
|
||||
"""Return the SessionService singleton for dependency injection."""
|
||||
return session_service
|
||||
|
||||
|
||||
def get_oauth_service() -> OAuthService:
|
||||
"""Return OAuthService for dependency injection."""
|
||||
return OAuthService()
|
||||
|
||||
|
||||
def get_oauth_provider_service():
|
||||
"""Return the oauth_provider_service module for dependency injection."""
|
||||
return oauth_provider_service
|
||||
@@ -1,10 +1,24 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.api.routes import auth, users, sessions, admin, organizations
|
||||
from app.api.routes import (
|
||||
admin,
|
||||
auth,
|
||||
oauth,
|
||||
oauth_provider,
|
||||
organizations,
|
||||
sessions,
|
||||
users,
|
||||
)
|
||||
|
||||
api_router = APIRouter()
|
||||
api_router.include_router(auth.router, prefix="/auth", tags=["Authentication"])
|
||||
api_router.include_router(oauth.router, prefix="/oauth", tags=["OAuth"])
|
||||
api_router.include_router(
|
||||
oauth_provider.router, prefix="/oauth", tags=["OAuth Provider"]
|
||||
)
|
||||
api_router.include_router(users.router, prefix="/users", tags=["Users"])
|
||||
api_router.include_router(sessions.router, prefix="/sessions", tags=["Sessions"])
|
||||
api_router.include_router(admin.router, prefix="/admin", tags=["Admin"])
|
||||
api_router.include_router(organizations.router, prefix="/organizations", tags=["Organizations"])
|
||||
api_router.include_router(
|
||||
organizations.router, prefix="/organizations", tags=["Organizations"]
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,40 +1,44 @@
|
||||
# app/api/routes/auth.py
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.dependencies.auth import get_current_user
|
||||
from app.core.auth import TokenExpiredError, TokenInvalidError, decode_token
|
||||
from app.core.auth import get_password_hash
|
||||
from app.core.auth import (
|
||||
TokenExpiredError,
|
||||
TokenInvalidError,
|
||||
decode_token,
|
||||
)
|
||||
from app.core.database import get_db
|
||||
from app.core.exceptions import (
|
||||
AuthenticationError as AuthError,
|
||||
DatabaseError,
|
||||
ErrorCode
|
||||
DuplicateError,
|
||||
ErrorCode,
|
||||
)
|
||||
from app.crud.session import session as session_crud
|
||||
from app.crud.user import user as user_crud
|
||||
from app.models.user import User
|
||||
from app.schemas.common import MessageResponse
|
||||
from app.schemas.sessions import SessionCreate, LogoutRequest
|
||||
from app.schemas.sessions import LogoutRequest, SessionCreate
|
||||
from app.schemas.users import (
|
||||
LoginRequest,
|
||||
PasswordResetConfirm,
|
||||
PasswordResetRequest,
|
||||
RefreshTokenRequest,
|
||||
Token,
|
||||
UserCreate,
|
||||
UserResponse,
|
||||
Token,
|
||||
LoginRequest,
|
||||
RefreshTokenRequest,
|
||||
PasswordResetRequest,
|
||||
PasswordResetConfirm
|
||||
)
|
||||
from app.services.auth_service import AuthService, AuthenticationError
|
||||
from app.services.auth_service import AuthenticationError, AuthService
|
||||
from app.services.email_service import email_service
|
||||
from app.services.session_service import session_service
|
||||
from app.services.user_service import user_service
|
||||
from app.utils.device import extract_device_info
|
||||
from app.utils.security import create_password_reset_token, verify_password_reset_token
|
||||
|
||||
@@ -54,7 +58,7 @@ async def _create_login_session(
|
||||
request: Request,
|
||||
user: User,
|
||||
tokens: Token,
|
||||
login_type: str = "login"
|
||||
login_type: str = "login",
|
||||
) -> None:
|
||||
"""
|
||||
Create a session record for successful login.
|
||||
@@ -81,29 +85,35 @@ async def _create_login_session(
|
||||
device_id=device_info.device_id,
|
||||
ip_address=device_info.ip_address,
|
||||
user_agent=device_info.user_agent,
|
||||
last_used_at=datetime.now(timezone.utc),
|
||||
expires_at=datetime.fromtimestamp(refresh_payload.exp, tz=timezone.utc),
|
||||
last_used_at=datetime.now(UTC),
|
||||
expires_at=datetime.fromtimestamp(refresh_payload.exp, tz=UTC),
|
||||
location_city=device_info.location_city,
|
||||
location_country=device_info.location_country,
|
||||
)
|
||||
|
||||
await session_crud.create_session(db, obj_in=session_data)
|
||||
await session_service.create_session(db, obj_in=session_data)
|
||||
|
||||
logger.info(
|
||||
f"{login_type.capitalize()} successful: {user.email} from {device_info.device_name} "
|
||||
f"(IP: {device_info.ip_address})"
|
||||
"%s successful: %s from %s (IP: %s)",
|
||||
login_type.capitalize(),
|
||||
user.email,
|
||||
device_info.device_name,
|
||||
device_info.ip_address,
|
||||
)
|
||||
except Exception as session_err:
|
||||
# Log but don't fail login if session creation fails
|
||||
logger.error(f"Failed to create session for {user.email}: {str(session_err)}", exc_info=True)
|
||||
logger.exception("Failed to create session for %s: %s", user.email, session_err)
|
||||
|
||||
|
||||
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED, operation_id="register")
|
||||
@router.post(
|
||||
"/register",
|
||||
response_model=UserResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
operation_id="register",
|
||||
)
|
||||
@limiter.limit(f"{5 * RATE_MULTIPLIER}/minute")
|
||||
async def register_user(
|
||||
request: Request,
|
||||
user_data: UserCreate,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
request: Request, user_data: UserCreate, db: AsyncSession = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Register a new user.
|
||||
@@ -114,27 +124,31 @@ async def register_user(
|
||||
try:
|
||||
user = await AuthService.create_user(db, user_data)
|
||||
return user
|
||||
except AuthenticationError as e:
|
||||
except DuplicateError:
|
||||
# SECURITY: Don't reveal if email exists - generic error message
|
||||
logger.warning(f"Registration failed: {str(e)}")
|
||||
logger.warning("Registration failed: duplicate email %s", user_data.email)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Registration failed. Please check your information and try again."
|
||||
detail="Registration failed. Please check your information and try again.",
|
||||
)
|
||||
except AuthError as e:
|
||||
logger.warning("Registration failed: %s", e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Registration failed. Please check your information and try again.",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during registration: {str(e)}", exc_info=True)
|
||||
logger.exception("Unexpected error during registration: %s", e)
|
||||
raise DatabaseError(
|
||||
message="An unexpected error occurred. Please try again later.",
|
||||
error_code=ErrorCode.INTERNAL_ERROR
|
||||
error_code=ErrorCode.INTERNAL_ERROR,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/login", response_model=Token, operation_id="login")
|
||||
@limiter.limit(f"{10 * RATE_MULTIPLIER}/minute")
|
||||
async def login(
|
||||
request: Request,
|
||||
login_data: LoginRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
request: Request, login_data: LoginRequest, db: AsyncSession = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Login with username and password.
|
||||
@@ -146,14 +160,16 @@ async def login(
|
||||
"""
|
||||
try:
|
||||
# Attempt to authenticate the user
|
||||
user = await AuthService.authenticate_user(db, login_data.email, login_data.password)
|
||||
user = await AuthService.authenticate_user(
|
||||
db, login_data.email, login_data.password
|
||||
)
|
||||
|
||||
# Explicitly check for None result and raise correct exception
|
||||
if user is None:
|
||||
logger.warning(f"Invalid login attempt for: {login_data.email}")
|
||||
logger.warning("Invalid login attempt for: %s", login_data.email)
|
||||
raise AuthError(
|
||||
message="Invalid email or password",
|
||||
error_code=ErrorCode.INVALID_CREDENTIALS
|
||||
error_code=ErrorCode.INVALID_CREDENTIALS,
|
||||
)
|
||||
|
||||
# User is authenticated, generate tokens
|
||||
@@ -166,29 +182,23 @@ async def login(
|
||||
|
||||
except AuthenticationError as e:
|
||||
# Handle specific authentication errors like inactive accounts
|
||||
logger.warning(f"Authentication failed: {str(e)}")
|
||||
raise AuthError(
|
||||
message=str(e),
|
||||
error_code=ErrorCode.INVALID_CREDENTIALS
|
||||
)
|
||||
except AuthError:
|
||||
# Re-raise custom auth exceptions without modification
|
||||
raise
|
||||
logger.warning("Authentication failed: %s", e)
|
||||
raise AuthError(message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS)
|
||||
except Exception as e:
|
||||
# Handle unexpected errors
|
||||
logger.error(f"Unexpected error during login: {str(e)}", exc_info=True)
|
||||
logger.exception("Unexpected error during login: %s", e)
|
||||
raise DatabaseError(
|
||||
message="An unexpected error occurred. Please try again later.",
|
||||
error_code=ErrorCode.INTERNAL_ERROR
|
||||
error_code=ErrorCode.INTERNAL_ERROR,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/login/oauth", response_model=Token, operation_id='login_oauth')
|
||||
@router.post("/login/oauth", response_model=Token, operation_id="login_oauth")
|
||||
@limiter.limit("10/minute")
|
||||
async def login_oauth(
|
||||
request: Request,
|
||||
form_data: OAuth2PasswordRequestForm = Depends(),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
request: Request,
|
||||
form_data: OAuth2PasswordRequestForm = Depends(),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
OAuth2-compatible login endpoint, used by the OpenAPI UI.
|
||||
@@ -199,12 +209,14 @@ async def login_oauth(
|
||||
Access and refresh tokens.
|
||||
"""
|
||||
try:
|
||||
user = await AuthService.authenticate_user(db, form_data.username, form_data.password)
|
||||
user = await AuthService.authenticate_user(
|
||||
db, form_data.username, form_data.password
|
||||
)
|
||||
|
||||
if user is None:
|
||||
raise AuthError(
|
||||
message="Invalid email or password",
|
||||
error_code=ErrorCode.INVALID_CREDENTIALS
|
||||
error_code=ErrorCode.INVALID_CREDENTIALS,
|
||||
)
|
||||
|
||||
# Generate tokens
|
||||
@@ -216,28 +228,22 @@ async def login_oauth(
|
||||
# Return full token response with user data
|
||||
return tokens
|
||||
except AuthenticationError as e:
|
||||
logger.warning(f"OAuth authentication failed: {str(e)}")
|
||||
raise AuthError(
|
||||
message=str(e),
|
||||
error_code=ErrorCode.INVALID_CREDENTIALS
|
||||
)
|
||||
except AuthError:
|
||||
# Re-raise custom auth exceptions without modification
|
||||
raise
|
||||
logger.warning("OAuth authentication failed: %s", e)
|
||||
raise AuthError(message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during OAuth login: {str(e)}", exc_info=True)
|
||||
logger.exception("Unexpected error during OAuth login: %s", e)
|
||||
raise DatabaseError(
|
||||
message="An unexpected error occurred. Please try again later.",
|
||||
error_code=ErrorCode.INTERNAL_ERROR
|
||||
error_code=ErrorCode.INTERNAL_ERROR,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=Token, operation_id="refresh_token")
|
||||
@limiter.limit("30/minute")
|
||||
async def refresh_token(
|
||||
request: Request,
|
||||
refresh_data: RefreshTokenRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
request: Request,
|
||||
refresh_data: RefreshTokenRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Refresh access token using a refresh token.
|
||||
@@ -249,13 +255,18 @@ async def refresh_token(
|
||||
"""
|
||||
try:
|
||||
# Decode the refresh token to get the JTI
|
||||
refresh_payload = decode_token(refresh_data.refresh_token, verify_type="refresh")
|
||||
refresh_payload = decode_token(
|
||||
refresh_data.refresh_token, verify_type="refresh"
|
||||
)
|
||||
|
||||
# Check if session exists and is active
|
||||
session = await session_crud.get_active_by_jti(db, jti=refresh_payload.jti)
|
||||
session = await session_service.get_active_by_jti(db, jti=refresh_payload.jti)
|
||||
|
||||
if not session:
|
||||
logger.warning(f"Refresh token used for inactive or non-existent session: {refresh_payload.jti}")
|
||||
logger.warning(
|
||||
"Refresh token used for inactive or non-existent session: %s",
|
||||
refresh_payload.jti,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Session has been revoked. Please log in again.",
|
||||
@@ -270,14 +281,14 @@ async def refresh_token(
|
||||
|
||||
# Update session with new refresh token JTI and expiration
|
||||
try:
|
||||
await session_crud.update_refresh_token(
|
||||
await session_service.update_refresh_token(
|
||||
db,
|
||||
session=session,
|
||||
new_jti=new_refresh_payload.jti,
|
||||
new_expires_at=datetime.fromtimestamp(new_refresh_payload.exp, tz=timezone.utc)
|
||||
new_expires_at=datetime.fromtimestamp(new_refresh_payload.exp, tz=UTC),
|
||||
)
|
||||
except Exception as session_err:
|
||||
logger.error(f"Failed to update session {session.id}: {str(session_err)}", exc_info=True)
|
||||
logger.exception("Failed to update session %s: %s", session.id, session_err)
|
||||
# Continue anyway - tokens are already issued
|
||||
|
||||
return tokens
|
||||
@@ -300,10 +311,10 @@ async def refresh_token(
|
||||
# Re-raise HTTP exceptions (like session revoked)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during token refresh: {str(e)}")
|
||||
logger.error("Unexpected error during token refresh: %s", e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="An unexpected error occurred. Please try again later."
|
||||
detail="An unexpected error occurred. Please try again later.",
|
||||
)
|
||||
|
||||
|
||||
@@ -320,13 +331,13 @@ async def refresh_token(
|
||||
|
||||
**Rate Limit**: 3 requests/minute
|
||||
""",
|
||||
operation_id="request_password_reset"
|
||||
operation_id="request_password_reset",
|
||||
)
|
||||
@limiter.limit("3/minute")
|
||||
async def request_password_reset(
|
||||
request: Request,
|
||||
reset_request: PasswordResetRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Request a password reset.
|
||||
@@ -336,7 +347,7 @@ async def request_password_reset(
|
||||
"""
|
||||
try:
|
||||
# Look up user by email
|
||||
user = await user_crud.get_by_email(db, email=reset_request.email)
|
||||
user = await user_service.get_by_email(db, email=reset_request.email)
|
||||
|
||||
# Only send email if user exists and is active
|
||||
if user and user.is_active:
|
||||
@@ -345,26 +356,27 @@ async def request_password_reset(
|
||||
|
||||
# Send password reset email
|
||||
await email_service.send_password_reset_email(
|
||||
to_email=user.email,
|
||||
reset_token=reset_token,
|
||||
user_name=user.first_name
|
||||
to_email=user.email, reset_token=reset_token, user_name=user.first_name
|
||||
)
|
||||
logger.info(f"Password reset requested for {user.email}")
|
||||
logger.info("Password reset requested for %s", user.email)
|
||||
else:
|
||||
# Log attempt but don't reveal if email exists
|
||||
logger.warning(f"Password reset requested for non-existent or inactive email: {reset_request.email}")
|
||||
logger.warning(
|
||||
"Password reset requested for non-existent or inactive email: %s",
|
||||
reset_request.email,
|
||||
)
|
||||
|
||||
# Always return success to prevent email enumeration
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message="If your email is registered, you will receive a password reset link shortly"
|
||||
message="If your email is registered, you will receive a password reset link shortly",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing password reset request: {str(e)}", exc_info=True)
|
||||
logger.exception("Error processing password reset request: %s", e)
|
||||
# Still return success to prevent information leakage
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message="If your email is registered, you will receive a password reset link shortly"
|
||||
message="If your email is registered, you will receive a password reset link shortly",
|
||||
)
|
||||
|
||||
|
||||
@@ -378,13 +390,13 @@ async def request_password_reset(
|
||||
|
||||
**Rate Limit**: 5 requests/minute
|
||||
""",
|
||||
operation_id="confirm_password_reset"
|
||||
operation_id="confirm_password_reset",
|
||||
)
|
||||
@limiter.limit("5/minute")
|
||||
async def confirm_password_reset(
|
||||
request: Request,
|
||||
reset_confirm: PasswordResetConfirm,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Confirm password reset with token.
|
||||
@@ -398,55 +410,52 @@ async def confirm_password_reset(
|
||||
if not email:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid or expired password reset token"
|
||||
detail="Invalid or expired password reset token",
|
||||
)
|
||||
|
||||
# Look up user
|
||||
user = await user_crud.get_by_email(db, email=email)
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
# Reset password via service (validates user exists and is active)
|
||||
try:
|
||||
user = await AuthService.reset_password(
|
||||
db, email=email, new_password=reset_confirm.new_password
|
||||
)
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="User account is inactive"
|
||||
)
|
||||
|
||||
# Update password
|
||||
user.password_hash = get_password_hash(reset_confirm.new_password)
|
||||
db.add(user)
|
||||
await db.commit()
|
||||
except AuthenticationError as e:
|
||||
err_msg = str(e)
|
||||
if "inactive" in err_msg.lower():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail=err_msg
|
||||
)
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=err_msg)
|
||||
|
||||
# SECURITY: Invalidate all existing sessions after password reset
|
||||
# This prevents stolen sessions from being used after password change
|
||||
from app.crud.session import session as session_crud
|
||||
try:
|
||||
deactivated_count = await session_crud.deactivate_all_user_sessions(
|
||||
db,
|
||||
user_id=str(user.id)
|
||||
deactivated_count = await session_service.deactivate_all_user_sessions(
|
||||
db, user_id=str(user.id)
|
||||
)
|
||||
logger.info(
|
||||
"Password reset successful for %s, invalidated %s sessions",
|
||||
user.email,
|
||||
deactivated_count,
|
||||
)
|
||||
logger.info(f"Password reset successful for {user.email}, invalidated {deactivated_count} sessions")
|
||||
except Exception as session_error:
|
||||
# Log but don't fail password reset if session invalidation fails
|
||||
logger.error(f"Failed to invalidate sessions after password reset: {str(session_error)}")
|
||||
logger.error(
|
||||
"Failed to invalidate sessions after password reset: %s", session_error
|
||||
)
|
||||
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message="Password has been reset successfully. All devices have been logged out for security. You can now log in with your new password."
|
||||
message="Password has been reset successfully. All devices have been logged out for security. You can now log in with your new password.",
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error confirming password reset: {str(e)}", exc_info=True)
|
||||
logger.exception("Error confirming password reset: %s", e)
|
||||
await db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="An error occurred while resetting your password"
|
||||
detail="An error occurred while resetting your password",
|
||||
)
|
||||
|
||||
|
||||
@@ -464,14 +473,14 @@ async def confirm_password_reset(
|
||||
|
||||
**Rate Limit**: 10 requests/minute
|
||||
""",
|
||||
operation_id="logout"
|
||||
operation_id="logout",
|
||||
)
|
||||
@limiter.limit("10/minute")
|
||||
async def logout(
|
||||
request: Request,
|
||||
logout_request: LogoutRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Logout from current device by deactivating the session.
|
||||
@@ -487,57 +496,57 @@ async def logout(
|
||||
try:
|
||||
# Decode refresh token to get JTI
|
||||
try:
|
||||
refresh_payload = decode_token(logout_request.refresh_token, verify_type="refresh")
|
||||
refresh_payload = decode_token(
|
||||
logout_request.refresh_token, verify_type="refresh"
|
||||
)
|
||||
except (TokenExpiredError, TokenInvalidError) as e:
|
||||
# Even if token is expired/invalid, try to deactivate session
|
||||
logger.warning(f"Logout with invalid/expired token: {str(e)}")
|
||||
logger.warning("Logout with invalid/expired token: %s", e)
|
||||
# Don't fail - return success anyway
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message="Logged out successfully"
|
||||
)
|
||||
return MessageResponse(success=True, message="Logged out successfully")
|
||||
|
||||
# Find the session by JTI
|
||||
session = await session_crud.get_by_jti(db, jti=refresh_payload.jti)
|
||||
session = await session_service.get_by_jti(db, jti=refresh_payload.jti)
|
||||
|
||||
if session:
|
||||
# Verify session belongs to current user (security check)
|
||||
if str(session.user_id) != str(current_user.id):
|
||||
logger.warning(
|
||||
f"User {current_user.id} attempted to logout session {session.id} "
|
||||
f"belonging to user {session.user_id}"
|
||||
"User %s attempted to logout session %s belonging to user %s",
|
||||
current_user.id,
|
||||
session.id,
|
||||
session.user_id,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You can only logout your own sessions"
|
||||
detail="You can only logout your own sessions",
|
||||
)
|
||||
|
||||
# Deactivate the session
|
||||
await session_crud.deactivate(db, session_id=str(session.id))
|
||||
await session_service.deactivate(db, session_id=str(session.id))
|
||||
|
||||
logger.info(
|
||||
f"User {current_user.id} logged out from {session.device_name} "
|
||||
f"(session {session.id})"
|
||||
"User %s logged out from %s (session %s)",
|
||||
current_user.id,
|
||||
session.device_name,
|
||||
session.id,
|
||||
)
|
||||
else:
|
||||
# Session not found - maybe already deleted or never existed
|
||||
# Return success anyway (idempotent)
|
||||
logger.info(f"Logout requested for non-existent session (JTI: {refresh_payload.jti})")
|
||||
logger.info(
|
||||
"Logout requested for non-existent session (JTI: %s)",
|
||||
refresh_payload.jti,
|
||||
)
|
||||
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message="Logged out successfully"
|
||||
)
|
||||
return MessageResponse(success=True, message="Logged out successfully")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error during logout for user {current_user.id}: {str(e)}", exc_info=True)
|
||||
logger.exception("Error during logout for user %s: %s", current_user.id, e)
|
||||
# Don't expose error details
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message="Logged out successfully"
|
||||
)
|
||||
return MessageResponse(success=True, message="Logged out successfully")
|
||||
|
||||
|
||||
@router.post(
|
||||
@@ -553,13 +562,13 @@ async def logout(
|
||||
|
||||
**Rate Limit**: 5 requests/minute
|
||||
""",
|
||||
operation_id="logout_all"
|
||||
operation_id="logout_all",
|
||||
)
|
||||
@limiter.limit("5/minute")
|
||||
async def logout_all(
|
||||
request: Request,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Logout from all devices by deactivating all user sessions.
|
||||
@@ -573,19 +582,23 @@ async def logout_all(
|
||||
"""
|
||||
try:
|
||||
# Deactivate all sessions for this user
|
||||
count = await session_crud.deactivate_all_user_sessions(db, user_id=str(current_user.id))
|
||||
count = await session_service.deactivate_all_user_sessions(
|
||||
db, user_id=str(current_user.id)
|
||||
)
|
||||
|
||||
logger.info(f"User {current_user.id} logged out from all devices ({count} sessions)")
|
||||
logger.info(
|
||||
"User %s logged out from all devices (%s sessions)", current_user.id, count
|
||||
)
|
||||
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message=f"Successfully logged out from all devices ({count} sessions terminated)"
|
||||
message=f"Successfully logged out from all devices ({count} sessions terminated)",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during logout-all for user {current_user.id}: {str(e)}", exc_info=True)
|
||||
logger.exception("Error during logout-all for user %s: %s", current_user.id, e)
|
||||
await db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="An error occurred while logging out"
|
||||
detail="An error occurred while logging out",
|
||||
)
|
||||
|
||||
434
backend/app/api/routes/oauth.py
Normal file
434
backend/app/api/routes/oauth.py
Normal file
@@ -0,0 +1,434 @@
|
||||
# app/api/routes/oauth.py
|
||||
"""
|
||||
OAuth routes for social authentication.
|
||||
|
||||
Endpoints:
|
||||
- GET /oauth/providers - List enabled OAuth providers
|
||||
- GET /oauth/authorize/{provider} - Get authorization URL
|
||||
- POST /oauth/callback/{provider} - Handle OAuth callback
|
||||
- GET /oauth/accounts - List linked OAuth accounts
|
||||
- DELETE /oauth/accounts/{provider} - Unlink an OAuth account
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.dependencies.auth import get_current_user, get_optional_current_user
|
||||
from app.core.auth import decode_token
|
||||
from app.core.config import settings
|
||||
from app.core.database import get_db
|
||||
from app.core.exceptions import AuthenticationError as AuthError
|
||||
from app.models.user import User
|
||||
from app.schemas.oauth import (
|
||||
OAuthAccountsListResponse,
|
||||
OAuthCallbackRequest,
|
||||
OAuthCallbackResponse,
|
||||
OAuthProvidersResponse,
|
||||
OAuthUnlinkResponse,
|
||||
)
|
||||
from app.schemas.sessions import SessionCreate
|
||||
from app.schemas.users import Token
|
||||
from app.services.oauth_service import OAuthService
|
||||
from app.services.session_service import session_service
|
||||
from app.utils.device import extract_device_info
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Initialize limiter for this router
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
# Use higher rate limits in test environment
|
||||
IS_TEST = os.getenv("IS_TEST", "False") == "True"
|
||||
RATE_MULTIPLIER = 100 if IS_TEST else 1
|
||||
|
||||
|
||||
async def _create_oauth_login_session(
|
||||
db: AsyncSession,
|
||||
request: Request,
|
||||
user: User,
|
||||
tokens: Token,
|
||||
provider: str,
|
||||
) -> None:
|
||||
"""
|
||||
Create a session record for successful OAuth login.
|
||||
|
||||
This is a best-effort operation - login succeeds even if session creation fails.
|
||||
"""
|
||||
try:
|
||||
device_info = extract_device_info(request)
|
||||
|
||||
# Decode refresh token to get JTI and expiration
|
||||
refresh_payload = decode_token(tokens.refresh_token, verify_type="refresh")
|
||||
|
||||
session_data = SessionCreate(
|
||||
user_id=user.id,
|
||||
refresh_token_jti=refresh_payload.jti,
|
||||
device_name=device_info.device_name or f"OAuth ({provider})",
|
||||
device_id=device_info.device_id,
|
||||
ip_address=device_info.ip_address,
|
||||
user_agent=device_info.user_agent,
|
||||
last_used_at=datetime.now(UTC),
|
||||
expires_at=datetime.fromtimestamp(refresh_payload.exp, tz=UTC),
|
||||
location_city=device_info.location_city,
|
||||
location_country=device_info.location_country,
|
||||
)
|
||||
|
||||
await session_service.create_session(db, obj_in=session_data)
|
||||
|
||||
logger.info(
|
||||
"OAuth login successful: %s via %s from %s (IP: %s)",
|
||||
user.email,
|
||||
provider,
|
||||
device_info.device_name,
|
||||
device_info.ip_address,
|
||||
)
|
||||
except Exception as session_err:
|
||||
# Log but don't fail login if session creation fails
|
||||
logger.exception(
|
||||
"Failed to create session for OAuth login %s: %s", user.email, session_err
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/providers",
|
||||
response_model=OAuthProvidersResponse,
|
||||
summary="List OAuth Providers",
|
||||
description="""
|
||||
Get list of enabled OAuth providers for the login/register UI.
|
||||
|
||||
Returns:
|
||||
List of enabled providers with display info.
|
||||
""",
|
||||
operation_id="list_oauth_providers",
|
||||
)
|
||||
async def list_providers() -> Any:
|
||||
"""
|
||||
Get list of enabled OAuth providers.
|
||||
|
||||
This endpoint is public (no authentication required) as it's needed
|
||||
for the login/register UI to display available social login options.
|
||||
"""
|
||||
return OAuthService.get_enabled_providers()
|
||||
|
||||
|
||||
@router.get(
|
||||
"/authorize/{provider}",
|
||||
response_model=dict,
|
||||
summary="Get OAuth Authorization URL",
|
||||
description="""
|
||||
Get the authorization URL to redirect the user to the OAuth provider.
|
||||
|
||||
The frontend should redirect the user to the returned URL.
|
||||
After authentication, the provider will redirect back to the callback URL.
|
||||
|
||||
**Rate Limit**: 10 requests/minute
|
||||
""",
|
||||
operation_id="get_oauth_authorization_url",
|
||||
)
|
||||
@limiter.limit(f"{10 * RATE_MULTIPLIER}/minute")
|
||||
async def get_authorization_url(
|
||||
request: Request,
|
||||
provider: str,
|
||||
redirect_uri: str = Query(
|
||||
..., description="Frontend callback URL after OAuth completes"
|
||||
),
|
||||
current_user: User | None = Depends(get_optional_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Get OAuth authorization URL.
|
||||
|
||||
Args:
|
||||
provider: OAuth provider (google, github)
|
||||
redirect_uri: Frontend callback URL
|
||||
current_user: Current user (optional, for account linking)
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
dict with authorization_url and state
|
||||
"""
|
||||
if not settings.OAUTH_ENABLED:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="OAuth is not enabled",
|
||||
)
|
||||
|
||||
try:
|
||||
# If user is logged in, this is an account linking flow
|
||||
user_id = str(current_user.id) if current_user else None
|
||||
|
||||
url, state = await OAuthService.create_authorization_url(
|
||||
db,
|
||||
provider=provider,
|
||||
redirect_uri=redirect_uri,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
return {
|
||||
"authorization_url": url,
|
||||
"state": state,
|
||||
}
|
||||
|
||||
except AuthError as e:
|
||||
logger.warning("OAuth authorization failed: %s", e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("OAuth authorization error: %s", e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to create authorization URL",
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/callback/{provider}",
|
||||
response_model=OAuthCallbackResponse,
|
||||
summary="OAuth Callback",
|
||||
description="""
|
||||
Handle OAuth callback from provider.
|
||||
|
||||
The frontend should call this endpoint with the code and state
|
||||
parameters received from the OAuth provider redirect.
|
||||
|
||||
Returns:
|
||||
JWT tokens for the authenticated user.
|
||||
|
||||
**Rate Limit**: 10 requests/minute
|
||||
""",
|
||||
operation_id="handle_oauth_callback",
|
||||
)
|
||||
@limiter.limit(f"{10 * RATE_MULTIPLIER}/minute")
|
||||
async def handle_callback(
|
||||
request: Request,
|
||||
provider: str,
|
||||
callback_data: OAuthCallbackRequest,
|
||||
redirect_uri: str = Query(
|
||||
..., description="Must match the redirect_uri used in authorization"
|
||||
),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Handle OAuth callback.
|
||||
|
||||
Args:
|
||||
provider: OAuth provider (google, github)
|
||||
callback_data: Code and state from provider
|
||||
redirect_uri: Original redirect URI (for validation)
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
OAuthCallbackResponse with tokens
|
||||
"""
|
||||
if not settings.OAUTH_ENABLED:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="OAuth is not enabled",
|
||||
)
|
||||
|
||||
try:
|
||||
result = await OAuthService.handle_callback(
|
||||
db,
|
||||
code=callback_data.code,
|
||||
state=callback_data.state,
|
||||
redirect_uri=redirect_uri,
|
||||
)
|
||||
|
||||
# Create session for the login (need to get the user first)
|
||||
# Note: This requires fetching the user from the token
|
||||
# For now, we skip session creation here as the result doesn't include user info
|
||||
# The session will be created on next request if needed
|
||||
|
||||
return result
|
||||
|
||||
except AuthError as e:
|
||||
logger.warning("OAuth callback failed: %s", e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=str(e),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("OAuth callback error: %s", e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="OAuth authentication failed",
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/accounts",
|
||||
response_model=OAuthAccountsListResponse,
|
||||
summary="List Linked OAuth Accounts",
|
||||
description="""
|
||||
Get list of OAuth accounts linked to the current user.
|
||||
|
||||
Requires authentication.
|
||||
""",
|
||||
operation_id="list_oauth_accounts",
|
||||
)
|
||||
async def list_accounts(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
List OAuth accounts linked to the current user.
|
||||
|
||||
Args:
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
List of linked OAuth accounts
|
||||
"""
|
||||
accounts = await OAuthService.get_user_accounts(db, user_id=current_user.id)
|
||||
return OAuthAccountsListResponse(accounts=accounts)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/accounts/{provider}",
|
||||
response_model=OAuthUnlinkResponse,
|
||||
summary="Unlink OAuth Account",
|
||||
description="""
|
||||
Unlink an OAuth provider from the current user.
|
||||
|
||||
The user must have either a password set or another OAuth provider
|
||||
linked to ensure they can still log in.
|
||||
|
||||
**Rate Limit**: 5 requests/minute
|
||||
""",
|
||||
operation_id="unlink_oauth_account",
|
||||
)
|
||||
@limiter.limit(f"{5 * RATE_MULTIPLIER}/minute")
|
||||
async def unlink_account(
|
||||
request: Request,
|
||||
provider: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Unlink an OAuth provider from the current user.
|
||||
|
||||
Args:
|
||||
provider: Provider to unlink (google, github)
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Success message
|
||||
"""
|
||||
try:
|
||||
await OAuthService.unlink_provider(
|
||||
db,
|
||||
user=current_user,
|
||||
provider=provider,
|
||||
)
|
||||
|
||||
return OAuthUnlinkResponse(
|
||||
success=True,
|
||||
message=f"{provider.capitalize()} account unlinked successfully",
|
||||
)
|
||||
|
||||
except AuthError as e:
|
||||
logger.warning("OAuth unlink failed for %s: %s", current_user.email, e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("OAuth unlink error: %s", e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to unlink OAuth account",
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/link/{provider}",
|
||||
response_model=dict,
|
||||
summary="Start Account Linking",
|
||||
description="""
|
||||
Start the OAuth flow to link a new provider to the current user.
|
||||
|
||||
This is a convenience endpoint that redirects to /authorize/{provider}
|
||||
with the current user context.
|
||||
|
||||
**Rate Limit**: 10 requests/minute
|
||||
""",
|
||||
operation_id="start_oauth_link",
|
||||
)
|
||||
@limiter.limit(f"{10 * RATE_MULTIPLIER}/minute")
|
||||
async def start_link(
|
||||
request: Request,
|
||||
provider: str,
|
||||
redirect_uri: str = Query(
|
||||
..., description="Frontend callback URL after OAuth completes"
|
||||
),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Start OAuth account linking flow.
|
||||
|
||||
This endpoint requires authentication and will initiate an OAuth flow
|
||||
to link a new provider to the current user's account.
|
||||
|
||||
Args:
|
||||
provider: OAuth provider to link (google, github)
|
||||
redirect_uri: Frontend callback URL
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
dict with authorization_url and state
|
||||
"""
|
||||
if not settings.OAUTH_ENABLED:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="OAuth is not enabled",
|
||||
)
|
||||
|
||||
# Check if user already has this provider linked
|
||||
existing = await OAuthService.get_user_account_by_provider(
|
||||
db, user_id=current_user.id, provider=provider
|
||||
)
|
||||
if existing:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"You already have a {provider} account linked",
|
||||
)
|
||||
|
||||
try:
|
||||
url, state = await OAuthService.create_authorization_url(
|
||||
db,
|
||||
provider=provider,
|
||||
redirect_uri=redirect_uri,
|
||||
user_id=str(current_user.id),
|
||||
)
|
||||
|
||||
return {
|
||||
"authorization_url": url,
|
||||
"state": state,
|
||||
}
|
||||
|
||||
except AuthError as e:
|
||||
logger.warning("OAuth link authorization failed: %s", e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("OAuth link error: %s", e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to create authorization URL",
|
||||
)
|
||||
824
backend/app/api/routes/oauth_provider.py
Normal file
824
backend/app/api/routes/oauth_provider.py
Normal file
@@ -0,0 +1,824 @@
|
||||
# app/api/routes/oauth_provider.py
|
||||
"""
|
||||
OAuth Provider routes (Authorization Server mode) for MCP integration.
|
||||
|
||||
Implements OAuth 2.0 Authorization Server endpoints:
|
||||
- GET /.well-known/oauth-authorization-server - Server metadata (RFC 8414)
|
||||
- GET /oauth/provider/authorize - Authorization endpoint
|
||||
- POST /oauth/provider/token - Token endpoint
|
||||
- POST /oauth/provider/revoke - Token revocation (RFC 7009)
|
||||
- POST /oauth/provider/introspect - Token introspection (RFC 7662)
|
||||
- Client management endpoints
|
||||
|
||||
Security features:
|
||||
- PKCE required for public clients (S256)
|
||||
- CSRF protection via state parameter
|
||||
- Secure token handling
|
||||
- Rate limiting on sensitive endpoints
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from fastapi import APIRouter, Depends, Form, HTTPException, Query, Request, status
|
||||
from fastapi.responses import RedirectResponse
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.dependencies.auth import (
|
||||
get_current_active_user,
|
||||
get_current_superuser,
|
||||
get_optional_current_user,
|
||||
)
|
||||
from app.core.config import settings
|
||||
from app.core.database import get_db
|
||||
from app.models.user import User
|
||||
from app.schemas.oauth import (
|
||||
OAuthClientCreate,
|
||||
OAuthClientResponse,
|
||||
OAuthServerMetadata,
|
||||
OAuthTokenIntrospectionResponse,
|
||||
OAuthTokenResponse,
|
||||
)
|
||||
from app.services import oauth_provider_service as provider_service
|
||||
|
||||
router = APIRouter()
|
||||
# Separate router for RFC 8414 well-known endpoint (registered at root level)
|
||||
wellknown_router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
|
||||
def require_provider_enabled():
|
||||
"""Dependency to check if OAuth provider mode is enabled."""
|
||||
if not settings.OAUTH_PROVIDER_ENABLED:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="OAuth provider mode is not enabled. Set OAUTH_PROVIDER_ENABLED=true",
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Server Metadata (RFC 8414)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@wellknown_router.get(
|
||||
"/.well-known/oauth-authorization-server",
|
||||
response_model=OAuthServerMetadata,
|
||||
summary="OAuth Server Metadata",
|
||||
description="""
|
||||
OAuth 2.0 Authorization Server Metadata (RFC 8414).
|
||||
|
||||
Returns server metadata including supported endpoints, scopes,
|
||||
and capabilities. MCP clients use this to discover the server.
|
||||
|
||||
Note: This endpoint is at the root level per RFC 8414.
|
||||
""",
|
||||
operation_id="get_oauth_server_metadata",
|
||||
tags=["OAuth Provider"],
|
||||
)
|
||||
async def get_server_metadata(
|
||||
_: None = Depends(require_provider_enabled),
|
||||
) -> OAuthServerMetadata:
|
||||
"""Get OAuth 2.0 server metadata."""
|
||||
base_url = settings.OAUTH_ISSUER.rstrip("/")
|
||||
|
||||
return OAuthServerMetadata(
|
||||
issuer=base_url,
|
||||
authorization_endpoint=f"{base_url}/api/v1/oauth/provider/authorize",
|
||||
token_endpoint=f"{base_url}/api/v1/oauth/provider/token",
|
||||
revocation_endpoint=f"{base_url}/api/v1/oauth/provider/revoke",
|
||||
introspection_endpoint=f"{base_url}/api/v1/oauth/provider/introspect",
|
||||
registration_endpoint=None, # Dynamic registration not supported
|
||||
scopes_supported=[
|
||||
"openid",
|
||||
"profile",
|
||||
"email",
|
||||
"read:users",
|
||||
"write:users",
|
||||
"read:organizations",
|
||||
"write:organizations",
|
||||
"admin",
|
||||
],
|
||||
response_types_supported=["code"],
|
||||
grant_types_supported=["authorization_code", "refresh_token"],
|
||||
code_challenge_methods_supported=["S256"],
|
||||
token_endpoint_auth_methods_supported=[
|
||||
"client_secret_basic",
|
||||
"client_secret_post",
|
||||
"none", # For public clients with PKCE
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Authorization Endpoint
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.get(
|
||||
"/provider/authorize",
|
||||
summary="Authorization Endpoint",
|
||||
description="""
|
||||
OAuth 2.0 Authorization Endpoint.
|
||||
|
||||
Initiates the authorization code flow:
|
||||
1. Validates client and parameters
|
||||
2. Checks if user is authenticated (redirects to login if not)
|
||||
3. Checks existing consent
|
||||
4. Redirects to consent page if needed
|
||||
5. Issues authorization code and redirects back to client
|
||||
|
||||
Required parameters:
|
||||
- response_type: Must be "code"
|
||||
- client_id: Registered client ID
|
||||
- redirect_uri: Must match registered URI
|
||||
|
||||
Recommended parameters:
|
||||
- state: CSRF protection
|
||||
- code_challenge + code_challenge_method: PKCE (required for public clients)
|
||||
- scope: Requested permissions
|
||||
""",
|
||||
operation_id="oauth_provider_authorize",
|
||||
tags=["OAuth Provider"],
|
||||
)
|
||||
@limiter.limit("30/minute")
|
||||
async def authorize(
|
||||
request: Request,
|
||||
response_type: str = Query(..., description="Must be 'code'"),
|
||||
client_id: str = Query(..., description="OAuth client ID"),
|
||||
redirect_uri: str = Query(..., description="Redirect URI"),
|
||||
scope: str = Query(default="", description="Requested scopes (space-separated)"),
|
||||
state: str = Query(default="", description="CSRF state parameter"),
|
||||
code_challenge: str | None = Query(default=None, description="PKCE code challenge"),
|
||||
code_challenge_method: str | None = Query(
|
||||
default=None, description="PKCE method (S256)"
|
||||
),
|
||||
nonce: str | None = Query(default=None, description="OpenID Connect nonce"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_: None = Depends(require_provider_enabled),
|
||||
current_user: User | None = Depends(get_optional_current_user),
|
||||
) -> Any:
|
||||
"""
|
||||
Authorization endpoint - initiates OAuth flow.
|
||||
|
||||
If user is not authenticated, redirects to login with return URL.
|
||||
If user has not consented, redirects to consent page.
|
||||
If all checks pass, generates code and redirects to client.
|
||||
"""
|
||||
# Validate response_type
|
||||
if response_type != "code":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="invalid_request: response_type must be 'code'",
|
||||
)
|
||||
|
||||
# Validate PKCE method if provided - ONLY S256 is allowed (RFC 7636 Section 4.3)
|
||||
# "plain" method provides no security benefit and MUST NOT be used
|
||||
if code_challenge_method and code_challenge_method != "S256":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="invalid_request: code_challenge_method must be 'S256' (plain is not supported)",
|
||||
)
|
||||
|
||||
# Validate client
|
||||
try:
|
||||
client = await provider_service.get_client(db, client_id)
|
||||
if not client:
|
||||
raise provider_service.InvalidClientError("Unknown client_id")
|
||||
provider_service.validate_redirect_uri(client, redirect_uri)
|
||||
except provider_service.OAuthProviderError as e:
|
||||
# For client/redirect errors, we can't safely redirect - show error
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"{e.error}: {e.error_description}",
|
||||
)
|
||||
|
||||
# Validate and filter scopes
|
||||
try:
|
||||
requested_scopes = provider_service.parse_scope(scope)
|
||||
valid_scopes = provider_service.validate_scopes(client, requested_scopes)
|
||||
except provider_service.InvalidScopeError as e:
|
||||
# Redirect with error
|
||||
scope_error_params: dict[str, str] = {"error": e.error}
|
||||
if e.error_description:
|
||||
scope_error_params["error_description"] = e.error_description
|
||||
if state:
|
||||
scope_error_params["state"] = state
|
||||
return RedirectResponse(
|
||||
url=f"{redirect_uri}?{urlencode(scope_error_params)}",
|
||||
status_code=status.HTTP_302_FOUND,
|
||||
)
|
||||
|
||||
# Public clients MUST use PKCE
|
||||
if client.client_type == "public":
|
||||
if not code_challenge or code_challenge_method != "S256":
|
||||
pkce_error_params: dict[str, str] = {
|
||||
"error": "invalid_request",
|
||||
"error_description": "PKCE with S256 is required for public clients",
|
||||
}
|
||||
if state:
|
||||
pkce_error_params["state"] = state
|
||||
return RedirectResponse(
|
||||
url=f"{redirect_uri}?{urlencode(pkce_error_params)}",
|
||||
status_code=status.HTTP_302_FOUND,
|
||||
)
|
||||
|
||||
# If user is not authenticated, redirect to login
|
||||
if not current_user:
|
||||
# Store authorization request in session and redirect to login
|
||||
# The frontend will handle the return URL
|
||||
login_url = f"{settings.FRONTEND_URL}/login"
|
||||
return_params = urlencode(
|
||||
{
|
||||
"oauth_authorize": "true",
|
||||
"client_id": client_id,
|
||||
"redirect_uri": redirect_uri,
|
||||
"scope": " ".join(valid_scopes),
|
||||
"state": state,
|
||||
"code_challenge": code_challenge or "",
|
||||
"code_challenge_method": code_challenge_method or "",
|
||||
"nonce": nonce or "",
|
||||
}
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{login_url}?return_to=/auth/consent?{return_params}",
|
||||
status_code=status.HTTP_302_FOUND,
|
||||
)
|
||||
|
||||
# Check if user has already consented
|
||||
has_consent = await provider_service.check_consent(
|
||||
db, current_user.id, client_id, valid_scopes
|
||||
)
|
||||
|
||||
if not has_consent:
|
||||
# Redirect to consent page
|
||||
consent_params = urlencode(
|
||||
{
|
||||
"client_id": client_id,
|
||||
"client_name": client.client_name,
|
||||
"redirect_uri": redirect_uri,
|
||||
"scope": " ".join(valid_scopes),
|
||||
"state": state,
|
||||
"code_challenge": code_challenge or "",
|
||||
"code_challenge_method": code_challenge_method or "",
|
||||
"nonce": nonce or "",
|
||||
}
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{settings.FRONTEND_URL}/auth/consent?{consent_params}",
|
||||
status_code=status.HTTP_302_FOUND,
|
||||
)
|
||||
|
||||
# User is authenticated and has consented - issue authorization code
|
||||
try:
|
||||
code = await provider_service.create_authorization_code(
|
||||
db=db,
|
||||
client=client,
|
||||
user=current_user,
|
||||
redirect_uri=redirect_uri,
|
||||
scope=" ".join(valid_scopes),
|
||||
code_challenge=code_challenge,
|
||||
code_challenge_method=code_challenge_method,
|
||||
state=state,
|
||||
nonce=nonce,
|
||||
)
|
||||
except provider_service.OAuthProviderError as e:
|
||||
error_params: dict[str, str] = {"error": e.error}
|
||||
if e.error_description:
|
||||
error_params["error_description"] = e.error_description
|
||||
if state:
|
||||
error_params["state"] = state
|
||||
return RedirectResponse(
|
||||
url=f"{redirect_uri}?{urlencode(error_params)}",
|
||||
status_code=status.HTTP_302_FOUND,
|
||||
)
|
||||
|
||||
# Success - redirect with code
|
||||
success_params = {"code": code}
|
||||
if state:
|
||||
success_params["state"] = state
|
||||
|
||||
return RedirectResponse(
|
||||
url=f"{redirect_uri}?{urlencode(success_params)}",
|
||||
status_code=status.HTTP_302_FOUND,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/provider/authorize/consent",
|
||||
summary="Submit Authorization Consent",
|
||||
description="""
|
||||
Submit user consent for OAuth authorization.
|
||||
|
||||
Called by the consent page after user approves or denies.
|
||||
""",
|
||||
operation_id="oauth_provider_consent",
|
||||
tags=["OAuth Provider"],
|
||||
)
|
||||
@limiter.limit("30/minute")
|
||||
async def submit_consent(
|
||||
request: Request,
|
||||
approved: bool = Form(..., description="Whether user approved"),
|
||||
client_id: str = Form(..., description="OAuth client ID"),
|
||||
redirect_uri: str = Form(..., description="Redirect URI"),
|
||||
scope: str = Form(default="", description="Granted scopes"),
|
||||
state: str = Form(default="", description="CSRF state parameter"),
|
||||
code_challenge: str | None = Form(default=None),
|
||||
code_challenge_method: str | None = Form(default=None),
|
||||
nonce: str | None = Form(default=None),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_: None = Depends(require_provider_enabled),
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
) -> Any:
|
||||
"""Process consent form submission."""
|
||||
# Validate client
|
||||
try:
|
||||
client = await provider_service.get_client(db, client_id)
|
||||
if not client:
|
||||
raise provider_service.InvalidClientError("Unknown client_id")
|
||||
provider_service.validate_redirect_uri(client, redirect_uri)
|
||||
except provider_service.OAuthProviderError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"{e.error}: {e.error_description}",
|
||||
)
|
||||
|
||||
# If user denied, redirect with error
|
||||
if not approved:
|
||||
denied_params: dict[str, str] = {
|
||||
"error": "access_denied",
|
||||
"error_description": "User denied authorization",
|
||||
}
|
||||
if state:
|
||||
denied_params["state"] = state
|
||||
return RedirectResponse(
|
||||
url=f"{redirect_uri}?{urlencode(denied_params)}",
|
||||
status_code=status.HTTP_302_FOUND,
|
||||
)
|
||||
|
||||
# Parse and validate scopes
|
||||
granted_scopes = provider_service.parse_scope(scope)
|
||||
valid_scopes = provider_service.validate_scopes(client, granted_scopes)
|
||||
|
||||
# Record consent
|
||||
await provider_service.grant_consent(db, current_user.id, client_id, valid_scopes)
|
||||
|
||||
# Generate authorization code
|
||||
try:
|
||||
code = await provider_service.create_authorization_code(
|
||||
db=db,
|
||||
client=client,
|
||||
user=current_user,
|
||||
redirect_uri=redirect_uri,
|
||||
scope=" ".join(valid_scopes),
|
||||
code_challenge=code_challenge,
|
||||
code_challenge_method=code_challenge_method,
|
||||
state=state,
|
||||
nonce=nonce,
|
||||
)
|
||||
except provider_service.OAuthProviderError as e:
|
||||
error_params: dict[str, str] = {"error": e.error}
|
||||
if e.error_description:
|
||||
error_params["error_description"] = e.error_description
|
||||
if state:
|
||||
error_params["state"] = state
|
||||
return RedirectResponse(
|
||||
url=f"{redirect_uri}?{urlencode(error_params)}",
|
||||
status_code=status.HTTP_302_FOUND,
|
||||
)
|
||||
|
||||
# Success
|
||||
success_params = {"code": code}
|
||||
if state:
|
||||
success_params["state"] = state
|
||||
|
||||
return RedirectResponse(
|
||||
url=f"{redirect_uri}?{urlencode(success_params)}",
|
||||
status_code=status.HTTP_302_FOUND,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Token Endpoint
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.post(
|
||||
"/provider/token",
|
||||
response_model=OAuthTokenResponse,
|
||||
summary="Token Endpoint",
|
||||
description="""
|
||||
OAuth 2.0 Token Endpoint.
|
||||
|
||||
Supports:
|
||||
- authorization_code: Exchange code for tokens
|
||||
- refresh_token: Refresh access token
|
||||
|
||||
Client authentication:
|
||||
- Confidential clients: client_secret (Basic auth or POST body)
|
||||
- Public clients: No secret, but PKCE code_verifier required
|
||||
""",
|
||||
operation_id="oauth_provider_token",
|
||||
tags=["OAuth Provider"],
|
||||
)
|
||||
@limiter.limit("60/minute")
|
||||
async def token(
|
||||
request: Request,
|
||||
grant_type: str = Form(..., description="Grant type"),
|
||||
code: str | None = Form(default=None, description="Authorization code"),
|
||||
redirect_uri: str | None = Form(default=None, description="Redirect URI"),
|
||||
client_id: str | None = Form(default=None, description="Client ID"),
|
||||
client_secret: str | None = Form(default=None, description="Client secret"),
|
||||
code_verifier: str | None = Form(default=None, description="PKCE code verifier"),
|
||||
refresh_token: str | None = Form(default=None, description="Refresh token"),
|
||||
scope: str | None = Form(default=None, description="Scope (for refresh)"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_: None = Depends(require_provider_enabled),
|
||||
) -> OAuthTokenResponse:
|
||||
"""Token endpoint - exchange code for tokens or refresh."""
|
||||
# Extract client credentials from Basic auth if not in body
|
||||
if not client_id:
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if auth_header.startswith("Basic "):
|
||||
import base64
|
||||
|
||||
try:
|
||||
decoded = base64.b64decode(auth_header[6:]).decode()
|
||||
client_id, client_secret = decoded.split(":", 1)
|
||||
except Exception as e:
|
||||
# Log malformed Basic auth for security monitoring
|
||||
logger.warning(
|
||||
"Malformed Basic auth header in token request: %s", type(e).__name__
|
||||
)
|
||||
# Fall back to form body
|
||||
|
||||
if not client_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="invalid_client: client_id required",
|
||||
headers={"WWW-Authenticate": "Basic"},
|
||||
)
|
||||
|
||||
# Get device info
|
||||
device_info = request.headers.get("User-Agent", "")[:500]
|
||||
ip_address = get_remote_address(request)
|
||||
|
||||
try:
|
||||
if grant_type == "authorization_code":
|
||||
if not code:
|
||||
raise provider_service.InvalidRequestError("code required")
|
||||
if not redirect_uri:
|
||||
raise provider_service.InvalidRequestError("redirect_uri required")
|
||||
|
||||
result = await provider_service.exchange_authorization_code(
|
||||
db=db,
|
||||
code=code,
|
||||
client_id=client_id,
|
||||
redirect_uri=redirect_uri,
|
||||
code_verifier=code_verifier,
|
||||
client_secret=client_secret,
|
||||
device_info=device_info,
|
||||
ip_address=ip_address,
|
||||
)
|
||||
|
||||
elif grant_type == "refresh_token":
|
||||
if not refresh_token:
|
||||
raise provider_service.InvalidRequestError("refresh_token required")
|
||||
|
||||
result = await provider_service.refresh_tokens(
|
||||
db=db,
|
||||
refresh_token=refresh_token,
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
scope=scope,
|
||||
device_info=device_info,
|
||||
ip_address=ip_address,
|
||||
)
|
||||
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="unsupported_grant_type: Must be authorization_code or refresh_token",
|
||||
)
|
||||
|
||||
return OAuthTokenResponse(**result)
|
||||
|
||||
except provider_service.InvalidClientError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=f"{e.error}: {e.error_description}",
|
||||
headers={"WWW-Authenticate": "Basic"},
|
||||
)
|
||||
except provider_service.OAuthProviderError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"{e.error}: {e.error_description}",
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Token Revocation (RFC 7009)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.post(
|
||||
"/provider/revoke",
|
||||
status_code=status.HTTP_200_OK,
|
||||
summary="Token Revocation Endpoint",
|
||||
description="""
|
||||
OAuth 2.0 Token Revocation Endpoint (RFC 7009).
|
||||
|
||||
Revokes an access token or refresh token.
|
||||
Always returns 200 OK (even if token is invalid) per spec.
|
||||
""",
|
||||
operation_id="oauth_provider_revoke",
|
||||
tags=["OAuth Provider"],
|
||||
)
|
||||
@limiter.limit("30/minute")
|
||||
async def revoke(
|
||||
request: Request,
|
||||
token: str = Form(..., description="Token to revoke"),
|
||||
token_type_hint: str | None = Form(
|
||||
default=None, description="Token type hint (access_token, refresh_token)"
|
||||
),
|
||||
client_id: str | None = Form(default=None, description="Client ID"),
|
||||
client_secret: str | None = Form(default=None, description="Client secret"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_: None = Depends(require_provider_enabled),
|
||||
) -> dict[str, str]:
|
||||
"""Revoke a token."""
|
||||
# Extract client credentials from Basic auth if not in body
|
||||
if not client_id:
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if auth_header.startswith("Basic "):
|
||||
import base64
|
||||
|
||||
try:
|
||||
decoded = base64.b64decode(auth_header[6:]).decode()
|
||||
client_id, client_secret = decoded.split(":", 1)
|
||||
except Exception as e:
|
||||
# Log malformed Basic auth for security monitoring
|
||||
logger.warning(
|
||||
"Malformed Basic auth header in revoke request: %s",
|
||||
type(e).__name__,
|
||||
)
|
||||
# Fall back to form body
|
||||
|
||||
try:
|
||||
await provider_service.revoke_token(
|
||||
db=db,
|
||||
token=token,
|
||||
token_type_hint=token_type_hint,
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
)
|
||||
except provider_service.InvalidClientError:
|
||||
# Per RFC 7009, we should return 200 OK even for errors
|
||||
# But client authentication errors can return 401
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="invalid_client",
|
||||
headers={"WWW-Authenticate": "Basic"},
|
||||
)
|
||||
except Exception as e:
|
||||
# Log but don't expose errors per RFC 7009
|
||||
logger.warning("Token revocation error: %s", e)
|
||||
|
||||
# Always return 200 OK per RFC 7009
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Token Introspection (RFC 7662)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.post(
|
||||
"/provider/introspect",
|
||||
response_model=OAuthTokenIntrospectionResponse,
|
||||
summary="Token Introspection Endpoint",
|
||||
description="""
|
||||
OAuth 2.0 Token Introspection Endpoint (RFC 7662).
|
||||
|
||||
Allows resource servers to query the authorization server
|
||||
to determine the active state and metadata of a token.
|
||||
""",
|
||||
operation_id="oauth_provider_introspect",
|
||||
tags=["OAuth Provider"],
|
||||
)
|
||||
@limiter.limit("120/minute")
|
||||
async def introspect(
|
||||
request: Request,
|
||||
token: str = Form(..., description="Token to introspect"),
|
||||
token_type_hint: str | None = Form(
|
||||
default=None, description="Token type hint (access_token, refresh_token)"
|
||||
),
|
||||
client_id: str | None = Form(default=None, description="Client ID"),
|
||||
client_secret: str | None = Form(default=None, description="Client secret"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_: None = Depends(require_provider_enabled),
|
||||
) -> OAuthTokenIntrospectionResponse:
|
||||
"""Introspect a token."""
|
||||
# Extract client credentials from Basic auth if not in body
|
||||
if not client_id:
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if auth_header.startswith("Basic "):
|
||||
import base64
|
||||
|
||||
try:
|
||||
decoded = base64.b64decode(auth_header[6:]).decode()
|
||||
client_id, client_secret = decoded.split(":", 1)
|
||||
except Exception as e:
|
||||
# Log malformed Basic auth for security monitoring
|
||||
logger.warning(
|
||||
"Malformed Basic auth header in introspect request: %s",
|
||||
type(e).__name__,
|
||||
)
|
||||
# Fall back to form body
|
||||
|
||||
try:
|
||||
result = await provider_service.introspect_token(
|
||||
db=db,
|
||||
token=token,
|
||||
token_type_hint=token_type_hint,
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
)
|
||||
return OAuthTokenIntrospectionResponse(**result)
|
||||
except provider_service.InvalidClientError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="invalid_client",
|
||||
headers={"WWW-Authenticate": "Basic"},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Token introspection error: %s", e)
|
||||
return OAuthTokenIntrospectionResponse(active=False) # pyright: ignore[reportCallIssue]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Client Management (Admin)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.post(
|
||||
"/provider/clients",
|
||||
response_model=dict,
|
||||
summary="Register OAuth Client",
|
||||
description="""
|
||||
Register a new OAuth client (admin only).
|
||||
|
||||
Creates an MCP client that can authenticate against this API.
|
||||
Returns client_id and client_secret (for confidential clients).
|
||||
|
||||
**Important:** Store the client_secret securely - it won't be shown again!
|
||||
""",
|
||||
operation_id="register_oauth_client",
|
||||
tags=["OAuth Provider Admin"],
|
||||
)
|
||||
async def register_client(
|
||||
client_name: str = Form(..., description="Client application name"),
|
||||
redirect_uris: str = Form(..., description="Comma-separated redirect URIs"),
|
||||
client_type: str = Form(default="public", description="public or confidential"),
|
||||
scopes: str = Form(
|
||||
default="openid profile email",
|
||||
description="Allowed scopes (space-separated)",
|
||||
),
|
||||
mcp_server_url: str | None = Form(default=None, description="MCP server URL"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_: None = Depends(require_provider_enabled),
|
||||
current_user: User = Depends(get_current_superuser),
|
||||
) -> dict:
|
||||
"""Register a new OAuth client."""
|
||||
# Parse redirect URIs
|
||||
uris = [uri.strip() for uri in redirect_uris.split(",") if uri.strip()]
|
||||
if not uris:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="At least one redirect_uri is required",
|
||||
)
|
||||
|
||||
# Parse scopes
|
||||
allowed_scopes = [s.strip() for s in scopes.split() if s.strip()]
|
||||
|
||||
client_data = OAuthClientCreate(
|
||||
client_name=client_name,
|
||||
client_description=None,
|
||||
redirect_uris=uris,
|
||||
allowed_scopes=allowed_scopes,
|
||||
client_type=client_type,
|
||||
)
|
||||
|
||||
client, secret = await provider_service.register_client(db, client_data)
|
||||
|
||||
# Update MCP server URL if provided
|
||||
if mcp_server_url:
|
||||
client.mcp_server_url = mcp_server_url
|
||||
await db.commit()
|
||||
|
||||
result = {
|
||||
"client_id": client.client_id,
|
||||
"client_name": client.client_name,
|
||||
"client_type": client.client_type,
|
||||
"redirect_uris": client.redirect_uris,
|
||||
"allowed_scopes": client.allowed_scopes,
|
||||
}
|
||||
|
||||
if secret:
|
||||
result["client_secret"] = secret
|
||||
result["warning"] = (
|
||||
"Store the client_secret securely! It will not be shown again."
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.get(
|
||||
"/provider/clients",
|
||||
response_model=list[OAuthClientResponse],
|
||||
summary="List OAuth Clients",
|
||||
description="List all registered OAuth clients (admin only).",
|
||||
operation_id="list_oauth_clients",
|
||||
tags=["OAuth Provider Admin"],
|
||||
)
|
||||
async def list_clients(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_: None = Depends(require_provider_enabled),
|
||||
current_user: User = Depends(get_current_superuser),
|
||||
) -> list[OAuthClientResponse]:
|
||||
"""List all OAuth clients."""
|
||||
clients = await provider_service.list_clients(db)
|
||||
return [OAuthClientResponse.model_validate(c) for c in clients]
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/provider/clients/{client_id}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
summary="Delete OAuth Client",
|
||||
description="Delete an OAuth client (admin only). Revokes all tokens.",
|
||||
operation_id="delete_oauth_client",
|
||||
tags=["OAuth Provider Admin"],
|
||||
)
|
||||
async def delete_client(
|
||||
client_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_: None = Depends(require_provider_enabled),
|
||||
current_user: User = Depends(get_current_superuser),
|
||||
) -> None:
|
||||
"""Delete an OAuth client."""
|
||||
client = await provider_service.get_client(db, client_id)
|
||||
if not client:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Client not found",
|
||||
)
|
||||
|
||||
await provider_service.delete_client_by_id(db, client_id=client_id)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# User Consent Management
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.get(
|
||||
"/provider/consents",
|
||||
summary="List My Consents",
|
||||
description="List OAuth applications the current user has authorized.",
|
||||
operation_id="list_my_oauth_consents",
|
||||
tags=["OAuth Provider"],
|
||||
)
|
||||
async def list_my_consents(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_: None = Depends(require_provider_enabled),
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
) -> list[dict]:
|
||||
"""List applications the user has authorized."""
|
||||
return await provider_service.list_user_consents(db, user_id=current_user.id)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/provider/consents/{client_id}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
summary="Revoke My Consent",
|
||||
description="Revoke authorization for an OAuth application. Also revokes all tokens.",
|
||||
operation_id="revoke_my_oauth_consent",
|
||||
tags=["OAuth Provider"],
|
||||
)
|
||||
async def revoke_my_consent(
|
||||
client_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_: None = Depends(require_provider_enabled),
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
) -> None:
|
||||
"""Revoke consent for an application."""
|
||||
revoked = await provider_service.revoke_consent(db, current_user.id, client_id)
|
||||
if not revoked:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="No consent found for this client",
|
||||
)
|
||||
@@ -4,8 +4,9 @@ Organization endpoints for regular users.
|
||||
|
||||
These endpoints allow users to view and manage organizations they belong to.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, List
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
@@ -14,19 +15,18 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.api.dependencies.auth import get_current_user
|
||||
from app.api.dependencies.permissions import require_org_admin, require_org_membership
|
||||
from app.core.database import get_db
|
||||
from app.core.exceptions import NotFoundError, ErrorCode
|
||||
from app.crud.organization import organization as organization_crud
|
||||
from app.models.user import User
|
||||
from app.schemas.common import (
|
||||
PaginationParams,
|
||||
PaginatedResponse,
|
||||
create_pagination_meta
|
||||
PaginationParams,
|
||||
create_pagination_meta,
|
||||
)
|
||||
from app.schemas.organizations import (
|
||||
OrganizationResponse,
|
||||
OrganizationMemberResponse,
|
||||
OrganizationUpdate
|
||||
OrganizationResponse,
|
||||
OrganizationUpdate,
|
||||
)
|
||||
from app.services.organization_service import organization_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -35,15 +35,15 @@ router = APIRouter()
|
||||
|
||||
@router.get(
|
||||
"/me",
|
||||
response_model=List[OrganizationResponse],
|
||||
response_model=list[OrganizationResponse],
|
||||
summary="Get My Organizations",
|
||||
description="Get all organizations the current user belongs to",
|
||||
operation_id="get_my_organizations"
|
||||
operation_id="get_my_organizations",
|
||||
)
|
||||
async def get_my_organizations(
|
||||
is_active: bool = Query(True, description="Filter by active membership"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Get all organizations the current user belongs to.
|
||||
@@ -53,16 +53,14 @@ async def get_my_organizations(
|
||||
"""
|
||||
try:
|
||||
# Get all org data in single query with JOIN and subquery
|
||||
orgs_data = await organization_crud.get_user_organizations_with_details(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
is_active=is_active
|
||||
orgs_data = await organization_service.get_user_organizations_with_details(
|
||||
db, user_id=current_user.id, is_active=is_active
|
||||
)
|
||||
|
||||
# Transform to response objects
|
||||
orgs_with_data = []
|
||||
for item in orgs_data:
|
||||
org = item['organization']
|
||||
org = item["organization"]
|
||||
org_dict = {
|
||||
"id": org.id,
|
||||
"name": org.name,
|
||||
@@ -72,14 +70,14 @@ async def get_my_organizations(
|
||||
"settings": org.settings,
|
||||
"created_at": org.created_at,
|
||||
"updated_at": org.updated_at,
|
||||
"member_count": item['member_count']
|
||||
"member_count": item["member_count"],
|
||||
}
|
||||
orgs_with_data.append(OrganizationResponse(**org_dict))
|
||||
|
||||
return orgs_with_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user organizations: {str(e)}", exc_info=True)
|
||||
logger.exception("Error getting user organizations: %s", e)
|
||||
raise
|
||||
|
||||
|
||||
@@ -88,12 +86,12 @@ async def get_my_organizations(
|
||||
response_model=OrganizationResponse,
|
||||
summary="Get Organization Details",
|
||||
description="Get details of an organization the user belongs to",
|
||||
operation_id="get_organization"
|
||||
operation_id="get_organization",
|
||||
)
|
||||
async def get_organization(
|
||||
organization_id: UUID,
|
||||
current_user: User = Depends(require_org_membership),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Get details of a specific organization.
|
||||
@@ -101,13 +99,7 @@ async def get_organization(
|
||||
User must be a member of the organization.
|
||||
"""
|
||||
try:
|
||||
org = await organization_crud.get(db, id=organization_id)
|
||||
if not org: # pragma: no cover - Permission check prevents this (see docs/UNREACHABLE_DEFENSIVE_CODE_ANALYSIS.md)
|
||||
raise NotFoundError(
|
||||
detail=f"Organization {organization_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND
|
||||
)
|
||||
|
||||
org = await organization_service.get_organization(db, str(organization_id))
|
||||
org_dict = {
|
||||
"id": org.id,
|
||||
"name": org.name,
|
||||
@@ -117,14 +109,14 @@ async def get_organization(
|
||||
"settings": org.settings,
|
||||
"created_at": org.created_at,
|
||||
"updated_at": org.updated_at,
|
||||
"member_count": await organization_crud.get_member_count(db, organization_id=org.id)
|
||||
"member_count": await organization_service.get_member_count(
|
||||
db, organization_id=org.id
|
||||
),
|
||||
}
|
||||
return OrganizationResponse(**org_dict)
|
||||
|
||||
except NotFoundError: # pragma: no cover - See above
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting organization: {str(e)}", exc_info=True)
|
||||
logger.exception("Error getting organization: %s", e)
|
||||
raise
|
||||
|
||||
|
||||
@@ -133,14 +125,14 @@ async def get_organization(
|
||||
response_model=PaginatedResponse[OrganizationMemberResponse],
|
||||
summary="Get Organization Members",
|
||||
description="Get all members of an organization (members can view)",
|
||||
operation_id="get_organization_members"
|
||||
operation_id="get_organization_members",
|
||||
)
|
||||
async def get_organization_members(
|
||||
organization_id: UUID,
|
||||
pagination: PaginationParams = Depends(),
|
||||
is_active: bool = Query(True, description="Filter by active status"),
|
||||
current_user: User = Depends(require_org_membership),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Get all members of an organization.
|
||||
@@ -148,12 +140,12 @@ async def get_organization_members(
|
||||
User must be a member of the organization to view members.
|
||||
"""
|
||||
try:
|
||||
members, total = await organization_crud.get_organization_members(
|
||||
members, total = await organization_service.get_organization_members(
|
||||
db,
|
||||
organization_id=organization_id,
|
||||
skip=pagination.offset,
|
||||
limit=pagination.limit,
|
||||
is_active=is_active
|
||||
is_active=is_active,
|
||||
)
|
||||
|
||||
member_responses = [OrganizationMemberResponse(**member) for member in members]
|
||||
@@ -162,13 +154,13 @@ async def get_organization_members(
|
||||
total=total,
|
||||
page=pagination.page,
|
||||
limit=pagination.limit,
|
||||
items_count=len(member_responses)
|
||||
items_count=len(member_responses),
|
||||
)
|
||||
|
||||
return PaginatedResponse(data=member_responses, pagination=pagination_meta)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting organization members: {str(e)}", exc_info=True)
|
||||
logger.exception("Error getting organization members: %s", e)
|
||||
raise
|
||||
|
||||
|
||||
@@ -177,13 +169,13 @@ async def get_organization_members(
|
||||
response_model=OrganizationResponse,
|
||||
summary="Update Organization",
|
||||
description="Update organization details (admin/owner only)",
|
||||
operation_id="update_organization"
|
||||
operation_id="update_organization",
|
||||
)
|
||||
async def update_organization(
|
||||
organization_id: UUID,
|
||||
org_in: OrganizationUpdate,
|
||||
current_user: User = Depends(require_org_admin),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Update organization details.
|
||||
@@ -191,15 +183,13 @@ async def update_organization(
|
||||
Requires owner or admin role in the organization.
|
||||
"""
|
||||
try:
|
||||
org = await organization_crud.get(db, id=organization_id)
|
||||
if not org: # pragma: no cover - Permission check prevents this (see docs/UNREACHABLE_DEFENSIVE_CODE_ANALYSIS.md)
|
||||
raise NotFoundError(
|
||||
detail=f"Organization {organization_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND
|
||||
)
|
||||
|
||||
updated_org = await organization_crud.update(db, db_obj=org, obj_in=org_in)
|
||||
logger.info(f"User {current_user.email} updated organization {updated_org.name}")
|
||||
org = await organization_service.get_organization(db, str(organization_id))
|
||||
updated_org = await organization_service.update_organization(
|
||||
db, org=org, obj_in=org_in
|
||||
)
|
||||
logger.info(
|
||||
"User %s updated organization %s", current_user.email, updated_org.name
|
||||
)
|
||||
|
||||
org_dict = {
|
||||
"id": updated_org.id,
|
||||
@@ -210,12 +200,12 @@ async def update_organization(
|
||||
"settings": updated_org.settings,
|
||||
"created_at": updated_org.created_at,
|
||||
"updated_at": updated_org.updated_at,
|
||||
"member_count": await organization_crud.get_member_count(db, organization_id=updated_org.id)
|
||||
"member_count": await organization_service.get_member_count(
|
||||
db, organization_id=updated_org.id
|
||||
),
|
||||
}
|
||||
return OrganizationResponse(**org_dict)
|
||||
|
||||
except NotFoundError: # pragma: no cover - See above
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating organization: {str(e)}", exc_info=True)
|
||||
logger.exception("Error updating organization: %s", e)
|
||||
raise
|
||||
|
||||
@@ -3,11 +3,12 @@ Session management endpoints.
|
||||
|
||||
Allows users to view and manage their active sessions across devices.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
@@ -15,11 +16,11 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.api.dependencies.auth import get_current_user
|
||||
from app.core.auth import decode_token
|
||||
from app.core.database import get_db
|
||||
from app.core.exceptions import NotFoundError, AuthorizationError, ErrorCode
|
||||
from app.crud.session import session as session_crud
|
||||
from app.core.exceptions import AuthorizationError, ErrorCode, NotFoundError
|
||||
from app.models.user import User
|
||||
from app.schemas.common import MessageResponse
|
||||
from app.schemas.sessions import SessionResponse, SessionListResponse
|
||||
from app.schemas.sessions import SessionListResponse, SessionResponse
|
||||
from app.services.session_service import session_service
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -39,13 +40,13 @@ limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
**Rate Limit**: 30 requests/minute
|
||||
""",
|
||||
operation_id="list_my_sessions"
|
||||
operation_id="list_my_sessions",
|
||||
)
|
||||
@limiter.limit("30/minute")
|
||||
async def list_my_sessions(
|
||||
request: Request,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
List all active sessions for the current user.
|
||||
@@ -59,23 +60,21 @@ async def list_my_sessions(
|
||||
"""
|
||||
try:
|
||||
# Get all active sessions for user
|
||||
sessions = await session_crud.get_user_sessions(
|
||||
db,
|
||||
user_id=str(current_user.id),
|
||||
active_only=True
|
||||
sessions = await session_service.get_user_sessions(
|
||||
db, user_id=str(current_user.id), active_only=True
|
||||
)
|
||||
|
||||
# Try to identify current session from Authorization header
|
||||
current_session_jti = None
|
||||
auth_header = request.headers.get("authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
try:
|
||||
access_token = auth_header.split(" ")[1]
|
||||
token_payload = decode_token(access_token)
|
||||
decode_token(access_token)
|
||||
# Note: Access tokens don't have JTI by default, but we can try
|
||||
# For now, we'll mark current based on most recent activity
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
# Optional token parsing - silently ignore failures
|
||||
logger.debug("Failed to decode access token for session marking: %s", e)
|
||||
|
||||
# Convert to response format
|
||||
session_responses = []
|
||||
@@ -90,22 +89,25 @@ async def list_my_sessions(
|
||||
last_used_at=s.last_used_at,
|
||||
created_at=s.created_at,
|
||||
expires_at=s.expires_at,
|
||||
is_current=(s == sessions[0] if sessions else False) # Most recent = current
|
||||
is_current=(
|
||||
s == sessions[0] if sessions else False
|
||||
), # Most recent = current
|
||||
)
|
||||
session_responses.append(session_response)
|
||||
|
||||
logger.info(f"User {current_user.id} listed {len(session_responses)} active sessions")
|
||||
logger.info(
|
||||
"User %s listed %s active sessions", current_user.id, len(session_responses)
|
||||
)
|
||||
|
||||
return SessionListResponse(
|
||||
sessions=session_responses,
|
||||
total=len(session_responses)
|
||||
sessions=session_responses, total=len(session_responses)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing sessions for user {current_user.id}: {str(e)}", exc_info=True)
|
||||
logger.exception("Error listing sessions for user %s: %s", current_user.id, e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to retrieve sessions"
|
||||
detail="Failed to retrieve sessions",
|
||||
)
|
||||
|
||||
|
||||
@@ -122,14 +124,14 @@ async def list_my_sessions(
|
||||
|
||||
**Rate Limit**: 10 requests/minute
|
||||
""",
|
||||
operation_id="revoke_session"
|
||||
operation_id="revoke_session",
|
||||
)
|
||||
@limiter.limit("10/minute")
|
||||
async def revoke_session(
|
||||
request: Request,
|
||||
session_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Revoke a specific session by ID.
|
||||
@@ -144,45 +146,49 @@ async def revoke_session(
|
||||
"""
|
||||
try:
|
||||
# Get the session
|
||||
session = await session_crud.get(db, id=str(session_id))
|
||||
session = await session_service.get_session(db, str(session_id))
|
||||
|
||||
if not session:
|
||||
raise NotFoundError(
|
||||
message=f"Session {session_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
# Verify session belongs to current user
|
||||
if str(session.user_id) != str(current_user.id):
|
||||
logger.warning(
|
||||
f"User {current_user.id} attempted to revoke session {session_id} "
|
||||
f"belonging to user {session.user_id}"
|
||||
"User %s attempted to revoke session %s belonging to user %s",
|
||||
current_user.id,
|
||||
session_id,
|
||||
session.user_id,
|
||||
)
|
||||
raise AuthorizationError(
|
||||
message="You can only revoke your own sessions",
|
||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS
|
||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
)
|
||||
|
||||
# Deactivate the session
|
||||
await session_crud.deactivate(db, session_id=str(session_id))
|
||||
await session_service.deactivate(db, session_id=str(session_id))
|
||||
|
||||
logger.info(
|
||||
f"User {current_user.id} revoked session {session_id} "
|
||||
f"({session.device_name})"
|
||||
"User %s revoked session %s (%s)",
|
||||
current_user.id,
|
||||
session_id,
|
||||
session.device_name,
|
||||
)
|
||||
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message=f"Session revoked: {session.device_name or 'Unknown device'}"
|
||||
message=f"Session revoked: {session.device_name or 'Unknown device'}",
|
||||
)
|
||||
|
||||
except (NotFoundError, AuthorizationError):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error revoking session {session_id}: {str(e)}", exc_info=True)
|
||||
logger.exception("Error revoking session %s: %s", session_id, e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to revoke session"
|
||||
detail="Failed to revoke session",
|
||||
)
|
||||
|
||||
|
||||
@@ -198,13 +204,13 @@ async def revoke_session(
|
||||
|
||||
**Rate Limit**: 5 requests/minute
|
||||
""",
|
||||
operation_id="cleanup_expired_sessions"
|
||||
operation_id="cleanup_expired_sessions",
|
||||
)
|
||||
@limiter.limit("5/minute")
|
||||
async def cleanup_expired_sessions(
|
||||
request: Request,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Cleanup expired sessions for the current user.
|
||||
@@ -218,22 +224,24 @@ async def cleanup_expired_sessions(
|
||||
"""
|
||||
try:
|
||||
# Use optimized bulk DELETE instead of N individual deletes
|
||||
deleted_count = await session_crud.cleanup_expired_for_user(
|
||||
db,
|
||||
user_id=str(current_user.id)
|
||||
deleted_count = await session_service.cleanup_expired_for_user(
|
||||
db, user_id=str(current_user.id)
|
||||
)
|
||||
|
||||
logger.info(f"User {current_user.id} cleaned up {deleted_count} expired sessions")
|
||||
logger.info(
|
||||
"User %s cleaned up %s expired sessions", current_user.id, deleted_count
|
||||
)
|
||||
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message=f"Cleaned up {deleted_count} expired sessions"
|
||||
success=True, message=f"Cleaned up {deleted_count} expired sessions"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up sessions for user {current_user.id}: {str(e)}", exc_info=True)
|
||||
logger.exception(
|
||||
"Error cleaning up sessions for user %s: %s", current_user.id, e
|
||||
)
|
||||
await db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to cleanup sessions"
|
||||
detail="Failed to cleanup sessions",
|
||||
)
|
||||
|
||||
@@ -1,33 +1,30 @@
|
||||
"""
|
||||
User management endpoints for CRUD operations.
|
||||
User management endpoints for database operations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, status, Request
|
||||
from fastapi import APIRouter, Depends, Query, Request, status
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.dependencies.auth import get_current_user, get_current_superuser
|
||||
from app.api.dependencies.auth import get_current_superuser, get_current_user
|
||||
from app.core.database import get_db
|
||||
from app.core.exceptions import (
|
||||
NotFoundError,
|
||||
AuthorizationError,
|
||||
ErrorCode
|
||||
)
|
||||
from app.crud.user import user as user_crud
|
||||
from app.core.exceptions import AuthorizationError, ErrorCode
|
||||
from app.models.user import User
|
||||
from app.schemas.common import (
|
||||
PaginationParams,
|
||||
PaginatedResponse,
|
||||
MessageResponse,
|
||||
PaginatedResponse,
|
||||
PaginationParams,
|
||||
SortParams,
|
||||
create_pagination_meta
|
||||
create_pagination_meta,
|
||||
)
|
||||
from app.schemas.users import UserResponse, UserUpdate, PasswordChange
|
||||
from app.services.auth_service import AuthService, AuthenticationError
|
||||
from app.schemas.users import PasswordChange, UserResponse, UserUpdate
|
||||
from app.services.auth_service import AuthenticationError, AuthService
|
||||
from app.services.user_service import user_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -50,15 +47,15 @@ limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
**Rate Limit**: 60 requests/minute
|
||||
""",
|
||||
operation_id="list_users"
|
||||
operation_id="list_users",
|
||||
)
|
||||
async def list_users(
|
||||
pagination: PaginationParams = Depends(),
|
||||
sort: SortParams = Depends(),
|
||||
is_active: Optional[bool] = Query(None, description="Filter by active status"),
|
||||
is_superuser: Optional[bool] = Query(None, description="Filter by superuser status"),
|
||||
is_active: bool | None = Query(None, description="Filter by active status"),
|
||||
is_superuser: bool | None = Query(None, description="Filter by superuser status"),
|
||||
current_user: User = Depends(get_current_superuser),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
List all users with pagination, filtering, and sorting.
|
||||
@@ -74,13 +71,13 @@ async def list_users(
|
||||
filters["is_superuser"] = is_superuser
|
||||
|
||||
# Get paginated users with total count
|
||||
users, total = await user_crud.get_multi_with_total(
|
||||
users, total = await user_service.list_users(
|
||||
db,
|
||||
skip=pagination.offset,
|
||||
limit=pagination.limit,
|
||||
sort_by=sort.sort_by,
|
||||
sort_order=sort.sort_order.value if sort.sort_order else "asc",
|
||||
filters=filters if filters else None
|
||||
filters=filters if filters else None,
|
||||
)
|
||||
|
||||
# Create pagination metadata
|
||||
@@ -88,15 +85,12 @@ async def list_users(
|
||||
total=total,
|
||||
page=pagination.page,
|
||||
limit=pagination.limit,
|
||||
items_count=len(users)
|
||||
items_count=len(users),
|
||||
)
|
||||
|
||||
return PaginatedResponse(
|
||||
data=users,
|
||||
pagination=pagination_meta
|
||||
)
|
||||
return PaginatedResponse(data=users, pagination=pagination_meta)
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing users: {str(e)}", exc_info=True)
|
||||
logger.exception("Error listing users: %s", e)
|
||||
raise
|
||||
|
||||
|
||||
@@ -111,10 +105,10 @@ async def list_users(
|
||||
|
||||
**Rate Limit**: 60 requests/minute
|
||||
""",
|
||||
operation_id="get_current_user_profile"
|
||||
operation_id="get_current_user_profile",
|
||||
)
|
||||
def get_current_user_profile(
|
||||
current_user: User = Depends(get_current_user)
|
||||
async def get_current_user_profile(
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> Any:
|
||||
"""Get current user's profile."""
|
||||
return current_user
|
||||
@@ -133,12 +127,12 @@ def get_current_user_profile(
|
||||
|
||||
**Rate Limit**: 30 requests/minute
|
||||
""",
|
||||
operation_id="update_current_user"
|
||||
operation_id="update_current_user",
|
||||
)
|
||||
async def update_current_user(
|
||||
user_update: UserUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Update current user's profile.
|
||||
@@ -146,18 +140,16 @@ async def update_current_user(
|
||||
Users cannot elevate their own permissions (protected by UserUpdate schema validator).
|
||||
"""
|
||||
try:
|
||||
updated_user = await user_crud.update(
|
||||
db,
|
||||
db_obj=current_user,
|
||||
obj_in=user_update
|
||||
updated_user = await user_service.update_user(
|
||||
db, user=current_user, obj_in=user_update
|
||||
)
|
||||
logger.info(f"User {current_user.id} updated their profile")
|
||||
logger.info("User %s updated their profile", current_user.id)
|
||||
return updated_user
|
||||
except ValueError as e:
|
||||
logger.error(f"Error updating user {current_user.id}: {str(e)}")
|
||||
logger.error("Error updating user %s: %s", current_user.id, e)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error updating user {current_user.id}: {str(e)}", exc_info=True)
|
||||
logger.exception("Unexpected error updating user %s: %s", current_user.id, e)
|
||||
raise
|
||||
|
||||
|
||||
@@ -175,12 +167,12 @@ async def update_current_user(
|
||||
|
||||
**Rate Limit**: 60 requests/minute
|
||||
""",
|
||||
operation_id="get_user_by_id"
|
||||
operation_id="get_user_by_id",
|
||||
)
|
||||
async def get_user_by_id(
|
||||
user_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Get user by ID.
|
||||
@@ -190,21 +182,17 @@ async def get_user_by_id(
|
||||
# Check permissions
|
||||
if str(user_id) != str(current_user.id) and not current_user.is_superuser:
|
||||
logger.warning(
|
||||
f"User {current_user.id} attempted to access user {user_id} without permission"
|
||||
"User %s attempted to access user %s without permission",
|
||||
current_user.id,
|
||||
user_id,
|
||||
)
|
||||
raise AuthorizationError(
|
||||
message="Not enough permissions to view this user",
|
||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS
|
||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
)
|
||||
|
||||
# Get user
|
||||
user = await user_crud.get(db, id=str(user_id))
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
message=f"User with id {user_id} not found",
|
||||
error_code=ErrorCode.USER_NOT_FOUND
|
||||
)
|
||||
|
||||
user = await user_service.get_user(db, str(user_id))
|
||||
return user
|
||||
|
||||
|
||||
@@ -222,13 +210,13 @@ async def get_user_by_id(
|
||||
|
||||
**Rate Limit**: 30 requests/minute
|
||||
""",
|
||||
operation_id="update_user"
|
||||
operation_id="update_user",
|
||||
)
|
||||
async def update_user(
|
||||
user_id: UUID,
|
||||
user_update: UserUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Update user by ID.
|
||||
@@ -241,30 +229,27 @@ async def update_user(
|
||||
|
||||
if not is_own_profile and not current_user.is_superuser:
|
||||
logger.warning(
|
||||
f"User {current_user.id} attempted to update user {user_id} without permission"
|
||||
"User %s attempted to update user %s without permission",
|
||||
current_user.id,
|
||||
user_id,
|
||||
)
|
||||
raise AuthorizationError(
|
||||
message="Not enough permissions to update this user",
|
||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS
|
||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
)
|
||||
|
||||
# Get user
|
||||
user = await user_crud.get(db, id=str(user_id))
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
message=f"User with id {user_id} not found",
|
||||
error_code=ErrorCode.USER_NOT_FOUND
|
||||
)
|
||||
user = await user_service.get_user(db, str(user_id))
|
||||
|
||||
try:
|
||||
updated_user = await user_crud.update(db, db_obj=user, obj_in=user_update)
|
||||
logger.info(f"User {user_id} updated by {current_user.id}")
|
||||
updated_user = await user_service.update_user(db, user=user, obj_in=user_update)
|
||||
logger.info("User %s updated by %s", user_id, current_user.id)
|
||||
return updated_user
|
||||
except ValueError as e:
|
||||
logger.error(f"Error updating user {user_id}: {str(e)}")
|
||||
logger.error("Error updating user %s: %s", user_id, e)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error updating user {user_id}: {str(e)}", exc_info=True)
|
||||
logger.exception("Unexpected error updating user %s: %s", user_id, e)
|
||||
raise
|
||||
|
||||
|
||||
@@ -281,14 +266,14 @@ async def update_user(
|
||||
|
||||
**Rate Limit**: 5 requests/minute
|
||||
""",
|
||||
operation_id="change_current_user_password"
|
||||
operation_id="change_current_user_password",
|
||||
)
|
||||
@limiter.limit("5/minute")
|
||||
async def change_current_user_password(
|
||||
request: Request,
|
||||
password_change: PasswordChange,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Change current user's password.
|
||||
@@ -300,23 +285,23 @@ async def change_current_user_password(
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
current_password=password_change.current_password,
|
||||
new_password=password_change.new_password
|
||||
new_password=password_change.new_password,
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info(f"User {current_user.id} changed their password")
|
||||
logger.info("User %s changed their password", current_user.id)
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message="Password changed successfully"
|
||||
success=True, message="Password changed successfully"
|
||||
)
|
||||
except AuthenticationError as e:
|
||||
logger.warning(f"Failed password change attempt for user {current_user.id}: {str(e)}")
|
||||
logger.warning(
|
||||
"Failed password change attempt for user %s: %s", current_user.id, e
|
||||
)
|
||||
raise AuthorizationError(
|
||||
message=str(e),
|
||||
error_code=ErrorCode.INVALID_CREDENTIALS
|
||||
message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error changing password for user {current_user.id}: {str(e)}")
|
||||
logger.error("Error changing password for user %s: %s", current_user.id, e)
|
||||
raise
|
||||
|
||||
|
||||
@@ -335,12 +320,12 @@ async def change_current_user_password(
|
||||
|
||||
**Note**: This performs a hard delete. Consider implementing soft deletes for production.
|
||||
""",
|
||||
operation_id="delete_user"
|
||||
operation_id="delete_user",
|
||||
)
|
||||
async def delete_user(
|
||||
user_id: UUID,
|
||||
current_user: User = Depends(get_current_superuser),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Delete user by ID (superuser only).
|
||||
@@ -351,28 +336,22 @@ async def delete_user(
|
||||
if str(user_id) == str(current_user.id):
|
||||
raise AuthorizationError(
|
||||
message="Cannot delete your own account",
|
||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS
|
||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
)
|
||||
|
||||
# Get user
|
||||
user = await user_crud.get(db, id=str(user_id))
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
message=f"User with id {user_id} not found",
|
||||
error_code=ErrorCode.USER_NOT_FOUND
|
||||
)
|
||||
# Get user (raises NotFoundError if not found)
|
||||
await user_service.get_user(db, str(user_id))
|
||||
|
||||
try:
|
||||
# Use soft delete instead of hard delete
|
||||
await user_crud.soft_delete(db, id=str(user_id))
|
||||
logger.info(f"User {user_id} soft-deleted by {current_user.id}")
|
||||
await user_service.soft_delete_user(db, str(user_id))
|
||||
logger.info("User %s soft-deleted by %s", user_id, current_user.id)
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message=f"User {user_id} deleted successfully"
|
||||
success=True, message=f"User {user_id} deleted successfully"
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"Error deleting user {user_id}: {str(e)}")
|
||||
logger.error("Error deleting user %s: %s", user_id, e)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error deleting user {user_id}: {str(e)}", exc_info=True)
|
||||
logger.exception("Unexpected error deleting user %s: %s", user_id, e)
|
||||
raise
|
||||
|
||||
@@ -1,49 +1,50 @@
|
||||
import logging
|
||||
logging.getLogger('passlib').setLevel(logging.ERROR)
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Dict, Optional, Union
|
||||
import uuid
|
||||
import asyncio
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
from jose import jwt, JWTError
|
||||
from passlib.context import CryptContext
|
||||
import bcrypt
|
||||
import jwt
|
||||
from jwt.exceptions import (
|
||||
ExpiredSignatureError,
|
||||
InvalidTokenError,
|
||||
MissingRequiredClaimError,
|
||||
)
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.core.config import settings
|
||||
from app.schemas.users import TokenData, TokenPayload
|
||||
|
||||
|
||||
# Password hashing context
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
# Custom exceptions for auth
|
||||
class AuthError(Exception):
|
||||
"""Base authentication error"""
|
||||
pass
|
||||
|
||||
|
||||
class TokenExpiredError(AuthError):
|
||||
"""Token has expired"""
|
||||
pass
|
||||
|
||||
|
||||
class TokenInvalidError(AuthError):
|
||||
"""Token is invalid"""
|
||||
pass
|
||||
|
||||
|
||||
class TokenMissingClaimError(AuthError):
|
||||
"""Token is missing a required claim"""
|
||||
pass
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""Verify a password against a hash."""
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
"""Verify a password against a bcrypt hash."""
|
||||
return bcrypt.checkpw(
|
||||
plain_password.encode("utf-8"), hashed_password.encode("utf-8")
|
||||
)
|
||||
|
||||
|
||||
def get_password_hash(password: str) -> str:
|
||||
"""Generate a password hash."""
|
||||
return pwd_context.hash(password)
|
||||
"""Generate a bcrypt password hash."""
|
||||
salt = bcrypt.gensalt()
|
||||
return bcrypt.hashpw(password.encode("utf-8"), salt).decode("utf-8")
|
||||
|
||||
|
||||
async def verify_password_async(plain_password: str, hashed_password: str) -> bool:
|
||||
@@ -60,10 +61,9 @@ async def verify_password_async(plain_password: str, hashed_password: str) -> bo
|
||||
Returns:
|
||||
True if password matches, False otherwise
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(
|
||||
None,
|
||||
partial(pwd_context.verify, plain_password, hashed_password)
|
||||
None, partial(verify_password, plain_password, hashed_password)
|
||||
)
|
||||
|
||||
|
||||
@@ -81,18 +81,14 @@ async def get_password_hash_async(password: str) -> str:
|
||||
Returns:
|
||||
Hashed password string
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
None,
|
||||
pwd_context.hash,
|
||||
password
|
||||
)
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(None, get_password_hash, password)
|
||||
|
||||
|
||||
def create_access_token(
|
||||
subject: Union[str, Any],
|
||||
expires_delta: Optional[timedelta] = None,
|
||||
claims: Optional[Dict[str, Any]] = None
|
||||
subject: str | Any,
|
||||
expires_delta: timedelta | None = None,
|
||||
claims: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Create a JWT access token.
|
||||
@@ -106,17 +102,19 @@ def create_access_token(
|
||||
Encoded JWT token
|
||||
"""
|
||||
if expires_delta:
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
expire = datetime.now(UTC) + expires_delta
|
||||
else:
|
||||
expire = datetime.now(timezone.utc) + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
expire = datetime.now(UTC) + timedelta(
|
||||
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
)
|
||||
|
||||
# Base token data
|
||||
to_encode = {
|
||||
"sub": str(subject),
|
||||
"exp": expire,
|
||||
"iat": datetime.now(tz=timezone.utc),
|
||||
"iat": datetime.now(tz=UTC),
|
||||
"jti": str(uuid.uuid4()),
|
||||
"type": "access"
|
||||
"type": "access",
|
||||
}
|
||||
|
||||
# Add custom claims
|
||||
@@ -124,18 +122,11 @@ def create_access_token(
|
||||
to_encode.update(claims)
|
||||
|
||||
# Create the JWT
|
||||
encoded_jwt = jwt.encode(
|
||||
to_encode,
|
||||
settings.SECRET_KEY,
|
||||
algorithm=settings.ALGORITHM
|
||||
)
|
||||
|
||||
return encoded_jwt
|
||||
return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||||
|
||||
|
||||
def create_refresh_token(
|
||||
subject: Union[str, Any],
|
||||
expires_delta: Optional[timedelta] = None
|
||||
subject: str | Any, expires_delta: timedelta | None = None
|
||||
) -> str:
|
||||
"""
|
||||
Create a JWT refresh token.
|
||||
@@ -148,28 +139,22 @@ def create_refresh_token(
|
||||
Encoded JWT refresh token
|
||||
"""
|
||||
if expires_delta:
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
expire = datetime.now(UTC) + expires_delta
|
||||
else:
|
||||
expire = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
expire = datetime.now(UTC) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
|
||||
to_encode = {
|
||||
"sub": str(subject),
|
||||
"exp": expire,
|
||||
"iat": datetime.now(timezone.utc),
|
||||
"iat": datetime.now(UTC),
|
||||
"jti": str(uuid.uuid4()),
|
||||
"type": "refresh"
|
||||
"type": "refresh",
|
||||
}
|
||||
|
||||
encoded_jwt = jwt.encode(
|
||||
to_encode,
|
||||
settings.SECRET_KEY,
|
||||
algorithm=settings.ALGORITHM
|
||||
)
|
||||
|
||||
return encoded_jwt
|
||||
return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||||
|
||||
|
||||
def decode_token(token: str, verify_type: Optional[str] = None) -> TokenPayload:
|
||||
def decode_token(token: str, verify_type: str | None = None) -> TokenPayload:
|
||||
"""
|
||||
Decode and verify a JWT token.
|
||||
|
||||
@@ -195,8 +180,8 @@ def decode_token(token: str, verify_type: Optional[str] = None) -> TokenPayload:
|
||||
"verify_signature": True,
|
||||
"verify_exp": True,
|
||||
"verify_iat": True,
|
||||
"require": ["exp", "sub", "iat"]
|
||||
}
|
||||
"require": ["exp", "sub", "iat"],
|
||||
},
|
||||
)
|
||||
|
||||
# SECURITY: Explicitly verify the algorithm to prevent algorithm confusion attacks
|
||||
@@ -206,7 +191,7 @@ def decode_token(token: str, verify_type: Optional[str] = None) -> TokenPayload:
|
||||
|
||||
# Reject weak or unexpected algorithms
|
||||
# NOTE: These are defensive checks that provide defense-in-depth.
|
||||
# The python-jose library rejects these tokens BEFORE we reach here,
|
||||
# PyJWT rejects these tokens BEFORE we reach here,
|
||||
# but we keep these checks in case the library changes or is misconfigured.
|
||||
# Coverage: Marked as pragma since library catches first (see tests/core/test_auth_security.py)
|
||||
if token_algorithm == "NONE": # pragma: no cover
|
||||
@@ -227,10 +212,11 @@ def decode_token(token: str, verify_type: Optional[str] = None) -> TokenPayload:
|
||||
token_data = TokenPayload(**payload)
|
||||
return token_data
|
||||
|
||||
except JWTError as e:
|
||||
# Check if the error is due to an expired token
|
||||
if "expired" in str(e).lower():
|
||||
raise TokenExpiredError("Token has expired")
|
||||
except ExpiredSignatureError:
|
||||
raise TokenExpiredError("Token has expired")
|
||||
except MissingRequiredClaimError as e:
|
||||
raise TokenMissingClaimError(f"Token missing required claim: {e}")
|
||||
except InvalidTokenError:
|
||||
raise TokenInvalidError("Invalid authentication token")
|
||||
except ValidationError:
|
||||
raise TokenInvalidError("Invalid token payload")
|
||||
@@ -250,4 +236,4 @@ def get_token_data(token: str) -> TokenData:
|
||||
user_id = payload.sub
|
||||
is_superuser = payload.is_superuser or False
|
||||
|
||||
return TokenData(user_id=uuid.UUID(user_id), is_superuser=is_superuser)
|
||||
return TokenData(user_id=uuid.UUID(user_id), is_superuser=is_superuser)
|
||||
|
||||
@@ -1,19 +1,22 @@
|
||||
import logging
|
||||
from typing import Optional, List
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
PROJECT_NAME: str = "App"
|
||||
PROJECT_NAME: str = "PragmaStack"
|
||||
VERSION: str = "1.0.0"
|
||||
API_V1_STR: str = "/api/v1"
|
||||
|
||||
# Environment (must be before SECRET_KEY for validation)
|
||||
ENVIRONMENT: str = Field(
|
||||
default="development",
|
||||
description="Environment: development, staging, or production"
|
||||
description="Environment: development, staging, or production",
|
||||
)
|
||||
DEMO_MODE: bool = Field(
|
||||
default=False,
|
||||
description="Enable demo mode (relaxed security, demo users)",
|
||||
)
|
||||
|
||||
# Security: Content Security Policy
|
||||
@@ -21,8 +24,7 @@ class Settings(BaseSettings):
|
||||
# Set to True for strict CSP (blocks most external resources)
|
||||
# Set to "relaxed" for modern frontend development
|
||||
CSP_MODE: str = Field(
|
||||
default="relaxed",
|
||||
description="CSP mode: 'strict', 'relaxed', or 'disabled'"
|
||||
default="relaxed", description="CSP mode: 'strict', 'relaxed', or 'disabled'"
|
||||
)
|
||||
|
||||
# Database configuration
|
||||
@@ -31,7 +33,7 @@ class Settings(BaseSettings):
|
||||
POSTGRES_HOST: str = "localhost"
|
||||
POSTGRES_PORT: str = "5432"
|
||||
POSTGRES_DB: str = "app"
|
||||
DATABASE_URL: Optional[str] = None
|
||||
DATABASE_URL: str | None = None
|
||||
db_pool_size: int = 20 # Default connection pool size
|
||||
db_max_overflow: int = 50 # Maximum overflow connections
|
||||
db_pool_timeout: int = 30 # Seconds to wait for a connection
|
||||
@@ -59,38 +61,90 @@ class Settings(BaseSettings):
|
||||
SECRET_KEY: str = Field(
|
||||
default="dev_only_insecure_key_change_in_production_32chars_min",
|
||||
min_length=32,
|
||||
description="JWT signing key. MUST be changed in production. Generate with: python -c 'import secrets; print(secrets.token_urlsafe(32))'"
|
||||
description="JWT signing key. MUST be changed in production. Generate with: python -c 'import secrets; print(secrets.token_urlsafe(32))'",
|
||||
)
|
||||
ALGORITHM: str = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 15 # 15 minutes (production standard)
|
||||
REFRESH_TOKEN_EXPIRE_DAYS: int = 7 # 7 days
|
||||
|
||||
# CORS configuration
|
||||
BACKEND_CORS_ORIGINS: List[str] = ["http://localhost:3000"]
|
||||
BACKEND_CORS_ORIGINS: list[str] = ["http://localhost:3000"]
|
||||
|
||||
# Frontend URL for email links
|
||||
FRONTEND_URL: str = Field(
|
||||
default="http://localhost:3000",
|
||||
description="Frontend application URL for email links"
|
||||
description="Frontend application URL for email links",
|
||||
)
|
||||
|
||||
# OAuth Configuration
|
||||
OAUTH_ENABLED: bool = Field(
|
||||
default=False,
|
||||
description="Enable OAuth authentication (social login)",
|
||||
)
|
||||
OAUTH_AUTO_LINK_BY_EMAIL: bool = Field(
|
||||
default=True,
|
||||
description="Automatically link OAuth accounts to existing users with matching email",
|
||||
)
|
||||
OAUTH_STATE_EXPIRE_MINUTES: int = Field(
|
||||
default=10,
|
||||
description="OAuth state parameter expiration time in minutes",
|
||||
)
|
||||
|
||||
# Google OAuth
|
||||
OAUTH_GOOGLE_CLIENT_ID: str | None = Field(
|
||||
default=None,
|
||||
description="Google OAuth client ID from Google Cloud Console",
|
||||
)
|
||||
OAUTH_GOOGLE_CLIENT_SECRET: str | None = Field(
|
||||
default=None,
|
||||
description="Google OAuth client secret from Google Cloud Console",
|
||||
)
|
||||
|
||||
# GitHub OAuth
|
||||
OAUTH_GITHUB_CLIENT_ID: str | None = Field(
|
||||
default=None,
|
||||
description="GitHub OAuth client ID from GitHub Developer Settings",
|
||||
)
|
||||
OAUTH_GITHUB_CLIENT_SECRET: str | None = Field(
|
||||
default=None,
|
||||
description="GitHub OAuth client secret from GitHub Developer Settings",
|
||||
)
|
||||
|
||||
# OAuth Provider Mode (for MCP clients - skeleton)
|
||||
OAUTH_PROVIDER_ENABLED: bool = Field(
|
||||
default=False,
|
||||
description="Enable OAuth provider mode (act as authorization server for MCP clients)",
|
||||
)
|
||||
OAUTH_ISSUER: str = Field(
|
||||
default="http://localhost:8000",
|
||||
description="OAuth issuer URL (your API base URL)",
|
||||
)
|
||||
|
||||
@property
|
||||
def enabled_oauth_providers(self) -> list[str]:
|
||||
"""Get list of enabled OAuth providers based on configured credentials."""
|
||||
providers = []
|
||||
if self.OAUTH_GOOGLE_CLIENT_ID and self.OAUTH_GOOGLE_CLIENT_SECRET:
|
||||
providers.append("google")
|
||||
if self.OAUTH_GITHUB_CLIENT_ID and self.OAUTH_GITHUB_CLIENT_SECRET:
|
||||
providers.append("github")
|
||||
return providers
|
||||
|
||||
# Admin user
|
||||
FIRST_SUPERUSER_EMAIL: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Email for first superuser account"
|
||||
FIRST_SUPERUSER_EMAIL: str | None = Field(
|
||||
default=None, description="Email for first superuser account"
|
||||
)
|
||||
FIRST_SUPERUSER_PASSWORD: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Password for first superuser (min 12 characters)"
|
||||
FIRST_SUPERUSER_PASSWORD: str | None = Field(
|
||||
default=None, description="Password for first superuser (min 12 characters)"
|
||||
)
|
||||
|
||||
@field_validator('SECRET_KEY')
|
||||
@field_validator("SECRET_KEY")
|
||||
@classmethod
|
||||
def validate_secret_key(cls, v: str, info) -> str:
|
||||
"""Validate SECRET_KEY is secure, especially in production."""
|
||||
# Get environment from values if available
|
||||
values_data = info.data if info.data else {}
|
||||
env = values_data.get('ENVIRONMENT', 'development')
|
||||
env = values_data.get("ENVIRONMENT", "development")
|
||||
|
||||
if v.startswith("your_secret_key_here"):
|
||||
if env == "production":
|
||||
@@ -106,22 +160,40 @@ class Settings(BaseSettings):
|
||||
)
|
||||
|
||||
if len(v) < 32:
|
||||
raise ValueError("SECRET_KEY must be at least 32 characters long for security")
|
||||
raise ValueError(
|
||||
"SECRET_KEY must be at least 32 characters long for security"
|
||||
)
|
||||
|
||||
return v
|
||||
|
||||
@field_validator('FIRST_SUPERUSER_PASSWORD')
|
||||
@field_validator("FIRST_SUPERUSER_PASSWORD")
|
||||
@classmethod
|
||||
def validate_superuser_password(cls, v: Optional[str]) -> Optional[str]:
|
||||
def validate_superuser_password(cls, v: str | None, info) -> str | None:
|
||||
"""Validate superuser password strength."""
|
||||
if v is None:
|
||||
return v
|
||||
|
||||
# Get environment from values if available
|
||||
values_data = info.data if info.data else {}
|
||||
demo_mode = values_data.get("DEMO_MODE", False)
|
||||
|
||||
if demo_mode:
|
||||
# In demo mode, allow specific weak passwords for demo accounts
|
||||
demo_passwords = {"Demo123!", "Admin123!"}
|
||||
if v in demo_passwords:
|
||||
return v
|
||||
|
||||
if len(v) < 12:
|
||||
raise ValueError("FIRST_SUPERUSER_PASSWORD must be at least 12 characters")
|
||||
|
||||
# Check for common weak passwords
|
||||
weak_passwords = {'admin123', 'Admin123', 'password123', 'Password123', '123456789012'}
|
||||
weak_passwords = {
|
||||
"admin123",
|
||||
"Admin123",
|
||||
"password123",
|
||||
"Password123",
|
||||
"123456789012",
|
||||
}
|
||||
if v in weak_passwords:
|
||||
raise ValueError(
|
||||
"FIRST_SUPERUSER_PASSWORD is too weak. "
|
||||
@@ -144,8 +216,8 @@ class Settings(BaseSettings):
|
||||
"env_file": "../.env",
|
||||
"env_file_encoding": "utf-8",
|
||||
"case_sensitive": True,
|
||||
"extra": "ignore" # Ignore extra fields from .env (e.g., frontend-specific vars)
|
||||
"extra": "ignore", # Ignore extra fields from .env (e.g., frontend-specific vars)
|
||||
}
|
||||
|
||||
|
||||
settings = Settings()
|
||||
settings = Settings()
|
||||
|
||||
@@ -5,17 +5,18 @@ Database configuration using SQLAlchemy 2.0 and asyncpg.
|
||||
This module provides async database connectivity with proper connection pooling
|
||||
and session management for FastAPI endpoints.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncSession,
|
||||
AsyncEngine,
|
||||
create_async_engine,
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
)
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
@@ -27,12 +28,12 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# SQLite compatibility for testing
|
||||
@compiles(JSONB, 'sqlite')
|
||||
@compiles(JSONB, "sqlite")
|
||||
def compile_jsonb_sqlite(type_, compiler, **kw):
|
||||
return "TEXT"
|
||||
|
||||
|
||||
@compiles(UUID, 'sqlite')
|
||||
@compiles(UUID, "sqlite")
|
||||
def compile_uuid_sqlite(type_, compiler, **kw):
|
||||
return "TEXT"
|
||||
|
||||
@@ -40,7 +41,6 @@ def compile_uuid_sqlite(type_, compiler, **kw):
|
||||
# Declarative base for models (SQLAlchemy 2.0 style)
|
||||
class Base(DeclarativeBase):
|
||||
"""Base class for all database models."""
|
||||
pass
|
||||
|
||||
|
||||
def get_async_database_url(url: str) -> str:
|
||||
@@ -75,7 +75,7 @@ def create_async_production_engine() -> AsyncEngine:
|
||||
|
||||
# Add PostgreSQL-specific connect_args
|
||||
if "postgresql" in async_url:
|
||||
engine_config["connect_args"] = {
|
||||
engine_config["connect_args"] = { # type: ignore[assignment]
|
||||
"server_settings": {
|
||||
"application_name": settings.PROJECT_NAME,
|
||||
"timezone": "UTC",
|
||||
@@ -128,8 +128,8 @@ async def async_transaction_scope() -> AsyncGenerator[AsyncSession, None]:
|
||||
|
||||
Usage:
|
||||
async with async_transaction_scope() as db:
|
||||
user = await user_crud.create(db, obj_in=user_create)
|
||||
profile = await profile_crud.create(db, obj_in=profile_create)
|
||||
user = await user_repo.create(db, obj_in=user_create)
|
||||
profile = await profile_repo.create(db, obj_in=profile_create)
|
||||
# Both operations committed together
|
||||
"""
|
||||
async with SessionLocal() as session:
|
||||
@@ -139,7 +139,7 @@ async def async_transaction_scope() -> AsyncGenerator[AsyncSession, None]:
|
||||
logger.debug("Async transaction committed successfully")
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.error(f"Async transaction failed, rolling back: {str(e)}")
|
||||
logger.error("Async transaction failed, rolling back: %s", e)
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
@@ -155,7 +155,7 @@ async def check_async_database_health() -> bool:
|
||||
await db.execute(text("SELECT 1"))
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Async database health check failed: {str(e)}")
|
||||
logger.error("Async database health check failed: %s", e)
|
||||
return False
|
||||
|
||||
|
||||
|
||||
366
backend/app/core/demo_data.json
Normal file
366
backend/app/core/demo_data.json
Normal file
@@ -0,0 +1,366 @@
|
||||
{
|
||||
"organizations": [
|
||||
{
|
||||
"name": "Acme Corp",
|
||||
"slug": "acme-corp",
|
||||
"description": "A leading provider of coyote-catching equipment."
|
||||
},
|
||||
{
|
||||
"name": "Globex Corporation",
|
||||
"slug": "globex",
|
||||
"description": "We own the East Coast."
|
||||
},
|
||||
{
|
||||
"name": "Soylent Corp",
|
||||
"slug": "soylent",
|
||||
"description": "Making food for the future."
|
||||
},
|
||||
{
|
||||
"name": "Initech",
|
||||
"slug": "initech",
|
||||
"description": "Software for the soul."
|
||||
},
|
||||
{
|
||||
"name": "Umbrella Corporation",
|
||||
"slug": "umbrella",
|
||||
"description": "Our business is life itself."
|
||||
},
|
||||
{
|
||||
"name": "Massive Dynamic",
|
||||
"slug": "massive-dynamic",
|
||||
"description": "What don't we do?"
|
||||
}
|
||||
],
|
||||
"users": [
|
||||
{
|
||||
"email": "demo@example.com",
|
||||
"password": "DemoPass1234!",
|
||||
"first_name": "Demo",
|
||||
"last_name": "User",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "acme-corp",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "alice@acme.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Alice",
|
||||
"last_name": "Smith",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "acme-corp",
|
||||
"role": "admin",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "bob@acme.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Bob",
|
||||
"last_name": "Jones",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "acme-corp",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "charlie@acme.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Charlie",
|
||||
"last_name": "Brown",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "acme-corp",
|
||||
"role": "member",
|
||||
"is_active": false
|
||||
},
|
||||
{
|
||||
"email": "diana@acme.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Diana",
|
||||
"last_name": "Prince",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "acme-corp",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "carol@globex.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Carol",
|
||||
"last_name": "Williams",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "globex",
|
||||
"role": "owner",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "dan@globex.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Dan",
|
||||
"last_name": "Miller",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "globex",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "ellen@globex.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Ellen",
|
||||
"last_name": "Ripley",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "globex",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "fred@globex.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Fred",
|
||||
"last_name": "Flintstone",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "globex",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "dave@soylent.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Dave",
|
||||
"last_name": "Brown",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "soylent",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "gina@soylent.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Gina",
|
||||
"last_name": "Torres",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "soylent",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "harry@soylent.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Harry",
|
||||
"last_name": "Potter",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "soylent",
|
||||
"role": "admin",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "eve@initech.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Eve",
|
||||
"last_name": "Davis",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "initech",
|
||||
"role": "admin",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "iris@initech.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Iris",
|
||||
"last_name": "West",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "initech",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "jack@initech.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Jack",
|
||||
"last_name": "Sparrow",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "initech",
|
||||
"role": "member",
|
||||
"is_active": false
|
||||
},
|
||||
{
|
||||
"email": "frank@umbrella.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Frank",
|
||||
"last_name": "Miller",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "umbrella",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "george@umbrella.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "George",
|
||||
"last_name": "Costanza",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "umbrella",
|
||||
"role": "member",
|
||||
"is_active": false
|
||||
},
|
||||
{
|
||||
"email": "kate@umbrella.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Kate",
|
||||
"last_name": "Bishop",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "umbrella",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "leo@massive.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Leo",
|
||||
"last_name": "Messi",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "massive-dynamic",
|
||||
"role": "owner",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "mary@massive.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Mary",
|
||||
"last_name": "Jane",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "massive-dynamic",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "nathan@massive.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Nathan",
|
||||
"last_name": "Drake",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "massive-dynamic",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "olivia@massive.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Olivia",
|
||||
"last_name": "Dunham",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "massive-dynamic",
|
||||
"role": "admin",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "peter@massive.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Peter",
|
||||
"last_name": "Parker",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "massive-dynamic",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "quinn@massive.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Quinn",
|
||||
"last_name": "Mallory",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "massive-dynamic",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "grace@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Grace",
|
||||
"last_name": "Hopper",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "heidi@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Heidi",
|
||||
"last_name": "Klum",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "ivan@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Ivan",
|
||||
"last_name": "Drago",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": false
|
||||
},
|
||||
{
|
||||
"email": "rachel@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Rachel",
|
||||
"last_name": "Green",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "sam@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Sam",
|
||||
"last_name": "Wilson",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "tony@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Tony",
|
||||
"last_name": "Stark",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "una@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Una",
|
||||
"last_name": "Chin-Riley",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": false
|
||||
},
|
||||
{
|
||||
"email": "victor@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Victor",
|
||||
"last_name": "Von Doom",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "wanda@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Wanda",
|
||||
"last_name": "Maximoff",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": true
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -1,8 +1,8 @@
|
||||
"""
|
||||
Custom exceptions and global exception handlers for the API.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, Union
|
||||
|
||||
from fastapi import HTTPException, Request, status
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
@@ -27,17 +27,13 @@ class APIException(HTTPException):
|
||||
status_code: int,
|
||||
error_code: ErrorCode,
|
||||
message: str,
|
||||
field: Optional[str] = None,
|
||||
headers: Optional[dict] = None
|
||||
field: str | None = None,
|
||||
headers: dict | None = None,
|
||||
):
|
||||
self.error_code = error_code
|
||||
self.field = field
|
||||
self.message = message
|
||||
super().__init__(
|
||||
status_code=status_code,
|
||||
detail=message,
|
||||
headers=headers
|
||||
)
|
||||
super().__init__(status_code=status_code, detail=message, headers=headers)
|
||||
|
||||
|
||||
class AuthenticationError(APIException):
|
||||
@@ -47,14 +43,14 @@ class AuthenticationError(APIException):
|
||||
self,
|
||||
message: str = "Authentication failed",
|
||||
error_code: ErrorCode = ErrorCode.INVALID_CREDENTIALS,
|
||||
field: Optional[str] = None
|
||||
field: str | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
error_code=error_code,
|
||||
message=message,
|
||||
field=field,
|
||||
headers={"WWW-Authenticate": "Bearer"}
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
|
||||
@@ -64,12 +60,12 @@ class AuthorizationError(APIException):
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Insufficient permissions",
|
||||
error_code: ErrorCode = ErrorCode.INSUFFICIENT_PERMISSIONS
|
||||
error_code: ErrorCode = ErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
error_code=error_code,
|
||||
message=message
|
||||
message=message,
|
||||
)
|
||||
|
||||
|
||||
@@ -79,12 +75,12 @@ class NotFoundError(APIException):
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Resource not found",
|
||||
error_code: ErrorCode = ErrorCode.NOT_FOUND
|
||||
error_code: ErrorCode = ErrorCode.NOT_FOUND,
|
||||
):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
error_code=error_code,
|
||||
message=message
|
||||
message=message,
|
||||
)
|
||||
|
||||
|
||||
@@ -95,13 +91,13 @@ class DuplicateError(APIException):
|
||||
self,
|
||||
message: str = "Resource already exists",
|
||||
error_code: ErrorCode = ErrorCode.DUPLICATE_ENTRY,
|
||||
field: Optional[str] = None
|
||||
field: str | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
error_code=error_code,
|
||||
message=message,
|
||||
field=field
|
||||
field=field,
|
||||
)
|
||||
|
||||
|
||||
@@ -112,13 +108,13 @@ class ValidationException(APIException):
|
||||
self,
|
||||
message: str = "Validation error",
|
||||
error_code: ErrorCode = ErrorCode.VALIDATION_ERROR,
|
||||
field: Optional[str] = None
|
||||
field: str | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
error_code=error_code,
|
||||
message=message,
|
||||
field=field
|
||||
field=field,
|
||||
)
|
||||
|
||||
|
||||
@@ -128,12 +124,12 @@ class DatabaseError(APIException):
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Database operation failed",
|
||||
error_code: ErrorCode = ErrorCode.DATABASE_ERROR
|
||||
error_code: ErrorCode = ErrorCode.DATABASE_ERROR,
|
||||
):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
error_code=error_code,
|
||||
message=message
|
||||
message=message,
|
||||
)
|
||||
|
||||
|
||||
@@ -147,28 +143,26 @@ async def api_exception_handler(request: Request, exc: APIException) -> JSONResp
|
||||
Returns a standardized error response with error code and message.
|
||||
"""
|
||||
logger.warning(
|
||||
f"API exception: {exc.error_code} - {exc.message} "
|
||||
f"(status: {exc.status_code}, path: {request.url.path})"
|
||||
"API exception: %s - %s (status: %s, path: %s)",
|
||||
exc.error_code,
|
||||
exc.message,
|
||||
exc.status_code,
|
||||
request.url.path,
|
||||
)
|
||||
|
||||
error_response = ErrorResponse(
|
||||
errors=[ErrorDetail(
|
||||
code=exc.error_code,
|
||||
message=exc.message,
|
||||
field=exc.field
|
||||
)]
|
||||
errors=[ErrorDetail(code=exc.error_code, message=exc.message, field=exc.field)]
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content=error_response.model_dump(),
|
||||
headers=exc.headers
|
||||
headers=exc.headers,
|
||||
)
|
||||
|
||||
|
||||
async def validation_exception_handler(
|
||||
request: Request,
|
||||
exc: Union[RequestValidationError, ValidationError]
|
||||
request: Request, exc: RequestValidationError | ValidationError
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
Handler for Pydantic validation errors.
|
||||
@@ -189,22 +183,21 @@ async def validation_exception_handler(
|
||||
# Skip 'body' or 'query' prefix in location
|
||||
field = ".".join(str(x) for x in error["loc"][1:])
|
||||
|
||||
errors.append(ErrorDetail(
|
||||
code=ErrorCode.VALIDATION_ERROR,
|
||||
message=error["msg"],
|
||||
field=field
|
||||
))
|
||||
errors.append(
|
||||
ErrorDetail(
|
||||
code=ErrorCode.VALIDATION_ERROR, message=error["msg"], field=field
|
||||
)
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
f"Validation error: {len(errors)} errors "
|
||||
f"(path: {request.url.path})"
|
||||
"Validation error: %s errors (path: %s)", len(errors), request.url.path
|
||||
)
|
||||
|
||||
error_response = ErrorResponse(errors=errors)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
content=error_response.model_dump()
|
||||
content=error_response.model_dump(),
|
||||
)
|
||||
|
||||
|
||||
@@ -226,26 +219,24 @@ async def http_exception_handler(request: Request, exc: HTTPException) -> JSONRe
|
||||
}
|
||||
|
||||
error_code = status_code_to_error_code.get(
|
||||
exc.status_code,
|
||||
ErrorCode.INTERNAL_ERROR
|
||||
exc.status_code, ErrorCode.INTERNAL_ERROR
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
f"HTTP exception: {exc.status_code} - {exc.detail} "
|
||||
f"(path: {request.url.path})"
|
||||
"HTTP exception: %s - %s (path: %s)",
|
||||
exc.status_code,
|
||||
exc.detail,
|
||||
request.url.path,
|
||||
)
|
||||
|
||||
error_response = ErrorResponse(
|
||||
errors=[ErrorDetail(
|
||||
code=error_code,
|
||||
message=str(exc.detail)
|
||||
)]
|
||||
errors=[ErrorDetail(code=error_code, message=str(exc.detail), field=None)]
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content=error_response.model_dump(),
|
||||
headers=exc.headers
|
||||
headers=exc.headers,
|
||||
)
|
||||
|
||||
|
||||
@@ -256,27 +247,26 @@ async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONR
|
||||
Logs the full exception and returns a generic error response to avoid
|
||||
leaking sensitive information in production.
|
||||
"""
|
||||
logger.error(
|
||||
f"Unhandled exception: {type(exc).__name__} - {str(exc)} "
|
||||
f"(path: {request.url.path})",
|
||||
exc_info=True
|
||||
logger.exception(
|
||||
"Unhandled exception: %s - %s (path: %s)",
|
||||
type(exc).__name__,
|
||||
exc,
|
||||
request.url.path,
|
||||
)
|
||||
|
||||
# In production, don't expose internal error details
|
||||
from app.core.config import settings
|
||||
|
||||
if settings.ENVIRONMENT == "production":
|
||||
message = "An internal error occurred. Please try again later."
|
||||
else:
|
||||
message = f"{type(exc).__name__}: {str(exc)}"
|
||||
message = f"{type(exc).__name__}: {exc!s}"
|
||||
|
||||
error_response = ErrorResponse(
|
||||
errors=[ErrorDetail(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message=message
|
||||
)]
|
||||
errors=[ErrorDetail(code=ErrorCode.INTERNAL_ERROR, message=message, field=None)]
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content=error_response.model_dump()
|
||||
content=error_response.model_dump(),
|
||||
)
|
||||
|
||||
26
backend/app/core/repository_exceptions.py
Normal file
26
backend/app/core/repository_exceptions.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""
|
||||
Custom exceptions for the repository layer.
|
||||
|
||||
These exceptions allow services and routes to handle database-level errors
|
||||
with proper semantics, without leaking SQLAlchemy internals.
|
||||
"""
|
||||
|
||||
|
||||
class RepositoryError(Exception):
|
||||
"""Base for all repository-layer errors."""
|
||||
|
||||
|
||||
class DuplicateEntryError(RepositoryError):
|
||||
"""Raised on unique constraint violations. Maps to HTTP 409 Conflict."""
|
||||
|
||||
|
||||
class IntegrityConstraintError(RepositoryError):
|
||||
"""Raised on FK or check constraint violations."""
|
||||
|
||||
|
||||
class RecordNotFoundError(RepositoryError):
|
||||
"""Raised when an expected record doesn't exist."""
|
||||
|
||||
|
||||
class InvalidInputError(RepositoryError):
|
||||
"""Raised on bad pagination params, invalid UUIDs, or other invalid inputs."""
|
||||
@@ -1,6 +0,0 @@
|
||||
# app/crud/__init__.py
|
||||
from .organization import organization
|
||||
from .session import session as session_crud
|
||||
from .user import user
|
||||
|
||||
__all__ = ["user", "session_crud", "organization"]
|
||||
@@ -1,478 +0,0 @@
|
||||
"""
|
||||
Async CRUD operations for user sessions using SQLAlchemy 2.0 patterns.
|
||||
"""
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_, select, update, delete, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from app.crud.base import CRUDBase
|
||||
from app.models.user_session import UserSession
|
||||
from app.schemas.sessions import SessionCreate, SessionUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
"""Async CRUD operations for user sessions."""
|
||||
|
||||
async def get_by_jti(self, db: AsyncSession, *, jti: str) -> Optional[UserSession]:
|
||||
"""
|
||||
Get session by refresh token JTI.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
jti: Refresh token JWT ID
|
||||
|
||||
Returns:
|
||||
UserSession if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(UserSession).where(UserSession.refresh_token_jti == jti)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting session by JTI {jti}: {str(e)}")
|
||||
raise
|
||||
|
||||
async def get_active_by_jti(self, db: AsyncSession, *, jti: str) -> Optional[UserSession]:
|
||||
"""
|
||||
Get active session by refresh token JTI.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
jti: Refresh token JWT ID
|
||||
|
||||
Returns:
|
||||
Active UserSession if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(UserSession).where(
|
||||
and_(
|
||||
UserSession.refresh_token_jti == jti,
|
||||
UserSession.is_active == True
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting active session by JTI {jti}: {str(e)}")
|
||||
raise
|
||||
|
||||
async def get_user_sessions(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: str,
|
||||
active_only: bool = True,
|
||||
with_user: bool = False
|
||||
) -> List[UserSession]:
|
||||
"""
|
||||
Get all sessions for a user with optional eager loading.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_id: User ID
|
||||
active_only: If True, return only active sessions
|
||||
with_user: If True, eager load user relationship to prevent N+1
|
||||
|
||||
Returns:
|
||||
List of UserSession objects
|
||||
"""
|
||||
try:
|
||||
# Convert user_id string to UUID if needed
|
||||
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||
|
||||
query = select(UserSession).where(UserSession.user_id == user_uuid)
|
||||
|
||||
# Add eager loading if requested to prevent N+1 queries
|
||||
if with_user:
|
||||
query = query.options(joinedload(UserSession.user))
|
||||
|
||||
if active_only:
|
||||
query = query.where(UserSession.is_active == True)
|
||||
|
||||
query = query.order_by(UserSession.last_used_at.desc())
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting sessions for user {user_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
async def create_session(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
obj_in: SessionCreate
|
||||
) -> UserSession:
|
||||
"""
|
||||
Create a new user session.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
obj_in: SessionCreate schema with session data
|
||||
|
||||
Returns:
|
||||
Created UserSession
|
||||
|
||||
Raises:
|
||||
ValueError: If session creation fails
|
||||
"""
|
||||
try:
|
||||
db_obj = UserSession(
|
||||
user_id=obj_in.user_id,
|
||||
refresh_token_jti=obj_in.refresh_token_jti,
|
||||
device_name=obj_in.device_name,
|
||||
device_id=obj_in.device_id,
|
||||
ip_address=obj_in.ip_address,
|
||||
user_agent=obj_in.user_agent,
|
||||
last_used_at=obj_in.last_used_at,
|
||||
expires_at=obj_in.expires_at,
|
||||
is_active=True,
|
||||
location_city=obj_in.location_city,
|
||||
location_country=obj_in.location_country,
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
|
||||
logger.info(
|
||||
f"Session created for user {obj_in.user_id} from {obj_in.device_name} "
|
||||
f"(IP: {obj_in.ip_address})"
|
||||
)
|
||||
|
||||
return db_obj
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error creating session: {str(e)}", exc_info=True)
|
||||
raise ValueError(f"Failed to create session: {str(e)}")
|
||||
|
||||
async def deactivate(self, db: AsyncSession, *, session_id: str) -> Optional[UserSession]:
|
||||
"""
|
||||
Deactivate a session (logout from device).
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
session_id: Session UUID
|
||||
|
||||
Returns:
|
||||
Deactivated UserSession if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
session = await self.get(db, id=session_id)
|
||||
if not session:
|
||||
logger.warning(f"Session {session_id} not found for deactivation")
|
||||
return None
|
||||
|
||||
session.is_active = False
|
||||
db.add(session)
|
||||
await db.commit()
|
||||
await db.refresh(session)
|
||||
|
||||
logger.info(
|
||||
f"Session {session_id} deactivated for user {session.user_id} "
|
||||
f"({session.device_name})"
|
||||
)
|
||||
|
||||
return session
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error deactivating session {session_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
async def deactivate_all_user_sessions(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: str
|
||||
) -> int:
|
||||
"""
|
||||
Deactivate all active sessions for a user (logout from all devices).
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
Number of sessions deactivated
|
||||
"""
|
||||
try:
|
||||
# Convert user_id string to UUID if needed
|
||||
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||
|
||||
stmt = (
|
||||
update(UserSession)
|
||||
.where(
|
||||
and_(
|
||||
UserSession.user_id == user_uuid,
|
||||
UserSession.is_active == True
|
||||
)
|
||||
)
|
||||
.values(is_active=False)
|
||||
)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
await db.commit()
|
||||
|
||||
count = result.rowcount
|
||||
|
||||
logger.info(f"Deactivated {count} sessions for user {user_id}")
|
||||
|
||||
return count
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error deactivating all sessions for user {user_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
async def update_last_used(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
session: UserSession
|
||||
) -> UserSession:
|
||||
"""
|
||||
Update the last_used_at timestamp for a session.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
session: UserSession object
|
||||
|
||||
Returns:
|
||||
Updated UserSession
|
||||
"""
|
||||
try:
|
||||
session.last_used_at = datetime.now(timezone.utc)
|
||||
db.add(session)
|
||||
await db.commit()
|
||||
await db.refresh(session)
|
||||
return session
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error updating last_used for session {session.id}: {str(e)}")
|
||||
raise
|
||||
|
||||
async def update_refresh_token(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
session: UserSession,
|
||||
new_jti: str,
|
||||
new_expires_at: datetime
|
||||
) -> UserSession:
|
||||
"""
|
||||
Update session with new refresh token JTI and expiration.
|
||||
|
||||
Called during token refresh.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
session: UserSession object
|
||||
new_jti: New refresh token JTI
|
||||
new_expires_at: New expiration datetime
|
||||
|
||||
Returns:
|
||||
Updated UserSession
|
||||
"""
|
||||
try:
|
||||
session.refresh_token_jti = new_jti
|
||||
session.expires_at = new_expires_at
|
||||
session.last_used_at = datetime.now(timezone.utc)
|
||||
db.add(session)
|
||||
await db.commit()
|
||||
await db.refresh(session)
|
||||
return session
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error updating refresh token for session {session.id}: {str(e)}")
|
||||
raise
|
||||
|
||||
async def cleanup_expired(self, db: AsyncSession, *, keep_days: int = 30) -> int:
|
||||
"""
|
||||
Clean up expired sessions using optimized bulk DELETE.
|
||||
|
||||
Deletes sessions that are:
|
||||
- Expired AND inactive
|
||||
- Older than keep_days
|
||||
|
||||
Uses single DELETE query instead of N individual deletes for efficiency.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
keep_days: Keep inactive sessions for this many days (for audit)
|
||||
|
||||
Returns:
|
||||
Number of sessions deleted
|
||||
"""
|
||||
try:
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=keep_days)
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Use bulk DELETE with WHERE clause - single query
|
||||
stmt = delete(UserSession).where(
|
||||
and_(
|
||||
UserSession.is_active == False,
|
||||
UserSession.expires_at < now,
|
||||
UserSession.created_at < cutoff_date
|
||||
)
|
||||
)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
await db.commit()
|
||||
|
||||
count = result.rowcount
|
||||
|
||||
if count > 0:
|
||||
logger.info(f"Cleaned up {count} expired sessions using bulk DELETE")
|
||||
|
||||
return count
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error cleaning up expired sessions: {str(e)}")
|
||||
raise
|
||||
|
||||
async def cleanup_expired_for_user(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: str
|
||||
) -> int:
|
||||
"""
|
||||
Clean up expired and inactive sessions for a specific user.
|
||||
|
||||
Uses single bulk DELETE query for efficiency instead of N individual deletes.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_id: User ID to cleanup sessions for
|
||||
|
||||
Returns:
|
||||
Number of sessions deleted
|
||||
"""
|
||||
try:
|
||||
# Validate UUID
|
||||
try:
|
||||
uuid_obj = uuid.UUID(user_id)
|
||||
except (ValueError, AttributeError):
|
||||
logger.error(f"Invalid UUID format: {user_id}")
|
||||
raise ValueError(f"Invalid user ID format: {user_id}")
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Use bulk DELETE with WHERE clause - single query
|
||||
stmt = delete(UserSession).where(
|
||||
and_(
|
||||
UserSession.user_id == uuid_obj,
|
||||
UserSession.is_active == False,
|
||||
UserSession.expires_at < now
|
||||
)
|
||||
)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
await db.commit()
|
||||
|
||||
count = result.rowcount
|
||||
|
||||
if count > 0:
|
||||
logger.info(
|
||||
f"Cleaned up {count} expired sessions for user {user_id} using bulk DELETE"
|
||||
)
|
||||
|
||||
return count
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Error cleaning up expired sessions for user {user_id}: {str(e)}"
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_user_session_count(self, db: AsyncSession, *, user_id: str) -> int:
|
||||
"""
|
||||
Get count of active sessions for a user.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
Number of active sessions
|
||||
"""
|
||||
try:
|
||||
# Convert user_id string to UUID if needed
|
||||
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||
|
||||
result = await db.execute(
|
||||
select(func.count(UserSession.id)).where(
|
||||
and_(
|
||||
UserSession.user_id == user_uuid,
|
||||
UserSession.is_active == True
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalar_one()
|
||||
except Exception as e:
|
||||
logger.error(f"Error counting sessions for user {user_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
async def get_all_sessions(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
active_only: bool = True,
|
||||
with_user: bool = True
|
||||
) -> tuple[List[UserSession], int]:
|
||||
"""
|
||||
Get all sessions across all users with pagination (admin only).
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
skip: Number of records to skip
|
||||
limit: Maximum number of records to return
|
||||
active_only: If True, return only active sessions
|
||||
with_user: If True, eager load user relationship to prevent N+1
|
||||
|
||||
Returns:
|
||||
Tuple of (list of UserSession objects, total count)
|
||||
"""
|
||||
try:
|
||||
# Build query
|
||||
query = select(UserSession)
|
||||
|
||||
# Add eager loading if requested to prevent N+1 queries
|
||||
if with_user:
|
||||
query = query.options(joinedload(UserSession.user))
|
||||
|
||||
if active_only:
|
||||
query = query.where(UserSession.is_active == True)
|
||||
|
||||
# Get total count
|
||||
count_query = select(func.count(UserSession.id))
|
||||
if active_only:
|
||||
count_query = count_query.where(UserSession.is_active == True)
|
||||
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# Apply pagination and ordering
|
||||
query = query.order_by(UserSession.last_used_at.desc()).offset(skip).limit(limit)
|
||||
|
||||
result = await db.execute(query)
|
||||
sessions = list(result.scalars().all())
|
||||
|
||||
return sessions, total
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting all sessions: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
# Create singleton instance
|
||||
session = CRUDSession(UserSession)
|
||||
@@ -4,20 +4,28 @@ Async database initialization script.
|
||||
|
||||
Creates the first superuser if configured and doesn't already exist.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional
|
||||
import random
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy import select, text
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.database import SessionLocal, engine
|
||||
from app.crud.user import user as user_crud
|
||||
from app.models.organization import Organization
|
||||
from app.models.user import User
|
||||
from app.models.user_organization import UserOrganization
|
||||
from app.repositories.user import user_repo as user_repo
|
||||
from app.schemas.users import UserCreate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def init_db() -> Optional[User]:
|
||||
async def init_db() -> User | None:
|
||||
"""
|
||||
Initialize database with first superuser if settings are configured and user doesn't exist.
|
||||
|
||||
@@ -26,21 +34,27 @@ async def init_db() -> Optional[User]:
|
||||
"""
|
||||
# Use default values if not set in environment variables
|
||||
superuser_email = settings.FIRST_SUPERUSER_EMAIL or "admin@example.com"
|
||||
superuser_password = settings.FIRST_SUPERUSER_PASSWORD or "AdminPassword123!"
|
||||
|
||||
default_password = "AdminPassword123!"
|
||||
if settings.DEMO_MODE:
|
||||
default_password = "AdminPass1234!"
|
||||
|
||||
superuser_password = settings.FIRST_SUPERUSER_PASSWORD or default_password
|
||||
|
||||
if not settings.FIRST_SUPERUSER_EMAIL or not settings.FIRST_SUPERUSER_PASSWORD:
|
||||
logger.warning(
|
||||
"First superuser credentials not configured in settings. "
|
||||
f"Using defaults: {superuser_email}"
|
||||
"Using defaults: %s",
|
||||
superuser_email,
|
||||
)
|
||||
|
||||
async with SessionLocal() as session:
|
||||
try:
|
||||
# Check if superuser already exists
|
||||
existing_user = await user_crud.get_by_email(session, email=superuser_email)
|
||||
existing_user = await user_repo.get_by_email(session, email=superuser_email)
|
||||
|
||||
if existing_user:
|
||||
logger.info(f"Superuser already exists: {existing_user.email}")
|
||||
logger.info("Superuser already exists: %s", existing_user.email)
|
||||
return existing_user
|
||||
|
||||
# Create superuser if doesn't exist
|
||||
@@ -49,34 +63,158 @@ async def init_db() -> Optional[User]:
|
||||
password=superuser_password,
|
||||
first_name="Admin",
|
||||
last_name="User",
|
||||
is_superuser=True
|
||||
is_superuser=True,
|
||||
)
|
||||
|
||||
user = await user_crud.create(session, obj_in=user_in)
|
||||
user = await user_repo.create(session, obj_in=user_in)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
|
||||
logger.info(f"Created first superuser: {user.email}")
|
||||
logger.info("Created first superuser: %s", user.email)
|
||||
|
||||
# Create demo data if in demo mode
|
||||
if settings.DEMO_MODE:
|
||||
await load_demo_data(session)
|
||||
|
||||
return user
|
||||
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.error(f"Error initializing database: {e}")
|
||||
logger.error("Error initializing database: %s", e)
|
||||
raise
|
||||
|
||||
|
||||
def _load_json_file(path: Path):
|
||||
with open(path) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
async def load_demo_data(session):
|
||||
"""Load demo data from JSON file."""
|
||||
demo_data_path = Path(__file__).parent / "core" / "demo_data.json"
|
||||
if not demo_data_path.exists():
|
||||
logger.warning("Demo data file not found: %s", demo_data_path)
|
||||
return
|
||||
|
||||
try:
|
||||
# Use asyncio.to_thread to avoid blocking the event loop
|
||||
data = await asyncio.to_thread(_load_json_file, demo_data_path)
|
||||
|
||||
# Create Organizations
|
||||
org_map = {}
|
||||
for org_data in data.get("organizations", []):
|
||||
# Check if org exists
|
||||
result = await session.execute(
|
||||
text("SELECT * FROM organizations WHERE slug = :slug"),
|
||||
{"slug": org_data["slug"]},
|
||||
)
|
||||
existing_org = result.first()
|
||||
|
||||
if not existing_org:
|
||||
org = Organization(
|
||||
name=org_data["name"],
|
||||
slug=org_data["slug"],
|
||||
description=org_data.get("description"),
|
||||
is_active=True,
|
||||
)
|
||||
session.add(org)
|
||||
await session.flush() # Flush to get ID
|
||||
org_map[org.slug] = org
|
||||
logger.info("Created demo organization: %s", org.name)
|
||||
else:
|
||||
# We can't easily get the ORM object from raw SQL result for map without querying again or mapping
|
||||
# So let's just query it properly if we need it for relationships
|
||||
# But for simplicity in this script, let's just assume we created it or it exists.
|
||||
# To properly map for users, we need the ID.
|
||||
# Let's use a simpler approach: just try to create, if slug conflict, skip.
|
||||
pass
|
||||
|
||||
# Re-query all orgs to build map for users
|
||||
result = await session.execute(select(Organization))
|
||||
orgs = result.scalars().all()
|
||||
org_map = {org.slug: org for org in orgs}
|
||||
|
||||
# Create Users
|
||||
for user_data in data.get("users", []):
|
||||
existing_user = await user_repo.get_by_email(
|
||||
session, email=user_data["email"]
|
||||
)
|
||||
if not existing_user:
|
||||
# Create user
|
||||
user_in = UserCreate(
|
||||
email=user_data["email"],
|
||||
password=user_data["password"],
|
||||
first_name=user_data["first_name"],
|
||||
last_name=user_data["last_name"],
|
||||
is_superuser=user_data["is_superuser"],
|
||||
is_active=user_data.get("is_active", True),
|
||||
)
|
||||
user = await user_repo.create(session, obj_in=user_in)
|
||||
|
||||
# Randomize created_at for demo data (last 30 days)
|
||||
# This makes the charts look more realistic
|
||||
days_ago = random.randint(0, 30) # noqa: S311
|
||||
random_time = datetime.now(UTC) - timedelta(days=days_ago)
|
||||
# Add some random hours/minutes variation
|
||||
random_time = random_time.replace(
|
||||
hour=random.randint(0, 23), # noqa: S311
|
||||
minute=random.randint(0, 59), # noqa: S311
|
||||
)
|
||||
|
||||
# Update the timestamp and is_active directly in the database
|
||||
# We do this to ensure the values are persisted correctly
|
||||
await session.execute(
|
||||
text(
|
||||
"UPDATE users SET created_at = :created_at, is_active = :is_active WHERE id = :user_id"
|
||||
),
|
||||
{
|
||||
"created_at": random_time,
|
||||
"is_active": user_data.get("is_active", True),
|
||||
"user_id": user.id,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Created demo user: %s (created %s days ago, active=%s)",
|
||||
user.email,
|
||||
days_ago,
|
||||
user_data.get("is_active", True),
|
||||
)
|
||||
|
||||
# Add to organization if specified
|
||||
org_slug = user_data.get("organization_slug")
|
||||
role = user_data.get("role")
|
||||
if org_slug and org_slug in org_map and role:
|
||||
org = org_map[org_slug]
|
||||
# Check if membership exists (it shouldn't for new user)
|
||||
member = UserOrganization(
|
||||
user_id=user.id, organization_id=org.id, role=role
|
||||
)
|
||||
session.add(member)
|
||||
logger.info("Added %s to %s as %s", user.email, org.name, role)
|
||||
else:
|
||||
logger.info("Demo user already exists: %s", existing_user.email)
|
||||
|
||||
await session.commit()
|
||||
logger.info("Demo data loaded successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error loading demo data: %s", e)
|
||||
raise
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main entry point for database initialization."""
|
||||
# Configure logging to show info logs
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
|
||||
try:
|
||||
user = await init_db()
|
||||
if user:
|
||||
print(f"✓ Database initialized successfully")
|
||||
print("✓ Database initialized successfully")
|
||||
print(f"✓ Superuser: {user.email}")
|
||||
else:
|
||||
print("✗ Failed to initialize database")
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import logging
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from fastapi import FastAPI, status, Request, HTTPException
|
||||
from fastapi import FastAPI, HTTPException, Request, status
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import HTMLResponse, JSONResponse
|
||||
@@ -14,14 +14,15 @@ from slowapi.errors import RateLimitExceeded
|
||||
from slowapi.util import get_remote_address
|
||||
|
||||
from app.api.main import api_router
|
||||
from app.api.routes.oauth_provider import wellknown_router as oauth_wellknown_router
|
||||
from app.core.config import settings
|
||||
from app.core.database import check_database_health
|
||||
from app.core.database import check_database_health, close_async_db
|
||||
from app.core.exceptions import (
|
||||
APIException,
|
||||
api_exception_handler,
|
||||
validation_exception_handler,
|
||||
http_exception_handler,
|
||||
unhandled_exception_handler
|
||||
unhandled_exception_handler,
|
||||
validation_exception_handler,
|
||||
)
|
||||
|
||||
scheduler = AsyncIOScheduler()
|
||||
@@ -52,11 +53,11 @@ async def lifespan(app: FastAPI):
|
||||
# Runs daily at 2:00 AM server time
|
||||
scheduler.add_job(
|
||||
cleanup_expired_sessions,
|
||||
'cron',
|
||||
"cron",
|
||||
hour=2,
|
||||
minute=0,
|
||||
id='cleanup_expired_sessions',
|
||||
replace_existing=True
|
||||
id="cleanup_expired_sessions",
|
||||
replace_existing=True,
|
||||
)
|
||||
|
||||
scheduler.start()
|
||||
@@ -71,14 +72,15 @@ async def lifespan(app: FastAPI):
|
||||
if os.getenv("IS_TEST", "False") != "True":
|
||||
scheduler.shutdown()
|
||||
logger.info("Scheduled jobs stopped")
|
||||
await close_async_db()
|
||||
|
||||
|
||||
logger.info(f"Starting app!!!")
|
||||
logger.info("Starting app!!!")
|
||||
app = FastAPI(
|
||||
title=settings.PROJECT_NAME,
|
||||
version=settings.VERSION,
|
||||
openapi_url=f"{settings.API_V1_STR}/openapi.json",
|
||||
lifespan=lifespan
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# Add rate limiter state to app
|
||||
@@ -96,7 +98,14 @@ app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.BACKEND_CORS_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"], # Explicit methods only
|
||||
allow_methods=[
|
||||
"GET",
|
||||
"POST",
|
||||
"PUT",
|
||||
"PATCH",
|
||||
"DELETE",
|
||||
"OPTIONS",
|
||||
], # Explicit methods only
|
||||
allow_headers=[
|
||||
"Content-Type",
|
||||
"Authorization",
|
||||
@@ -129,12 +138,14 @@ async def limit_request_size(request: Request, call_next):
|
||||
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
|
||||
content={
|
||||
"success": False,
|
||||
"errors": [{
|
||||
"code": "REQUEST_TOO_LARGE",
|
||||
"message": f"Request body too large. Maximum size is {MAX_REQUEST_SIZE // (1024 * 1024)}MB",
|
||||
"field": None
|
||||
}]
|
||||
}
|
||||
"errors": [
|
||||
{
|
||||
"code": "REQUEST_TOO_LARGE",
|
||||
"message": f"Request body too large. Maximum size is {MAX_REQUEST_SIZE // (1024 * 1024)}MB",
|
||||
"field": None,
|
||||
}
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
response = await call_next(request)
|
||||
@@ -165,15 +176,19 @@ async def add_security_headers(request: Request, call_next):
|
||||
|
||||
# Enforce HTTPS in production
|
||||
if settings.ENVIRONMENT == "production":
|
||||
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
|
||||
response.headers["Strict-Transport-Security"] = (
|
||||
"max-age=31536000; includeSubDomains"
|
||||
)
|
||||
|
||||
# Content Security Policy
|
||||
csp_mode = settings.CSP_MODE.lower()
|
||||
|
||||
# Special handling for API docs
|
||||
is_docs = request.url.path in ["/docs", "/redoc"] or \
|
||||
request.url.path.startswith("/docs/") or \
|
||||
request.url.path.startswith("/redoc/")
|
||||
is_docs = (
|
||||
request.url.path in ["/docs", "/redoc"]
|
||||
or request.url.path.startswith("/docs/")
|
||||
or request.url.path.startswith("/redoc/")
|
||||
)
|
||||
|
||||
if csp_mode == "disabled":
|
||||
# No CSP (only for local development/debugging)
|
||||
@@ -264,7 +279,7 @@ async def root():
|
||||
description="Check the health status of the API and its dependencies",
|
||||
response_description="Health status information",
|
||||
tags=["Health"],
|
||||
operation_id="health_check"
|
||||
operation_id="health_check",
|
||||
)
|
||||
async def health_check() -> JSONResponse:
|
||||
"""
|
||||
@@ -278,12 +293,12 @@ async def health_check() -> JSONResponse:
|
||||
- environment: Current environment (development, staging, production)
|
||||
- database: Database connectivity status
|
||||
"""
|
||||
health_status: Dict[str, Any] = {
|
||||
health_status: dict[str, Any] = {
|
||||
"status": "healthy",
|
||||
"timestamp": datetime.utcnow().isoformat() + "Z",
|
||||
"timestamp": datetime.now(UTC).isoformat().replace("+00:00", "Z"),
|
||||
"version": settings.VERSION,
|
||||
"environment": settings.ENVIRONMENT,
|
||||
"checks": {}
|
||||
"checks": {},
|
||||
}
|
||||
|
||||
response_status = status.HTTP_200_OK
|
||||
@@ -294,7 +309,7 @@ async def health_check() -> JSONResponse:
|
||||
if db_healthy:
|
||||
health_status["checks"]["database"] = {
|
||||
"status": "healthy",
|
||||
"message": "Database connection successful"
|
||||
"message": "Database connection successful",
|
||||
}
|
||||
else:
|
||||
raise Exception("Database health check returned unhealthy status")
|
||||
@@ -302,15 +317,16 @@ async def health_check() -> JSONResponse:
|
||||
health_status["status"] = "unhealthy"
|
||||
health_status["checks"]["database"] = {
|
||||
"status": "unhealthy",
|
||||
"message": f"Database connection failed: {str(e)}"
|
||||
"message": f"Database connection failed: {e!s}",
|
||||
}
|
||||
response_status = status.HTTP_503_SERVICE_UNAVAILABLE
|
||||
logger.error(f"Health check failed - database error: {e}")
|
||||
logger.error("Health check failed - database error: %s", e)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=response_status,
|
||||
content=health_status
|
||||
)
|
||||
return JSONResponse(status_code=response_status, content=health_status)
|
||||
|
||||
|
||||
app.include_router(api_router, prefix=settings.API_V1_STR)
|
||||
|
||||
# OAuth 2.0 well-known endpoint at root level per RFC 8414
|
||||
# This allows MCP clients to discover the OAuth server metadata at /.well-known/oauth-authorization-server
|
||||
app.include_router(oauth_wellknown_router)
|
||||
|
||||
@@ -2,17 +2,40 @@
|
||||
Models package initialization.
|
||||
Imports all models to ensure they're registered with SQLAlchemy.
|
||||
"""
|
||||
|
||||
# First import Base to avoid circular imports
|
||||
from app.core.database import Base
|
||||
|
||||
from .base import TimestampMixin, UUIDMixin
|
||||
|
||||
# OAuth models (client mode - authenticate via Google/GitHub)
|
||||
from .oauth_account import OAuthAccount
|
||||
|
||||
# OAuth provider models (server mode - act as authorization server for MCP)
|
||||
from .oauth_authorization_code import OAuthAuthorizationCode
|
||||
from .oauth_client import OAuthClient
|
||||
from .oauth_provider_token import OAuthConsent, OAuthProviderRefreshToken
|
||||
from .oauth_state import OAuthState
|
||||
from .organization import Organization
|
||||
|
||||
# Import models
|
||||
from .user import User
|
||||
from .user_organization import UserOrganization, OrganizationRole
|
||||
from .user_organization import OrganizationRole, UserOrganization
|
||||
from .user_session import UserSession
|
||||
|
||||
__all__ = [
|
||||
'Base', 'TimestampMixin', 'UUIDMixin',
|
||||
'User', 'UserSession',
|
||||
'Organization', 'UserOrganization', 'OrganizationRole',
|
||||
]
|
||||
"Base",
|
||||
"OAuthAccount",
|
||||
"OAuthAuthorizationCode",
|
||||
"OAuthClient",
|
||||
"OAuthConsent",
|
||||
"OAuthProviderRefreshToken",
|
||||
"OAuthState",
|
||||
"Organization",
|
||||
"OrganizationRole",
|
||||
"TimestampMixin",
|
||||
"UUIDMixin",
|
||||
"User",
|
||||
"UserOrganization",
|
||||
"UserSession",
|
||||
]
|
||||
|
||||
@@ -1,20 +1,28 @@
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import Column, DateTime
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
from app.core.database import Base
|
||||
from app.core.database import Base # Re-exported for other models
|
||||
|
||||
|
||||
class TimestampMixin:
|
||||
"""Mixin to add created_at and updated_at timestamps to models"""
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False)
|
||||
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc), nullable=False)
|
||||
|
||||
created_at = Column(
|
||||
DateTime(timezone=True), default=lambda: datetime.now(UTC), nullable=False
|
||||
)
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(UTC),
|
||||
onupdate=lambda: datetime.now(UTC),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
|
||||
class UUIDMixin:
|
||||
"""Mixin to add UUID primary keys to models"""
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
|
||||
55
backend/app/models/oauth_account.py
Executable file
55
backend/app/models/oauth_account.py
Executable file
@@ -0,0 +1,55 @@
|
||||
"""OAuth account model for linking external OAuth providers to users."""
|
||||
|
||||
from sqlalchemy import Column, DateTime, ForeignKey, Index, String, UniqueConstraint
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from .base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
|
||||
class OAuthAccount(Base, UUIDMixin, TimestampMixin):
|
||||
"""
|
||||
Links OAuth provider accounts to users.
|
||||
|
||||
Supports multiple OAuth providers per user (e.g., user can have both
|
||||
Google and GitHub connected). Each provider account is uniquely identified
|
||||
by (provider, provider_user_id).
|
||||
"""
|
||||
|
||||
__tablename__ = "oauth_accounts"
|
||||
|
||||
# Link to user
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# OAuth provider identification
|
||||
provider = Column(
|
||||
String(50), nullable=False, index=True
|
||||
) # google, github, microsoft
|
||||
provider_user_id = Column(String(255), nullable=False) # Provider's unique user ID
|
||||
provider_email = Column(
|
||||
String(255), nullable=True, index=True
|
||||
) # Email from provider (for reference)
|
||||
|
||||
# Optional: store provider tokens for API access
|
||||
# TODO: Encrypt these at rest in production (requires key management infrastructure)
|
||||
access_token = Column(String(2048), nullable=True)
|
||||
refresh_token = Column(String(2048), nullable=True)
|
||||
token_expires_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Relationship
|
||||
user = relationship("User", back_populates="oauth_accounts")
|
||||
|
||||
__table_args__ = (
|
||||
# Each provider account can only be linked to one user
|
||||
UniqueConstraint("provider", "provider_user_id", name="uq_oauth_provider_user"),
|
||||
# Index for finding all OAuth accounts for a user + provider
|
||||
Index("ix_oauth_accounts_user_provider", "user_id", "provider"),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<OAuthAccount {self.provider}:{self.provider_user_id}>"
|
||||
100
backend/app/models/oauth_authorization_code.py
Executable file
100
backend/app/models/oauth_authorization_code.py
Executable file
@@ -0,0 +1,100 @@
|
||||
"""OAuth authorization code model for OAuth provider mode."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Index, String
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from .base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
|
||||
class OAuthAuthorizationCode(Base, UUIDMixin, TimestampMixin):
|
||||
"""
|
||||
OAuth 2.0 Authorization Code for the authorization code flow.
|
||||
|
||||
Authorization codes are:
|
||||
- Single-use (marked as used after exchange)
|
||||
- Short-lived (10 minutes default)
|
||||
- Bound to specific client, user, redirect_uri
|
||||
- Support PKCE (code_challenge/code_challenge_method)
|
||||
|
||||
Security considerations:
|
||||
- Code must be cryptographically random (64 chars, URL-safe)
|
||||
- Must validate redirect_uri matches exactly
|
||||
- Must verify PKCE code_verifier for public clients
|
||||
- Must be consumed within expiration time
|
||||
|
||||
Performance indexes (defined in migration 0002_add_performance_indexes.py):
|
||||
- ix_perf_oauth_auth_codes_expires: expires_at WHERE used = false
|
||||
"""
|
||||
|
||||
__tablename__ = "oauth_authorization_codes"
|
||||
|
||||
# The authorization code (cryptographically random, URL-safe)
|
||||
code = Column(String(128), unique=True, nullable=False, index=True)
|
||||
|
||||
# Client that requested the code
|
||||
client_id = Column(
|
||||
String(64),
|
||||
ForeignKey("oauth_clients.client_id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# User who authorized the request
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Redirect URI (must match exactly on token exchange)
|
||||
redirect_uri = Column(String(2048), nullable=False)
|
||||
|
||||
# Granted scopes (space-separated)
|
||||
scope = Column(String(1000), nullable=False, default="")
|
||||
|
||||
# PKCE support (required for public clients)
|
||||
code_challenge = Column(String(128), nullable=True)
|
||||
code_challenge_method = Column(String(10), nullable=True) # "S256" or "plain"
|
||||
|
||||
# State parameter (for CSRF protection, returned to client)
|
||||
state = Column(String(256), nullable=True)
|
||||
|
||||
# Nonce (for OpenID Connect, included in ID token)
|
||||
nonce = Column(String(256), nullable=True)
|
||||
|
||||
# Expiration (codes are short-lived, typically 10 minutes)
|
||||
expires_at = Column(DateTime(timezone=True), nullable=False)
|
||||
|
||||
# Single-use flag (set to True after successful exchange)
|
||||
used = Column(Boolean, default=False, nullable=False)
|
||||
|
||||
# Relationships
|
||||
client = relationship("OAuthClient", backref="authorization_codes")
|
||||
user = relationship("User", backref="oauth_authorization_codes")
|
||||
|
||||
# Indexes for efficient cleanup queries
|
||||
__table_args__ = (
|
||||
Index("ix_oauth_authorization_codes_expires_at", "expires_at"),
|
||||
Index("ix_oauth_authorization_codes_client_user", "client_id", "user_id"),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<OAuthAuthorizationCode {self.code[:8]}... for {self.client_id}>"
|
||||
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if the authorization code has expired."""
|
||||
# Use timezone-aware comparison (datetime.utcnow() is deprecated)
|
||||
now = datetime.now(UTC)
|
||||
expires_at = self.expires_at
|
||||
# Handle both timezone-aware and naive datetimes from DB
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=UTC)
|
||||
return bool(now > expires_at)
|
||||
|
||||
@property
|
||||
def is_valid(self) -> bool:
|
||||
"""Check if the authorization code is valid (not used, not expired)."""
|
||||
return not self.used and not self.is_expired
|
||||
67
backend/app/models/oauth_client.py
Executable file
67
backend/app/models/oauth_client.py
Executable file
@@ -0,0 +1,67 @@
|
||||
"""OAuth client model for OAuth provider mode (MCP clients)."""
|
||||
|
||||
from sqlalchemy import Boolean, Column, ForeignKey, String
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from .base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
|
||||
class OAuthClient(Base, UUIDMixin, TimestampMixin):
|
||||
"""
|
||||
Registered OAuth clients (for OAuth provider mode).
|
||||
|
||||
This model stores third-party applications that can authenticate
|
||||
against this API using OAuth 2.0. Used for MCP (Model Context Protocol)
|
||||
client authentication and API access.
|
||||
|
||||
NOTE: This is a skeleton implementation. The full OAuth provider
|
||||
functionality (authorization endpoint, token endpoint, etc.) can be
|
||||
expanded when needed.
|
||||
"""
|
||||
|
||||
__tablename__ = "oauth_clients"
|
||||
|
||||
# Client credentials
|
||||
client_id = Column(String(64), unique=True, nullable=False, index=True)
|
||||
client_secret_hash = Column(
|
||||
String(255), nullable=True
|
||||
) # NULL for public clients (PKCE)
|
||||
|
||||
# Client metadata
|
||||
client_name = Column(String(255), nullable=False)
|
||||
client_description = Column(String(1000), nullable=True)
|
||||
|
||||
# Client type: "public" (SPA, mobile) or "confidential" (server-side)
|
||||
client_type = Column(String(20), nullable=False, default="public")
|
||||
|
||||
# Allowed redirect URIs (JSON array)
|
||||
redirect_uris = Column(JSONB, nullable=False, default=list)
|
||||
|
||||
# Allowed scopes (JSON array of scope names)
|
||||
allowed_scopes = Column(JSONB, nullable=False, default=list)
|
||||
|
||||
# Token lifetimes (in seconds)
|
||||
access_token_lifetime = Column(String(10), nullable=False, default="3600") # 1 hour
|
||||
refresh_token_lifetime = Column(
|
||||
String(10), nullable=False, default="604800"
|
||||
) # 7 days
|
||||
|
||||
# Status
|
||||
is_active = Column(Boolean, default=True, nullable=False, index=True)
|
||||
|
||||
# Optional: owner user (for user-registered applications)
|
||||
owner_user_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
# MCP-specific: URL of the MCP server this client represents
|
||||
mcp_server_url = Column(String(2048), nullable=True)
|
||||
|
||||
# Relationship
|
||||
owner = relationship("User", backref="owned_oauth_clients")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<OAuthClient {self.client_name} ({self.client_id[:8]}...)>"
|
||||
162
backend/app/models/oauth_provider_token.py
Executable file
162
backend/app/models/oauth_provider_token.py
Executable file
@@ -0,0 +1,162 @@
|
||||
"""OAuth provider token models for OAuth provider mode."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Index, String
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from .base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
|
||||
class OAuthProviderRefreshToken(Base, UUIDMixin, TimestampMixin):
|
||||
"""
|
||||
OAuth 2.0 Refresh Token for the OAuth provider.
|
||||
|
||||
Refresh tokens are:
|
||||
- Opaque (stored as hash in DB, actual token given to client)
|
||||
- Long-lived (configurable, default 30 days)
|
||||
- Revocable (via revoked flag or deletion)
|
||||
- Bound to specific client, user, and scope
|
||||
|
||||
Access tokens are JWTs and not stored in DB (self-contained).
|
||||
This model only tracks refresh tokens for revocation support.
|
||||
|
||||
Security considerations:
|
||||
- Store token hash, not plaintext
|
||||
- Support token rotation (new refresh token on use)
|
||||
- Track last used time for security auditing
|
||||
- Support revocation by user, client, or admin
|
||||
|
||||
Performance indexes (defined in migration 0002_add_performance_indexes.py):
|
||||
- ix_perf_oauth_refresh_tokens_expires: expires_at WHERE revoked = false
|
||||
"""
|
||||
|
||||
__tablename__ = "oauth_provider_refresh_tokens"
|
||||
|
||||
# Hash of the refresh token (SHA-256)
|
||||
# We store hash, not plaintext, for security
|
||||
token_hash = Column(String(64), unique=True, nullable=False, index=True)
|
||||
|
||||
# Unique token ID (JTI) - used in JWT access tokens to reference this refresh token
|
||||
jti = Column(String(64), unique=True, nullable=False, index=True)
|
||||
|
||||
# Client that owns this token
|
||||
client_id = Column(
|
||||
String(64),
|
||||
ForeignKey("oauth_clients.client_id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# User who authorized this token
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Granted scopes (space-separated)
|
||||
scope = Column(String(1000), nullable=False, default="")
|
||||
|
||||
# Token expiration
|
||||
expires_at = Column(DateTime(timezone=True), nullable=False)
|
||||
|
||||
# Revocation flag
|
||||
revoked = Column(Boolean, default=False, nullable=False, index=True)
|
||||
|
||||
# Last used timestamp (for security auditing)
|
||||
last_used_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Device/session info (optional, for user visibility)
|
||||
device_info = Column(String(500), nullable=True)
|
||||
ip_address = Column(String(45), nullable=True)
|
||||
|
||||
# Relationships
|
||||
client = relationship("OAuthClient", backref="refresh_tokens")
|
||||
user = relationship("User", backref="oauth_provider_refresh_tokens")
|
||||
|
||||
# Indexes
|
||||
__table_args__ = (
|
||||
Index("ix_oauth_provider_refresh_tokens_expires_at", "expires_at"),
|
||||
Index("ix_oauth_provider_refresh_tokens_client_user", "client_id", "user_id"),
|
||||
Index(
|
||||
"ix_oauth_provider_refresh_tokens_user_revoked",
|
||||
"user_id",
|
||||
"revoked",
|
||||
),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
status = "revoked" if self.revoked else "active"
|
||||
return f"<OAuthProviderRefreshToken {self.jti[:8]}... ({status})>"
|
||||
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if the refresh token has expired."""
|
||||
# Use timezone-aware comparison (datetime.utcnow() is deprecated)
|
||||
now = datetime.now(UTC)
|
||||
expires_at = self.expires_at
|
||||
# Handle both timezone-aware and naive datetimes from DB
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=UTC)
|
||||
return bool(now > expires_at)
|
||||
|
||||
@property
|
||||
def is_valid(self) -> bool:
|
||||
"""Check if the refresh token is valid (not revoked, not expired)."""
|
||||
return not self.revoked and not self.is_expired
|
||||
|
||||
|
||||
class OAuthConsent(Base, UUIDMixin, TimestampMixin):
|
||||
"""
|
||||
OAuth consent record - remembers user consent for a client.
|
||||
|
||||
When a user grants consent to an OAuth client, we store the record
|
||||
so they don't have to re-consent on subsequent authorizations
|
||||
(unless scopes change).
|
||||
|
||||
This enables a better UX - users only see consent screen once per client,
|
||||
unless the client requests additional scopes.
|
||||
"""
|
||||
|
||||
__tablename__ = "oauth_consents"
|
||||
|
||||
# User who granted consent
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Client that received consent
|
||||
client_id = Column(
|
||||
String(64),
|
||||
ForeignKey("oauth_clients.client_id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Granted scopes (space-separated)
|
||||
granted_scopes = Column(String(1000), nullable=False, default="")
|
||||
|
||||
# Relationships
|
||||
client = relationship("OAuthClient", backref="consents")
|
||||
user = relationship("User", backref="oauth_consents")
|
||||
|
||||
# Unique constraint: one consent record per user+client
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"ix_oauth_consents_user_client",
|
||||
"user_id",
|
||||
"client_id",
|
||||
unique=True,
|
||||
),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<OAuthConsent user={self.user_id} client={self.client_id}>"
|
||||
|
||||
def has_scopes(self, requested_scopes: list[str]) -> bool:
|
||||
"""Check if all requested scopes are already granted."""
|
||||
granted = set(self.granted_scopes.split()) if self.granted_scopes else set()
|
||||
requested = set(requested_scopes)
|
||||
return requested.issubset(granted)
|
||||
45
backend/app/models/oauth_state.py
Executable file
45
backend/app/models/oauth_state.py
Executable file
@@ -0,0 +1,45 @@
|
||||
"""OAuth state model for CSRF protection during OAuth flows."""
|
||||
|
||||
from sqlalchemy import Column, DateTime, String
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
from .base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
|
||||
class OAuthState(Base, UUIDMixin, TimestampMixin):
|
||||
"""
|
||||
Temporary storage for OAuth state parameters.
|
||||
|
||||
Prevents CSRF attacks during OAuth flows by storing a random state
|
||||
value that must match on callback. Also stores PKCE code_verifier
|
||||
for the Authorization Code flow with PKCE.
|
||||
|
||||
These records are short-lived (10 minutes by default) and should
|
||||
be deleted after use or expiration.
|
||||
"""
|
||||
|
||||
__tablename__ = "oauth_states"
|
||||
|
||||
# Random state parameter (CSRF protection)
|
||||
state = Column(String(255), unique=True, nullable=False, index=True)
|
||||
|
||||
# PKCE code_verifier (used to generate code_challenge)
|
||||
code_verifier = Column(String(128), nullable=True)
|
||||
|
||||
# OIDC nonce for ID token replay protection
|
||||
nonce = Column(String(255), nullable=True)
|
||||
|
||||
# OAuth provider (google, github, etc.)
|
||||
provider = Column(String(50), nullable=False)
|
||||
|
||||
# Original redirect URI (for callback validation)
|
||||
redirect_uri = Column(String(500), nullable=True)
|
||||
|
||||
# User ID if this is an account linking flow (user is already logged in)
|
||||
user_id = Column(UUID(as_uuid=True), nullable=True)
|
||||
|
||||
# Expiration time
|
||||
expires_at = Column(DateTime(timezone=True), nullable=False)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<OAuthState {self.state[:8]}... ({self.provider})>"
|
||||
@@ -1,5 +1,5 @@
|
||||
# app/models/organization.py
|
||||
from sqlalchemy import Column, String, Boolean, Text, Index
|
||||
from sqlalchemy import Boolean, Column, Index, String, Text
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
@@ -10,8 +10,12 @@ class Organization(Base, UUIDMixin, TimestampMixin):
|
||||
"""
|
||||
Organization model for multi-tenant support.
|
||||
Users can belong to multiple organizations with different roles.
|
||||
|
||||
Performance indexes (defined in migration 0002_add_performance_indexes.py):
|
||||
- ix_perf_organizations_slug_lower: LOWER(slug) WHERE is_active = true
|
||||
"""
|
||||
__tablename__ = 'organizations'
|
||||
|
||||
__tablename__ = "organizations"
|
||||
|
||||
name = Column(String(255), nullable=False, index=True)
|
||||
slug = Column(String(255), unique=True, nullable=False, index=True)
|
||||
@@ -20,11 +24,13 @@ class Organization(Base, UUIDMixin, TimestampMixin):
|
||||
settings = Column(JSONB, default={})
|
||||
|
||||
# Relationships
|
||||
user_organizations = relationship("UserOrganization", back_populates="organization", cascade="all, delete-orphan")
|
||||
user_organizations = relationship(
|
||||
"UserOrganization", back_populates="organization", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index('ix_organizations_name_active', 'name', 'is_active'),
|
||||
Index('ix_organizations_slug_active', 'slug', 'is_active'),
|
||||
Index("ix_organizations_name_active", "name", "is_active"),
|
||||
Index("ix_organizations_slug_active", "slug", "is_active"),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from sqlalchemy import Column, String, Boolean, DateTime
|
||||
from sqlalchemy import Boolean, Column, DateTime, String
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
@@ -6,20 +6,45 @@ from .base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
|
||||
class User(Base, UUIDMixin, TimestampMixin):
|
||||
__tablename__ = 'users'
|
||||
"""
|
||||
User model for authentication and profile data.
|
||||
|
||||
Performance indexes (defined in migration 0002_add_performance_indexes.py):
|
||||
- ix_perf_users_email_lower: LOWER(email) WHERE deleted_at IS NULL
|
||||
- ix_perf_users_active: is_active WHERE deleted_at IS NULL
|
||||
"""
|
||||
|
||||
__tablename__ = "users"
|
||||
|
||||
email = Column(String(255), unique=True, nullable=False, index=True)
|
||||
password_hash = Column(String(255), nullable=False)
|
||||
# Nullable to support OAuth-only users who never set a password
|
||||
password_hash = Column(String(255), nullable=True)
|
||||
first_name = Column(String(100), nullable=False, default="user")
|
||||
last_name = Column(String(100), nullable=True)
|
||||
phone_number = Column(String(20))
|
||||
is_active = Column(Boolean, default=True, nullable=False, index=True)
|
||||
is_superuser = Column(Boolean, default=False, nullable=False, index=True)
|
||||
preferences = Column(JSONB)
|
||||
locale = Column(String(10), nullable=True, index=True)
|
||||
deleted_at = Column(DateTime(timezone=True), nullable=True, index=True)
|
||||
|
||||
# Relationships
|
||||
user_organizations = relationship("UserOrganization", back_populates="user", cascade="all, delete-orphan")
|
||||
user_organizations = relationship(
|
||||
"UserOrganization", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
oauth_accounts = relationship(
|
||||
"OAuthAccount", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
@property
|
||||
def has_password(self) -> bool:
|
||||
"""Check if user can login with password (not OAuth-only)."""
|
||||
return self.password_hash is not None
|
||||
|
||||
@property
|
||||
def can_remove_oauth(self) -> bool:
|
||||
"""Check if user can safely remove an OAuth account link."""
|
||||
return self.has_password or len(self.oauth_accounts) > 1
|
||||
|
||||
def __repr__(self):
|
||||
return f"<User {self.email}>"
|
||||
return f"<User {self.email}>"
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# app/models/user_organization.py
|
||||
from enum import Enum as PyEnum
|
||||
|
||||
from sqlalchemy import Column, ForeignKey, Boolean, String, Index, Enum
|
||||
from sqlalchemy import Boolean, Column, Enum, ForeignKey, Index, String
|
||||
from sqlalchemy.dialects.postgresql import UUID as PGUUID
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
@@ -14,6 +14,7 @@ class OrganizationRole(str, PyEnum):
|
||||
These provide a baseline role system that can be optionally used.
|
||||
Projects can extend this or implement their own permission system.
|
||||
"""
|
||||
|
||||
OWNER = "owner" # Full control over organization
|
||||
ADMIN = "admin" # Can manage users and settings
|
||||
MEMBER = "member" # Regular member with standard access
|
||||
@@ -25,25 +26,41 @@ class UserOrganization(Base, TimestampMixin):
|
||||
Junction table for many-to-many relationship between Users and Organizations.
|
||||
Includes role information for flexible RBAC.
|
||||
"""
|
||||
__tablename__ = 'user_organizations'
|
||||
|
||||
user_id = Column(PGUUID(as_uuid=True), ForeignKey('users.id', ondelete='CASCADE'), primary_key=True)
|
||||
organization_id = Column(PGUUID(as_uuid=True), ForeignKey('organizations.id', ondelete='CASCADE'), primary_key=True)
|
||||
__tablename__ = "user_organizations"
|
||||
|
||||
role = Column(Enum(OrganizationRole), default=OrganizationRole.MEMBER, nullable=False, index=True)
|
||||
user_id = Column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
)
|
||||
organization_id = Column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("organizations.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
)
|
||||
|
||||
role: Column[OrganizationRole] = Column(
|
||||
Enum(OrganizationRole),
|
||||
default=OrganizationRole.MEMBER,
|
||||
nullable=False,
|
||||
# Note: index defined in __table_args__ as ix_user_org_role
|
||||
)
|
||||
is_active = Column(Boolean, default=True, nullable=False, index=True)
|
||||
|
||||
# Optional: Custom permissions override for specific users
|
||||
custom_permissions = Column(String(500), nullable=True) # JSON array of permission strings
|
||||
custom_permissions = Column(
|
||||
String(500), nullable=True
|
||||
) # JSON array of permission strings
|
||||
|
||||
# Relationships
|
||||
user = relationship("User", back_populates="user_organizations")
|
||||
organization = relationship("Organization", back_populates="user_organizations")
|
||||
|
||||
__table_args__ = (
|
||||
Index('ix_user_org_user_active', 'user_id', 'is_active'),
|
||||
Index('ix_user_org_org_active', 'organization_id', 'is_active'),
|
||||
Index('ix_user_org_role', 'role'),
|
||||
Index("ix_user_org_user_active", "user_id", "is_active"),
|
||||
Index("ix_user_org_org_active", "organization_id", "is_active"),
|
||||
Index("ix_user_org_role", "role"),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
|
||||
@@ -6,7 +6,10 @@ This allows users to:
|
||||
- Logout from specific devices
|
||||
- Manage their active sessions
|
||||
"""
|
||||
from sqlalchemy import Column, String, Boolean, DateTime, ForeignKey, Index
|
||||
|
||||
from datetime import UTC
|
||||
|
||||
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Index, String
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
@@ -19,20 +22,31 @@ class UserSession(Base, UUIDMixin, TimestampMixin):
|
||||
|
||||
Each time a user logs in from a device, a new session is created.
|
||||
Sessions are identified by the refresh token JTI (JWT ID).
|
||||
|
||||
Performance indexes (defined in migration 0002_add_performance_indexes.py):
|
||||
- ix_perf_user_sessions_expires: expires_at WHERE is_active = true
|
||||
"""
|
||||
__tablename__ = 'user_sessions'
|
||||
|
||||
__tablename__ = "user_sessions"
|
||||
|
||||
# Foreign key to user
|
||||
user_id = Column(UUID(as_uuid=True), ForeignKey('users.id', ondelete='CASCADE'), nullable=False, index=True)
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Refresh token identifier (JWT ID from the refresh token)
|
||||
refresh_token_jti = Column(String(255), unique=True, nullable=False, index=True)
|
||||
|
||||
# Device information
|
||||
device_name = Column(String(255), nullable=True) # "iPhone 14", "Chrome on MacBook"
|
||||
device_id = Column(String(255), nullable=True) # Persistent device identifier (from client)
|
||||
ip_address = Column(String(45), nullable=True) # IPv4 (15 chars) or IPv6 (45 chars)
|
||||
user_agent = Column(String(500), nullable=True) # Browser/app user agent
|
||||
device_id = Column(
|
||||
String(255), nullable=True
|
||||
) # Persistent device identifier (from client)
|
||||
ip_address = Column(String(45), nullable=True) # IPv4 (15 chars) or IPv6 (45 chars)
|
||||
user_agent = Column(String(500), nullable=True) # Browser/app user agent
|
||||
|
||||
# Session timing
|
||||
last_used_at = Column(DateTime(timezone=True), nullable=False)
|
||||
@@ -50,8 +64,8 @@ class UserSession(Base, UUIDMixin, TimestampMixin):
|
||||
|
||||
# Composite indexes for performance (defined in migration)
|
||||
__table_args__ = (
|
||||
Index('ix_user_sessions_user_active', 'user_id', 'is_active'),
|
||||
Index('ix_user_sessions_jti_active', 'refresh_token_jti', 'is_active'),
|
||||
Index("ix_user_sessions_user_active", "user_id", "is_active"),
|
||||
Index("ix_user_sessions_jti_active", "refresh_token_jti", "is_active"),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
@@ -60,21 +74,28 @@ class UserSession(Base, UUIDMixin, TimestampMixin):
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if session has expired."""
|
||||
from datetime import datetime, timezone
|
||||
return self.expires_at < datetime.now(timezone.utc)
|
||||
from datetime import datetime
|
||||
|
||||
now = datetime.now(UTC)
|
||||
expires_at = self.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=UTC)
|
||||
return bool(expires_at < now)
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert session to dictionary for serialization."""
|
||||
return {
|
||||
'id': str(self.id),
|
||||
'user_id': str(self.user_id),
|
||||
'device_name': self.device_name,
|
||||
'device_id': self.device_id,
|
||||
'ip_address': self.ip_address,
|
||||
'last_used_at': self.last_used_at.isoformat() if self.last_used_at else None,
|
||||
'expires_at': self.expires_at.isoformat() if self.expires_at else None,
|
||||
'is_active': self.is_active,
|
||||
'location_city': self.location_city,
|
||||
'location_country': self.location_country,
|
||||
'created_at': self.created_at.isoformat() if self.created_at else None,
|
||||
"id": str(self.id),
|
||||
"user_id": str(self.user_id),
|
||||
"device_name": self.device_name,
|
||||
"device_id": self.device_id,
|
||||
"ip_address": self.ip_address,
|
||||
"last_used_at": self.last_used_at.isoformat()
|
||||
if self.last_used_at
|
||||
else None,
|
||||
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
|
||||
"is_active": self.is_active,
|
||||
"location_city": self.location_city,
|
||||
"location_country": self.location_country,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
}
|
||||
|
||||
39
backend/app/repositories/__init__.py
Normal file
39
backend/app/repositories/__init__.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# app/repositories/__init__.py
|
||||
"""Repository layer — all database access goes through these classes."""
|
||||
|
||||
from app.repositories.oauth_account import OAuthAccountRepository, oauth_account_repo
|
||||
from app.repositories.oauth_authorization_code import (
|
||||
OAuthAuthorizationCodeRepository,
|
||||
oauth_authorization_code_repo,
|
||||
)
|
||||
from app.repositories.oauth_client import OAuthClientRepository, oauth_client_repo
|
||||
from app.repositories.oauth_consent import OAuthConsentRepository, oauth_consent_repo
|
||||
from app.repositories.oauth_provider_token import (
|
||||
OAuthProviderTokenRepository,
|
||||
oauth_provider_token_repo,
|
||||
)
|
||||
from app.repositories.oauth_state import OAuthStateRepository, oauth_state_repo
|
||||
from app.repositories.organization import OrganizationRepository, organization_repo
|
||||
from app.repositories.session import SessionRepository, session_repo
|
||||
from app.repositories.user import UserRepository, user_repo
|
||||
|
||||
__all__ = [
|
||||
"OAuthAccountRepository",
|
||||
"OAuthAuthorizationCodeRepository",
|
||||
"OAuthClientRepository",
|
||||
"OAuthConsentRepository",
|
||||
"OAuthProviderTokenRepository",
|
||||
"OAuthStateRepository",
|
||||
"OrganizationRepository",
|
||||
"SessionRepository",
|
||||
"UserRepository",
|
||||
"oauth_account_repo",
|
||||
"oauth_authorization_code_repo",
|
||||
"oauth_client_repo",
|
||||
"oauth_consent_repo",
|
||||
"oauth_provider_token_repo",
|
||||
"oauth_state_repo",
|
||||
"organization_repo",
|
||||
"session_repo",
|
||||
"user_repo",
|
||||
]
|
||||
255
backend/app/crud/base.py → backend/app/repositories/base.py
Executable file → Normal file
255
backend/app/crud/base.py → backend/app/repositories/base.py
Executable file → Normal file
@@ -1,21 +1,28 @@
|
||||
# app/crud/base_async.py
|
||||
# app/repositories/base.py
|
||||
"""
|
||||
Async CRUD operations base class using SQLAlchemy 2.0 async patterns.
|
||||
Base repository class for async database operations using SQLAlchemy 2.0 async patterns.
|
||||
|
||||
Provides reusable create, read, update, and delete operations for all models.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union, Tuple
|
||||
from datetime import UTC
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError, DataError
|
||||
from sqlalchemy.exc import DataError, IntegrityError, OperationalError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Load
|
||||
|
||||
from app.core.database import Base
|
||||
from app.core.repository_exceptions import (
|
||||
DuplicateEntryError,
|
||||
IntegrityConstraintError,
|
||||
InvalidInputError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -24,12 +31,16 @@ CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
|
||||
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
|
||||
|
||||
|
||||
class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
"""Async CRUD operations for a model."""
|
||||
class BaseRepository[
|
||||
ModelType: Base,
|
||||
CreateSchemaType: BaseModel,
|
||||
UpdateSchemaType: BaseModel,
|
||||
]:
|
||||
"""Async repository operations for a model."""
|
||||
|
||||
def __init__(self, model: Type[ModelType]):
|
||||
def __init__(self, model: type[ModelType]):
|
||||
"""
|
||||
CRUD object with default async methods to Create, Read, Update, Delete.
|
||||
Repository object with default async methods to Create, Read, Update, Delete.
|
||||
|
||||
Parameters:
|
||||
model: A SQLAlchemy model class
|
||||
@@ -37,11 +48,8 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
self.model = model
|
||||
|
||||
async def get(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
id: str,
|
||||
options: Optional[List[Load]] = None
|
||||
) -> Optional[ModelType]:
|
||||
self, db: AsyncSession, id: str, options: list[Load] | None = None
|
||||
) -> ModelType | None:
|
||||
"""
|
||||
Get a single record by ID with UUID validation and optional eager loading.
|
||||
|
||||
@@ -53,26 +61,19 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
|
||||
Returns:
|
||||
Model instance or None if not found
|
||||
|
||||
Example:
|
||||
# Eager load user relationship
|
||||
from sqlalchemy.orm import joinedload
|
||||
session = await session_crud.get(db, id=session_id, options=[joinedload(UserSession.user)])
|
||||
"""
|
||||
# Validate UUID format and convert to UUID object if string
|
||||
try:
|
||||
if isinstance(id, uuid.UUID):
|
||||
uuid_obj = id
|
||||
else:
|
||||
uuid_obj = uuid.UUID(str(id))
|
||||
except (ValueError, AttributeError, TypeError) as e:
|
||||
logger.warning(f"Invalid UUID format: {id} - {str(e)}")
|
||||
logger.warning("Invalid UUID format: %s - %s", id, e)
|
||||
return None
|
||||
|
||||
try:
|
||||
query = select(self.model).where(self.model.id == uuid_obj)
|
||||
|
||||
# Apply eager loading options if provided
|
||||
if options:
|
||||
for option in options:
|
||||
query = query.options(option)
|
||||
@@ -80,7 +81,9 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
result = await db.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving {self.model.__name__} with id {id}: {str(e)}")
|
||||
logger.error(
|
||||
"Error retrieving %s with id %s: %s", self.model.__name__, id, e
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_multi(
|
||||
@@ -89,32 +92,21 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
options: Optional[List[Load]] = None
|
||||
) -> List[ModelType]:
|
||||
options: list[Load] | None = None,
|
||||
) -> list[ModelType]:
|
||||
"""
|
||||
Get multiple records with pagination validation and optional eager loading.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
skip: Number of records to skip
|
||||
limit: Maximum number of records to return
|
||||
options: Optional list of SQLAlchemy load options for eager loading
|
||||
|
||||
Returns:
|
||||
List of model instances
|
||||
"""
|
||||
# Validate pagination parameters
|
||||
if skip < 0:
|
||||
raise ValueError("skip must be non-negative")
|
||||
raise InvalidInputError("skip must be non-negative")
|
||||
if limit < 0:
|
||||
raise ValueError("limit must be non-negative")
|
||||
raise InvalidInputError("limit must be non-negative")
|
||||
if limit > 1000:
|
||||
raise ValueError("Maximum limit is 1000")
|
||||
raise InvalidInputError("Maximum limit is 1000")
|
||||
|
||||
try:
|
||||
query = select(self.model).offset(skip).limit(limit)
|
||||
query = select(self.model).order_by(self.model.id).offset(skip).limit(limit)
|
||||
|
||||
# Apply eager loading options if provided
|
||||
if options:
|
||||
for option in options:
|
||||
query = query.options(option)
|
||||
@@ -122,16 +114,19 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving multiple {self.model.__name__} records: {str(e)}")
|
||||
logger.error(
|
||||
"Error retrieving multiple %s records: %s", self.model.__name__, e
|
||||
)
|
||||
raise
|
||||
|
||||
async def create(self, db: AsyncSession, *, obj_in: CreateSchemaType) -> ModelType: # pragma: no cover
|
||||
async def create(
|
||||
self, db: AsyncSession, *, obj_in: CreateSchemaType
|
||||
) -> ModelType: # pragma: no cover
|
||||
"""Create a new record with error handling.
|
||||
|
||||
NOTE: This method is defensive code that's never called in practice.
|
||||
All CRUD subclasses (CRUDUser, CRUDOrganization, CRUDSession) override this method
|
||||
with their own implementations, so the base implementation and its exception handlers
|
||||
are never executed. Marked as pragma: no cover to avoid false coverage gaps.
|
||||
All repository subclasses override this method with their own implementations.
|
||||
Marked as pragma: no cover to avoid false coverage gaps.
|
||||
"""
|
||||
try: # pragma: no cover
|
||||
obj_in_data = jsonable_encoder(obj_in)
|
||||
@@ -142,19 +137,27 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
return db_obj
|
||||
except IntegrityError as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
|
||||
logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(f"A {self.model.__name__} with this data already exists")
|
||||
logger.error(f"Integrity error creating {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
logger.warning(
|
||||
"Duplicate entry attempted for %s: %s",
|
||||
self.model.__name__,
|
||||
error_msg,
|
||||
)
|
||||
raise DuplicateEntryError(
|
||||
f"A {self.model.__name__} with this data already exists"
|
||||
)
|
||||
logger.error(
|
||||
"Integrity error creating %s: %s", self.model.__name__, error_msg
|
||||
)
|
||||
raise IntegrityConstraintError(f"Database integrity error: {error_msg}")
|
||||
except (OperationalError, DataError) as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error(f"Database error creating {self.model.__name__}: {str(e)}")
|
||||
raise ValueError(f"Database operation failed: {str(e)}")
|
||||
logger.error("Database error creating %s: %s", self.model.__name__, e)
|
||||
raise IntegrityConstraintError(f"Database operation failed: {e!s}")
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error(f"Unexpected error creating {self.model.__name__}: {str(e)}", exc_info=True)
|
||||
logger.exception("Unexpected error creating %s: %s", self.model.__name__, e)
|
||||
raise
|
||||
|
||||
async def update(
|
||||
@@ -162,7 +165,7 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
db: AsyncSession,
|
||||
*,
|
||||
db_obj: ModelType,
|
||||
obj_in: Union[UpdateSchemaType, Dict[str, Any]]
|
||||
obj_in: UpdateSchemaType | dict[str, Any],
|
||||
) -> ModelType:
|
||||
"""Update a record with error handling."""
|
||||
try:
|
||||
@@ -182,31 +185,38 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
|
||||
logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(f"A {self.model.__name__} with this data already exists")
|
||||
logger.error(f"Integrity error updating {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
logger.warning(
|
||||
"Duplicate entry attempted for %s: %s",
|
||||
self.model.__name__,
|
||||
error_msg,
|
||||
)
|
||||
raise DuplicateEntryError(
|
||||
f"A {self.model.__name__} with this data already exists"
|
||||
)
|
||||
logger.error(
|
||||
"Integrity error updating %s: %s", self.model.__name__, error_msg
|
||||
)
|
||||
raise IntegrityConstraintError(f"Database integrity error: {error_msg}")
|
||||
except (OperationalError, DataError) as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Database error updating {self.model.__name__}: {str(e)}")
|
||||
raise ValueError(f"Database operation failed: {str(e)}")
|
||||
logger.error("Database error updating %s: %s", self.model.__name__, e)
|
||||
raise IntegrityConstraintError(f"Database operation failed: {e!s}")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Unexpected error updating {self.model.__name__}: {str(e)}", exc_info=True)
|
||||
logger.exception("Unexpected error updating %s: %s", self.model.__name__, e)
|
||||
raise
|
||||
|
||||
async def remove(self, db: AsyncSession, *, id: str) -> Optional[ModelType]:
|
||||
async def remove(self, db: AsyncSession, *, id: str) -> ModelType | None:
|
||||
"""Delete a record with error handling and null check."""
|
||||
# Validate UUID format and convert to UUID object if string
|
||||
try:
|
||||
if isinstance(id, uuid.UUID):
|
||||
uuid_obj = id
|
||||
else:
|
||||
uuid_obj = uuid.UUID(str(id))
|
||||
except (ValueError, AttributeError, TypeError) as e:
|
||||
logger.warning(f"Invalid UUID format for deletion: {id} - {str(e)}")
|
||||
logger.warning("Invalid UUID format for deletion: %s - %s", id, e)
|
||||
return None
|
||||
|
||||
try:
|
||||
@@ -216,7 +226,9 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
obj = result.scalar_one_or_none()
|
||||
|
||||
if obj is None:
|
||||
logger.warning(f"{self.model.__name__} with id {id} not found for deletion")
|
||||
logger.warning(
|
||||
"%s with id %s not found for deletion", self.model.__name__, id
|
||||
)
|
||||
return None
|
||||
|
||||
await db.delete(obj)
|
||||
@@ -224,12 +236,18 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
return obj
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
||||
logger.error(f"Integrity error deleting {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(f"Cannot delete {self.model.__name__}: referenced by other records")
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
logger.error(
|
||||
"Integrity error deleting %s: %s", self.model.__name__, error_msg
|
||||
)
|
||||
raise IntegrityConstraintError(
|
||||
f"Cannot delete {self.model.__name__}: referenced by other records"
|
||||
)
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
|
||||
logger.exception(
|
||||
"Error deleting %s with id %s: %s", self.model.__name__, id, e
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_multi_with_total(
|
||||
@@ -238,67 +256,57 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
sort_by: Optional[str] = None,
|
||||
sort_by: str | None = None,
|
||||
sort_order: str = "asc",
|
||||
filters: Optional[Dict[str, Any]] = None
|
||||
) -> Tuple[List[ModelType], int]:
|
||||
filters: dict[str, Any] | None = None,
|
||||
) -> tuple[list[ModelType], int]: # pragma: no cover
|
||||
"""
|
||||
Get multiple records with total count, filtering, and sorting.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
skip: Number of records to skip
|
||||
limit: Maximum number of records to return
|
||||
sort_by: Field name to sort by (must be a valid model attribute)
|
||||
sort_order: Sort order ("asc" or "desc")
|
||||
filters: Dictionary of filters (field_name: value)
|
||||
|
||||
Returns:
|
||||
Tuple of (items, total_count)
|
||||
NOTE: This method is defensive code that's never called in practice.
|
||||
All repository subclasses override this method with their own implementations.
|
||||
Marked as pragma: no cover to avoid false coverage gaps.
|
||||
"""
|
||||
# Validate pagination parameters
|
||||
if skip < 0:
|
||||
raise ValueError("skip must be non-negative")
|
||||
raise InvalidInputError("skip must be non-negative")
|
||||
if limit < 0:
|
||||
raise ValueError("limit must be non-negative")
|
||||
raise InvalidInputError("limit must be non-negative")
|
||||
if limit > 1000:
|
||||
raise ValueError("Maximum limit is 1000")
|
||||
raise InvalidInputError("Maximum limit is 1000")
|
||||
|
||||
try:
|
||||
# Build base query
|
||||
query = select(self.model)
|
||||
|
||||
# Exclude soft-deleted records by default
|
||||
if hasattr(self.model, 'deleted_at'):
|
||||
if hasattr(self.model, "deleted_at"):
|
||||
query = query.where(self.model.deleted_at.is_(None))
|
||||
|
||||
# Apply filters
|
||||
if filters:
|
||||
for field, value in filters.items():
|
||||
if hasattr(self.model, field) and value is not None:
|
||||
query = query.where(getattr(self.model, field) == value)
|
||||
|
||||
# Get total count (before pagination)
|
||||
count_query = select(func.count()).select_from(query.alias())
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# Apply sorting
|
||||
if sort_by and hasattr(self.model, sort_by):
|
||||
sort_column = getattr(self.model, sort_by)
|
||||
if sort_order.lower() == "desc":
|
||||
query = query.order_by(sort_column.desc())
|
||||
else:
|
||||
query = query.order_by(sort_column.asc())
|
||||
else:
|
||||
query = query.order_by(self.model.id)
|
||||
|
||||
# Apply pagination
|
||||
query = query.offset(skip).limit(limit)
|
||||
items_result = await db.execute(query)
|
||||
items = list(items_result.scalars().all())
|
||||
|
||||
return items, total
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving paginated {self.model.__name__} records: {str(e)}")
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error(
|
||||
"Error retrieving paginated %s records: %s", self.model.__name__, e
|
||||
)
|
||||
raise
|
||||
|
||||
async def count(self, db: AsyncSession) -> int:
|
||||
@@ -307,7 +315,7 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
result = await db.execute(select(func.count(self.model.id)))
|
||||
return result.scalar_one()
|
||||
except Exception as e:
|
||||
logger.error(f"Error counting {self.model.__name__} records: {str(e)}")
|
||||
logger.error("Error counting %s records: %s", self.model.__name__, e)
|
||||
raise
|
||||
|
||||
async def exists(self, db: AsyncSession, id: str) -> bool:
|
||||
@@ -315,22 +323,21 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
obj = await self.get(db, id=id)
|
||||
return obj is not None
|
||||
|
||||
async def soft_delete(self, db: AsyncSession, *, id: str) -> Optional[ModelType]:
|
||||
async def soft_delete(self, db: AsyncSession, *, id: str) -> ModelType | None:
|
||||
"""
|
||||
Soft delete a record by setting deleted_at timestamp.
|
||||
|
||||
Only works if the model has a 'deleted_at' column.
|
||||
"""
|
||||
from datetime import datetime, timezone
|
||||
from datetime import datetime
|
||||
|
||||
# Validate UUID format and convert to UUID object if string
|
||||
try:
|
||||
if isinstance(id, uuid.UUID):
|
||||
uuid_obj = id
|
||||
else:
|
||||
uuid_obj = uuid.UUID(str(id))
|
||||
except (ValueError, AttributeError, TypeError) as e:
|
||||
logger.warning(f"Invalid UUID format for soft deletion: {id} - {str(e)}")
|
||||
logger.warning("Invalid UUID format for soft deletion: %s - %s", id, e)
|
||||
return None
|
||||
|
||||
try:
|
||||
@@ -340,60 +347,66 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
obj = result.scalar_one_or_none()
|
||||
|
||||
if obj is None:
|
||||
logger.warning(f"{self.model.__name__} with id {id} not found for soft deletion")
|
||||
logger.warning(
|
||||
"%s with id %s not found for soft deletion", self.model.__name__, id
|
||||
)
|
||||
return None
|
||||
|
||||
# Check if model supports soft deletes
|
||||
if not hasattr(self.model, 'deleted_at'):
|
||||
logger.error(f"{self.model.__name__} does not support soft deletes")
|
||||
raise ValueError(f"{self.model.__name__} does not have a deleted_at column")
|
||||
if not hasattr(self.model, "deleted_at"):
|
||||
logger.error("%s does not support soft deletes", self.model.__name__)
|
||||
raise InvalidInputError(
|
||||
f"{self.model.__name__} does not have a deleted_at column"
|
||||
)
|
||||
|
||||
# Set deleted_at timestamp
|
||||
obj.deleted_at = datetime.now(timezone.utc)
|
||||
obj.deleted_at = datetime.now(UTC)
|
||||
db.add(obj)
|
||||
await db.commit()
|
||||
await db.refresh(obj)
|
||||
return obj
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error soft deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
|
||||
logger.exception(
|
||||
"Error soft deleting %s with id %s: %s", self.model.__name__, id, e
|
||||
)
|
||||
raise
|
||||
|
||||
async def restore(self, db: AsyncSession, *, id: str) -> Optional[ModelType]:
|
||||
async def restore(self, db: AsyncSession, *, id: str) -> ModelType | None:
|
||||
"""
|
||||
Restore a soft-deleted record by clearing the deleted_at timestamp.
|
||||
|
||||
Only works if the model has a 'deleted_at' column.
|
||||
"""
|
||||
# Validate UUID format
|
||||
try:
|
||||
if isinstance(id, uuid.UUID):
|
||||
uuid_obj = id
|
||||
else:
|
||||
uuid_obj = uuid.UUID(str(id))
|
||||
except (ValueError, AttributeError, TypeError) as e:
|
||||
logger.warning(f"Invalid UUID format for restoration: {id} - {str(e)}")
|
||||
logger.warning("Invalid UUID format for restoration: %s - %s", id, e)
|
||||
return None
|
||||
|
||||
try:
|
||||
# Find the soft-deleted record
|
||||
if hasattr(self.model, 'deleted_at'):
|
||||
if hasattr(self.model, "deleted_at"):
|
||||
result = await db.execute(
|
||||
select(self.model).where(
|
||||
self.model.id == uuid_obj,
|
||||
self.model.deleted_at.isnot(None)
|
||||
self.model.id == uuid_obj, self.model.deleted_at.isnot(None)
|
||||
)
|
||||
)
|
||||
obj = result.scalar_one_or_none()
|
||||
else:
|
||||
logger.error(f"{self.model.__name__} does not support soft deletes")
|
||||
raise ValueError(f"{self.model.__name__} does not have a deleted_at column")
|
||||
logger.error("%s does not support soft deletes", self.model.__name__)
|
||||
raise InvalidInputError(
|
||||
f"{self.model.__name__} does not have a deleted_at column"
|
||||
)
|
||||
|
||||
if obj is None:
|
||||
logger.warning(f"Soft-deleted {self.model.__name__} with id {id} not found for restoration")
|
||||
logger.warning(
|
||||
"Soft-deleted %s with id %s not found for restoration",
|
||||
self.model.__name__,
|
||||
id,
|
||||
)
|
||||
return None
|
||||
|
||||
# Clear deleted_at timestamp
|
||||
obj.deleted_at = None
|
||||
db.add(obj)
|
||||
await db.commit()
|
||||
@@ -401,5 +414,7 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
return obj
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error restoring {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
|
||||
logger.exception(
|
||||
"Error restoring %s with id %s: %s", self.model.__name__, id, e
|
||||
)
|
||||
raise
|
||||
249
backend/app/repositories/oauth_account.py
Normal file
249
backend/app/repositories/oauth_account.py
Normal file
@@ -0,0 +1,249 @@
|
||||
# app/repositories/oauth_account.py
|
||||
"""Repository for OAuthAccount model async database operations."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import and_, delete, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from app.core.repository_exceptions import DuplicateEntryError
|
||||
from app.models.oauth_account import OAuthAccount
|
||||
from app.repositories.base import BaseRepository
|
||||
from app.schemas.oauth import OAuthAccountCreate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmptySchema(BaseModel):
|
||||
"""Placeholder schema for repository operations that don't need update schemas."""
|
||||
|
||||
|
||||
class OAuthAccountRepository(
|
||||
BaseRepository[OAuthAccount, OAuthAccountCreate, EmptySchema]
|
||||
):
|
||||
"""Repository for OAuth account links."""
|
||||
|
||||
async def get_by_provider_id(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
provider: str,
|
||||
provider_user_id: str,
|
||||
) -> OAuthAccount | None:
|
||||
"""Get OAuth account by provider and provider user ID."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(OAuthAccount)
|
||||
.where(
|
||||
and_(
|
||||
OAuthAccount.provider == provider,
|
||||
OAuthAccount.provider_user_id == provider_user_id,
|
||||
)
|
||||
)
|
||||
.options(joinedload(OAuthAccount.user))
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error(
|
||||
"Error getting OAuth account for %s:%s: %s",
|
||||
provider,
|
||||
provider_user_id,
|
||||
e,
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_by_provider_email(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
provider: str,
|
||||
email: str,
|
||||
) -> OAuthAccount | None:
|
||||
"""Get OAuth account by provider and email."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(OAuthAccount)
|
||||
.where(
|
||||
and_(
|
||||
OAuthAccount.provider == provider,
|
||||
OAuthAccount.provider_email == email,
|
||||
)
|
||||
)
|
||||
.options(joinedload(OAuthAccount.user))
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error(
|
||||
"Error getting OAuth account for %s email %s: %s", provider, email, e
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_user_accounts(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: str | UUID,
|
||||
) -> list[OAuthAccount]:
|
||||
"""Get all OAuth accounts linked to a user."""
|
||||
try:
|
||||
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
|
||||
|
||||
result = await db.execute(
|
||||
select(OAuthAccount)
|
||||
.where(OAuthAccount.user_id == user_uuid)
|
||||
.order_by(OAuthAccount.created_at.desc())
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error("Error getting OAuth accounts for user %s: %s", user_id, e)
|
||||
raise
|
||||
|
||||
async def get_user_account_by_provider(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: str | UUID,
|
||||
provider: str,
|
||||
) -> OAuthAccount | None:
|
||||
"""Get a specific OAuth account for a user and provider."""
|
||||
try:
|
||||
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
|
||||
|
||||
result = await db.execute(
|
||||
select(OAuthAccount).where(
|
||||
and_(
|
||||
OAuthAccount.user_id == user_uuid,
|
||||
OAuthAccount.provider == provider,
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error(
|
||||
"Error getting OAuth account for user %s, provider %s: %s",
|
||||
user_id,
|
||||
provider,
|
||||
e,
|
||||
)
|
||||
raise
|
||||
|
||||
async def create_account(
|
||||
self, db: AsyncSession, *, obj_in: OAuthAccountCreate
|
||||
) -> OAuthAccount:
|
||||
"""Create a new OAuth account link."""
|
||||
try:
|
||||
db_obj = OAuthAccount(
|
||||
user_id=obj_in.user_id,
|
||||
provider=obj_in.provider,
|
||||
provider_user_id=obj_in.provider_user_id,
|
||||
provider_email=obj_in.provider_email,
|
||||
access_token=obj_in.access_token,
|
||||
refresh_token=obj_in.refresh_token,
|
||||
token_expires_at=obj_in.token_expires_at,
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
|
||||
logger.info(
|
||||
"OAuth account created: %s linked to user %s",
|
||||
obj_in.provider,
|
||||
obj_in.user_id,
|
||||
)
|
||||
return db_obj
|
||||
except IntegrityError as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
if "uq_oauth_provider_user" in error_msg.lower():
|
||||
logger.warning(
|
||||
"OAuth account already exists: %s:%s",
|
||||
obj_in.provider,
|
||||
obj_in.provider_user_id,
|
||||
)
|
||||
raise DuplicateEntryError(
|
||||
f"This {obj_in.provider} account is already linked to another user"
|
||||
)
|
||||
logger.error("Integrity error creating OAuth account: %s", error_msg)
|
||||
raise DuplicateEntryError(f"Failed to create OAuth account: {error_msg}")
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.exception("Error creating OAuth account: %s", e)
|
||||
raise
|
||||
|
||||
async def delete_account(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: str | UUID,
|
||||
provider: str,
|
||||
) -> bool:
|
||||
"""Delete an OAuth account link."""
|
||||
try:
|
||||
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
|
||||
|
||||
result = await db.execute(
|
||||
delete(OAuthAccount).where(
|
||||
and_(
|
||||
OAuthAccount.user_id == user_uuid,
|
||||
OAuthAccount.provider == provider,
|
||||
)
|
||||
)
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
deleted = result.rowcount > 0
|
||||
if deleted:
|
||||
logger.info(
|
||||
"OAuth account deleted: %s unlinked from user %s", provider, user_id
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"OAuth account not found for deletion: %s for user %s",
|
||||
provider,
|
||||
user_id,
|
||||
)
|
||||
|
||||
return deleted
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
"Error deleting OAuth account %s for user %s: %s", provider, user_id, e
|
||||
)
|
||||
raise
|
||||
|
||||
async def update_tokens(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
account: OAuthAccount,
|
||||
access_token: str | None = None,
|
||||
refresh_token: str | None = None,
|
||||
token_expires_at: datetime | None = None,
|
||||
) -> OAuthAccount:
|
||||
"""Update OAuth tokens for an account."""
|
||||
try:
|
||||
if access_token is not None:
|
||||
account.access_token = access_token
|
||||
if refresh_token is not None:
|
||||
account.refresh_token = refresh_token
|
||||
if token_expires_at is not None:
|
||||
account.token_expires_at = token_expires_at
|
||||
|
||||
db.add(account)
|
||||
await db.commit()
|
||||
await db.refresh(account)
|
||||
|
||||
return account
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error("Error updating OAuth tokens: %s", e)
|
||||
raise
|
||||
|
||||
|
||||
# Singleton instance
|
||||
oauth_account_repo = OAuthAccountRepository(OAuthAccount)
|
||||
108
backend/app/repositories/oauth_authorization_code.py
Normal file
108
backend/app/repositories/oauth_authorization_code.py
Normal file
@@ -0,0 +1,108 @@
|
||||
# app/repositories/oauth_authorization_code.py
|
||||
"""Repository for OAuthAuthorizationCode model."""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_, delete, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.oauth_authorization_code import OAuthAuthorizationCode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuthAuthorizationCodeRepository:
|
||||
"""Repository for OAuth 2.0 authorization codes."""
|
||||
|
||||
async def create_code(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
code: str,
|
||||
client_id: str,
|
||||
user_id: UUID,
|
||||
redirect_uri: str,
|
||||
scope: str,
|
||||
expires_at: datetime,
|
||||
code_challenge: str | None = None,
|
||||
code_challenge_method: str | None = None,
|
||||
state: str | None = None,
|
||||
nonce: str | None = None,
|
||||
) -> OAuthAuthorizationCode:
|
||||
"""Create and persist a new authorization code."""
|
||||
auth_code = OAuthAuthorizationCode(
|
||||
code=code,
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
redirect_uri=redirect_uri,
|
||||
scope=scope,
|
||||
code_challenge=code_challenge,
|
||||
code_challenge_method=code_challenge_method,
|
||||
state=state,
|
||||
nonce=nonce,
|
||||
expires_at=expires_at,
|
||||
used=False,
|
||||
)
|
||||
db.add(auth_code)
|
||||
await db.commit()
|
||||
return auth_code
|
||||
|
||||
async def consume_code_atomically(
|
||||
self, db: AsyncSession, *, code: str
|
||||
) -> UUID | None:
|
||||
"""
|
||||
Atomically mark a code as used and return its UUID.
|
||||
|
||||
Returns the UUID if the code was found and not yet used, None otherwise.
|
||||
This prevents race conditions per RFC 6749 Section 4.1.2.
|
||||
"""
|
||||
stmt = (
|
||||
update(OAuthAuthorizationCode)
|
||||
.where(
|
||||
and_(
|
||||
OAuthAuthorizationCode.code == code,
|
||||
OAuthAuthorizationCode.used == False, # noqa: E712
|
||||
)
|
||||
)
|
||||
.values(used=True)
|
||||
.returning(OAuthAuthorizationCode.id)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
row_id = result.scalar_one_or_none()
|
||||
if row_id is not None:
|
||||
await db.commit()
|
||||
return row_id
|
||||
|
||||
async def get_by_id(
|
||||
self, db: AsyncSession, *, code_id: UUID
|
||||
) -> OAuthAuthorizationCode | None:
|
||||
"""Get authorization code by its UUID primary key."""
|
||||
result = await db.execute(
|
||||
select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.id == code_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_by_code(
|
||||
self, db: AsyncSession, *, code: str
|
||||
) -> OAuthAuthorizationCode | None:
|
||||
"""Get authorization code by the code string value."""
|
||||
result = await db.execute(
|
||||
select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.code == code)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def cleanup_expired(self, db: AsyncSession) -> int:
|
||||
"""Delete all expired authorization codes. Returns count deleted."""
|
||||
result = await db.execute(
|
||||
delete(OAuthAuthorizationCode).where(
|
||||
OAuthAuthorizationCode.expires_at < datetime.now(UTC)
|
||||
)
|
||||
)
|
||||
await db.commit()
|
||||
return result.rowcount # type: ignore[attr-defined]
|
||||
|
||||
|
||||
# Singleton instance
|
||||
oauth_authorization_code_repo = OAuthAuthorizationCodeRepository()
|
||||
201
backend/app/repositories/oauth_client.py
Normal file
201
backend/app/repositories/oauth_client.py
Normal file
@@ -0,0 +1,201 @@
|
||||
# app/repositories/oauth_client.py
|
||||
"""Repository for OAuthClient model async database operations."""
|
||||
|
||||
import logging
|
||||
import secrets
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import and_, delete, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.repository_exceptions import DuplicateEntryError
|
||||
from app.models.oauth_client import OAuthClient
|
||||
from app.repositories.base import BaseRepository
|
||||
from app.schemas.oauth import OAuthClientCreate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmptySchema(BaseModel):
|
||||
"""Placeholder schema for repository operations that don't need update schemas."""
|
||||
|
||||
|
||||
class OAuthClientRepository(
|
||||
BaseRepository[OAuthClient, OAuthClientCreate, EmptySchema]
|
||||
):
|
||||
"""Repository for OAuth clients (provider mode)."""
|
||||
|
||||
async def get_by_client_id(
|
||||
self, db: AsyncSession, *, client_id: str
|
||||
) -> OAuthClient | None:
|
||||
"""Get OAuth client by client_id."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(OAuthClient).where(
|
||||
and_(
|
||||
OAuthClient.client_id == client_id,
|
||||
OAuthClient.is_active == True, # noqa: E712
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error("Error getting OAuth client %s: %s", client_id, e)
|
||||
raise
|
||||
|
||||
async def create_client(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
obj_in: OAuthClientCreate,
|
||||
owner_user_id: UUID | None = None,
|
||||
) -> tuple[OAuthClient, str | None]:
|
||||
"""Create a new OAuth client."""
|
||||
try:
|
||||
client_id = secrets.token_urlsafe(32)
|
||||
|
||||
client_secret = None
|
||||
client_secret_hash = None
|
||||
if obj_in.client_type == "confidential":
|
||||
client_secret = secrets.token_urlsafe(48)
|
||||
from app.core.auth import get_password_hash
|
||||
|
||||
client_secret_hash = get_password_hash(client_secret)
|
||||
|
||||
db_obj = OAuthClient(
|
||||
client_id=client_id,
|
||||
client_secret_hash=client_secret_hash,
|
||||
client_name=obj_in.client_name,
|
||||
client_description=obj_in.client_description,
|
||||
client_type=obj_in.client_type,
|
||||
redirect_uris=obj_in.redirect_uris,
|
||||
allowed_scopes=obj_in.allowed_scopes,
|
||||
owner_user_id=owner_user_id,
|
||||
is_active=True,
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
|
||||
logger.info(
|
||||
"OAuth client created: %s (%s...)", obj_in.client_name, client_id[:8]
|
||||
)
|
||||
return db_obj, client_secret
|
||||
except IntegrityError as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
logger.error("Error creating OAuth client: %s", error_msg)
|
||||
raise DuplicateEntryError(f"Failed to create OAuth client: {error_msg}")
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.exception("Error creating OAuth client: %s", e)
|
||||
raise
|
||||
|
||||
async def deactivate_client(
|
||||
self, db: AsyncSession, *, client_id: str
|
||||
) -> OAuthClient | None:
|
||||
"""Deactivate an OAuth client."""
|
||||
try:
|
||||
client = await self.get_by_client_id(db, client_id=client_id)
|
||||
if client is None:
|
||||
return None
|
||||
|
||||
client.is_active = False
|
||||
db.add(client)
|
||||
await db.commit()
|
||||
await db.refresh(client)
|
||||
|
||||
logger.info("OAuth client deactivated: %s", client.client_name)
|
||||
return client
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error("Error deactivating OAuth client %s: %s", client_id, e)
|
||||
raise
|
||||
|
||||
async def validate_redirect_uri(
|
||||
self, db: AsyncSession, *, client_id: str, redirect_uri: str
|
||||
) -> bool:
|
||||
"""Validate that a redirect URI is allowed for a client."""
|
||||
try:
|
||||
client = await self.get_by_client_id(db, client_id=client_id)
|
||||
if client is None:
|
||||
return False
|
||||
|
||||
return redirect_uri in (client.redirect_uris or [])
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error("Error validating redirect URI: %s", e)
|
||||
return False
|
||||
|
||||
async def verify_client_secret(
|
||||
self, db: AsyncSession, *, client_id: str, client_secret: str
|
||||
) -> bool:
|
||||
"""Verify client credentials."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(OAuthClient).where(
|
||||
and_(
|
||||
OAuthClient.client_id == client_id,
|
||||
OAuthClient.is_active == True, # noqa: E712
|
||||
)
|
||||
)
|
||||
)
|
||||
client = result.scalar_one_or_none()
|
||||
|
||||
if client is None or client.client_secret_hash is None:
|
||||
return False
|
||||
|
||||
from app.core.auth import verify_password
|
||||
|
||||
stored_hash: str = str(client.client_secret_hash)
|
||||
|
||||
if stored_hash.startswith("$2"):
|
||||
return verify_password(client_secret, stored_hash)
|
||||
else:
|
||||
import hashlib
|
||||
|
||||
secret_hash = hashlib.sha256(client_secret.encode()).hexdigest()
|
||||
return secrets.compare_digest(stored_hash, secret_hash)
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error("Error verifying client secret: %s", e)
|
||||
return False
|
||||
|
||||
async def get_all_clients(
|
||||
self, db: AsyncSession, *, include_inactive: bool = False
|
||||
) -> list[OAuthClient]:
|
||||
"""Get all OAuth clients."""
|
||||
try:
|
||||
query = select(OAuthClient).order_by(OAuthClient.created_at.desc())
|
||||
if not include_inactive:
|
||||
query = query.where(OAuthClient.is_active == True) # noqa: E712
|
||||
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error("Error getting all OAuth clients: %s", e)
|
||||
raise
|
||||
|
||||
async def delete_client(self, db: AsyncSession, *, client_id: str) -> bool:
|
||||
"""Delete an OAuth client permanently."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
delete(OAuthClient).where(OAuthClient.client_id == client_id)
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
deleted = result.rowcount > 0
|
||||
if deleted:
|
||||
logger.info("OAuth client deleted: %s", client_id)
|
||||
else:
|
||||
logger.warning("OAuth client not found for deletion: %s", client_id)
|
||||
|
||||
return deleted
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error("Error deleting OAuth client %s: %s", client_id, e)
|
||||
raise
|
||||
|
||||
|
||||
# Singleton instance
|
||||
oauth_client_repo = OAuthClientRepository(OAuthClient)
|
||||
113
backend/app/repositories/oauth_consent.py
Normal file
113
backend/app/repositories/oauth_consent.py
Normal file
@@ -0,0 +1,113 @@
|
||||
# app/repositories/oauth_consent.py
|
||||
"""Repository for OAuthConsent model."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_, delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.oauth_client import OAuthClient
|
||||
from app.models.oauth_provider_token import OAuthConsent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuthConsentRepository:
|
||||
"""Repository for OAuth consent records (user grants to clients)."""
|
||||
|
||||
async def get_consent(
|
||||
self, db: AsyncSession, *, user_id: UUID, client_id: str
|
||||
) -> OAuthConsent | None:
|
||||
"""Get the consent record for a user-client pair, or None if not found."""
|
||||
result = await db.execute(
|
||||
select(OAuthConsent).where(
|
||||
and_(
|
||||
OAuthConsent.user_id == user_id,
|
||||
OAuthConsent.client_id == client_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def grant_consent(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: UUID,
|
||||
client_id: str,
|
||||
scopes: list[str],
|
||||
) -> OAuthConsent:
|
||||
"""
|
||||
Create or update consent for a user-client pair.
|
||||
|
||||
If consent already exists, the new scopes are merged with existing ones.
|
||||
Returns the created or updated consent record.
|
||||
"""
|
||||
consent = await self.get_consent(db, user_id=user_id, client_id=client_id)
|
||||
|
||||
if consent:
|
||||
existing = (
|
||||
set(consent.granted_scopes.split()) if consent.granted_scopes else set()
|
||||
)
|
||||
merged = existing | set(scopes)
|
||||
consent.granted_scopes = " ".join(sorted(merged)) # type: ignore[assignment]
|
||||
else:
|
||||
consent = OAuthConsent(
|
||||
user_id=user_id,
|
||||
client_id=client_id,
|
||||
granted_scopes=" ".join(sorted(set(scopes))),
|
||||
)
|
||||
db.add(consent)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(consent)
|
||||
return consent
|
||||
|
||||
async def get_user_consents_with_clients(
|
||||
self, db: AsyncSession, *, user_id: UUID
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get all consent records for a user joined with client details."""
|
||||
result = await db.execute(
|
||||
select(OAuthConsent, OAuthClient)
|
||||
.join(OAuthClient, OAuthConsent.client_id == OAuthClient.client_id)
|
||||
.where(OAuthConsent.user_id == user_id)
|
||||
)
|
||||
rows = result.all()
|
||||
return [
|
||||
{
|
||||
"client_id": consent.client_id,
|
||||
"client_name": client.client_name,
|
||||
"client_description": client.client_description,
|
||||
"granted_scopes": consent.granted_scopes.split()
|
||||
if consent.granted_scopes
|
||||
else [],
|
||||
"granted_at": consent.created_at.isoformat(),
|
||||
}
|
||||
for consent, client in rows
|
||||
]
|
||||
|
||||
async def revoke_consent(
|
||||
self, db: AsyncSession, *, user_id: UUID, client_id: str
|
||||
) -> bool:
|
||||
"""
|
||||
Delete the consent record for a user-client pair.
|
||||
|
||||
Returns True if a record was found and deleted.
|
||||
Note: Callers are responsible for also revoking associated tokens.
|
||||
"""
|
||||
result = await db.execute(
|
||||
delete(OAuthConsent).where(
|
||||
and_(
|
||||
OAuthConsent.user_id == user_id,
|
||||
OAuthConsent.client_id == client_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
await db.commit()
|
||||
return result.rowcount > 0 # type: ignore[attr-defined]
|
||||
|
||||
|
||||
# Singleton instance
|
||||
oauth_consent_repo = OAuthConsentRepository()
|
||||
142
backend/app/repositories/oauth_provider_token.py
Normal file
142
backend/app/repositories/oauth_provider_token.py
Normal file
@@ -0,0 +1,142 @@
|
||||
# app/repositories/oauth_provider_token.py
|
||||
"""Repository for OAuthProviderRefreshToken model."""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_, delete, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.oauth_provider_token import OAuthProviderRefreshToken
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuthProviderTokenRepository:
|
||||
"""Repository for OAuth provider refresh tokens."""
|
||||
|
||||
async def create_token(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
token_hash: str,
|
||||
jti: str,
|
||||
client_id: str,
|
||||
user_id: UUID,
|
||||
scope: str,
|
||||
expires_at: datetime,
|
||||
device_info: str | None = None,
|
||||
ip_address: str | None = None,
|
||||
) -> OAuthProviderRefreshToken:
|
||||
"""Create and persist a new refresh token record."""
|
||||
token = OAuthProviderRefreshToken(
|
||||
token_hash=token_hash,
|
||||
jti=jti,
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
scope=scope,
|
||||
expires_at=expires_at,
|
||||
device_info=device_info,
|
||||
ip_address=ip_address,
|
||||
)
|
||||
db.add(token)
|
||||
await db.commit()
|
||||
return token
|
||||
|
||||
async def get_by_token_hash(
|
||||
self, db: AsyncSession, *, token_hash: str
|
||||
) -> OAuthProviderRefreshToken | None:
|
||||
"""Get refresh token record by SHA-256 token hash."""
|
||||
result = await db.execute(
|
||||
select(OAuthProviderRefreshToken).where(
|
||||
OAuthProviderRefreshToken.token_hash == token_hash
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_by_jti(
|
||||
self, db: AsyncSession, *, jti: str
|
||||
) -> OAuthProviderRefreshToken | None:
|
||||
"""Get refresh token record by JWT ID (JTI)."""
|
||||
result = await db.execute(
|
||||
select(OAuthProviderRefreshToken).where(
|
||||
OAuthProviderRefreshToken.jti == jti
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def revoke(
|
||||
self, db: AsyncSession, *, token: OAuthProviderRefreshToken
|
||||
) -> None:
|
||||
"""Mark a specific token record as revoked."""
|
||||
token.revoked = True # type: ignore[assignment]
|
||||
token.last_used_at = datetime.now(UTC) # type: ignore[assignment]
|
||||
await db.commit()
|
||||
|
||||
async def revoke_all_for_user_client(
|
||||
self, db: AsyncSession, *, user_id: UUID, client_id: str
|
||||
) -> int:
|
||||
"""
|
||||
Revoke all active tokens for a specific user-client pair.
|
||||
|
||||
Used when security incidents are detected (e.g., authorization code reuse).
|
||||
Returns the number of tokens revoked.
|
||||
"""
|
||||
result = await db.execute(
|
||||
update(OAuthProviderRefreshToken)
|
||||
.where(
|
||||
and_(
|
||||
OAuthProviderRefreshToken.user_id == user_id,
|
||||
OAuthProviderRefreshToken.client_id == client_id,
|
||||
OAuthProviderRefreshToken.revoked == False, # noqa: E712
|
||||
)
|
||||
)
|
||||
.values(revoked=True)
|
||||
)
|
||||
count = result.rowcount # type: ignore[attr-defined]
|
||||
if count > 0:
|
||||
await db.commit()
|
||||
return count
|
||||
|
||||
async def revoke_all_for_user(self, db: AsyncSession, *, user_id: UUID) -> int:
|
||||
"""
|
||||
Revoke all active tokens for a user across all clients.
|
||||
|
||||
Used when user changes password or logs out everywhere.
|
||||
Returns the number of tokens revoked.
|
||||
"""
|
||||
result = await db.execute(
|
||||
update(OAuthProviderRefreshToken)
|
||||
.where(
|
||||
and_(
|
||||
OAuthProviderRefreshToken.user_id == user_id,
|
||||
OAuthProviderRefreshToken.revoked == False, # noqa: E712
|
||||
)
|
||||
)
|
||||
.values(revoked=True)
|
||||
)
|
||||
count = result.rowcount # type: ignore[attr-defined]
|
||||
if count > 0:
|
||||
await db.commit()
|
||||
return count
|
||||
|
||||
async def cleanup_expired(self, db: AsyncSession, *, cutoff_days: int = 7) -> int:
|
||||
"""
|
||||
Delete expired refresh tokens older than cutoff_days.
|
||||
|
||||
Should be called periodically (e.g., daily).
|
||||
Returns the number of tokens deleted.
|
||||
"""
|
||||
cutoff = datetime.now(UTC) - timedelta(days=cutoff_days)
|
||||
result = await db.execute(
|
||||
delete(OAuthProviderRefreshToken).where(
|
||||
OAuthProviderRefreshToken.expires_at < cutoff
|
||||
)
|
||||
)
|
||||
await db.commit()
|
||||
return result.rowcount # type: ignore[attr-defined]
|
||||
|
||||
|
||||
# Singleton instance
|
||||
oauth_provider_token_repo = OAuthProviderTokenRepository()
|
||||
113
backend/app/repositories/oauth_state.py
Normal file
113
backend/app/repositories/oauth_state.py
Normal file
@@ -0,0 +1,113 @@
|
||||
# app/repositories/oauth_state.py
|
||||
"""Repository for OAuthState model async database operations."""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.repository_exceptions import DuplicateEntryError
|
||||
from app.models.oauth_state import OAuthState
|
||||
from app.repositories.base import BaseRepository
|
||||
from app.schemas.oauth import OAuthStateCreate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmptySchema(BaseModel):
|
||||
"""Placeholder schema for repository operations that don't need update schemas."""
|
||||
|
||||
|
||||
class OAuthStateRepository(BaseRepository[OAuthState, OAuthStateCreate, EmptySchema]):
|
||||
"""Repository for OAuth state (CSRF protection)."""
|
||||
|
||||
async def create_state(
|
||||
self, db: AsyncSession, *, obj_in: OAuthStateCreate
|
||||
) -> OAuthState:
|
||||
"""Create a new OAuth state for CSRF protection."""
|
||||
try:
|
||||
db_obj = OAuthState(
|
||||
state=obj_in.state,
|
||||
code_verifier=obj_in.code_verifier,
|
||||
nonce=obj_in.nonce,
|
||||
provider=obj_in.provider,
|
||||
redirect_uri=obj_in.redirect_uri,
|
||||
user_id=obj_in.user_id,
|
||||
expires_at=obj_in.expires_at,
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
|
||||
logger.debug("OAuth state created for %s", obj_in.provider)
|
||||
return db_obj
|
||||
except IntegrityError as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
logger.error("OAuth state collision: %s", error_msg)
|
||||
raise DuplicateEntryError("Failed to create OAuth state, please retry")
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.exception("Error creating OAuth state: %s", e)
|
||||
raise
|
||||
|
||||
async def get_and_consume_state(
|
||||
self, db: AsyncSession, *, state: str
|
||||
) -> OAuthState | None:
|
||||
"""Get and delete OAuth state (consume it)."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(OAuthState).where(OAuthState.state == state)
|
||||
)
|
||||
db_obj = result.scalar_one_or_none()
|
||||
|
||||
if db_obj is None:
|
||||
logger.warning("OAuth state not found: %s...", state[:8])
|
||||
return None
|
||||
|
||||
now = datetime.now(UTC)
|
||||
expires_at = db_obj.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=UTC)
|
||||
|
||||
if expires_at < now:
|
||||
logger.warning("OAuth state expired: %s...", state[:8])
|
||||
await db.delete(db_obj)
|
||||
await db.commit()
|
||||
return None
|
||||
|
||||
await db.delete(db_obj)
|
||||
await db.commit()
|
||||
|
||||
logger.debug("OAuth state consumed: %s...", state[:8])
|
||||
return db_obj
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error("Error consuming OAuth state: %s", e)
|
||||
raise
|
||||
|
||||
async def cleanup_expired(self, db: AsyncSession) -> int:
|
||||
"""Clean up expired OAuth states."""
|
||||
try:
|
||||
now = datetime.now(UTC)
|
||||
|
||||
stmt = delete(OAuthState).where(OAuthState.expires_at < now)
|
||||
result = await db.execute(stmt)
|
||||
await db.commit()
|
||||
|
||||
count = result.rowcount
|
||||
if count > 0:
|
||||
logger.info("Cleaned up %s expired OAuth states", count)
|
||||
|
||||
return count
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error("Error cleaning up expired OAuth states: %s", e)
|
||||
raise
|
||||
|
||||
|
||||
# Singleton instance
|
||||
oauth_state_repo = OAuthStateRepository(OAuthState)
|
||||
311
backend/app/crud/organization.py → backend/app/repositories/organization.py
Executable file → Normal file
311
backend/app/crud/organization.py → backend/app/repositories/organization.py
Executable file → Normal file
@@ -1,17 +1,19 @@
|
||||
# app/crud/organization_async.py
|
||||
"""Async CRUD operations for Organization model using SQLAlchemy 2.0 patterns."""
|
||||
# app/repositories/organization.py
|
||||
"""Repository for Organization model async database operations using SQLAlchemy 2.0 patterns."""
|
||||
|
||||
import logging
|
||||
from typing import Optional, List, Dict, Any
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func, or_, and_, select, case
|
||||
from sqlalchemy import and_, case, func, or_, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.crud.base import CRUDBase
|
||||
from app.core.repository_exceptions import DuplicateEntryError, IntegrityConstraintError
|
||||
from app.models.organization import Organization
|
||||
from app.models.user import User
|
||||
from app.models.user_organization import UserOrganization, OrganizationRole
|
||||
from app.models.user_organization import OrganizationRole, UserOrganization
|
||||
from app.repositories.base import BaseRepository
|
||||
from app.schemas.organizations import (
|
||||
OrganizationCreate,
|
||||
OrganizationUpdate,
|
||||
@@ -20,10 +22,12 @@ from app.schemas.organizations import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUpdate]):
|
||||
"""Async CRUD operations for Organization model."""
|
||||
class OrganizationRepository(
|
||||
BaseRepository[Organization, OrganizationCreate, OrganizationUpdate]
|
||||
):
|
||||
"""Repository for Organization model."""
|
||||
|
||||
async def get_by_slug(self, db: AsyncSession, *, slug: str) -> Optional[Organization]:
|
||||
async def get_by_slug(self, db: AsyncSession, *, slug: str) -> Organization | None:
|
||||
"""Get organization by slug."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
@@ -31,10 +35,12 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting organization by slug {slug}: {str(e)}")
|
||||
logger.error("Error getting organization by slug %s: %s", slug, e)
|
||||
raise
|
||||
|
||||
async def create(self, db: AsyncSession, *, obj_in: OrganizationCreate) -> Organization:
|
||||
async def create(
|
||||
self, db: AsyncSession, *, obj_in: OrganizationCreate
|
||||
) -> Organization:
|
||||
"""Create a new organization with error handling."""
|
||||
try:
|
||||
db_obj = Organization(
|
||||
@@ -42,7 +48,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
slug=obj_in.slug,
|
||||
description=obj_in.description,
|
||||
is_active=obj_in.is_active,
|
||||
settings=obj_in.settings or {}
|
||||
settings=obj_in.settings or {},
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
@@ -50,15 +56,21 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
||||
if "slug" in error_msg.lower():
|
||||
logger.warning(f"Duplicate slug attempted: {obj_in.slug}")
|
||||
raise ValueError(f"Organization with slug '{obj_in.slug}' already exists")
|
||||
logger.error(f"Integrity error creating organization: {error_msg}")
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
if (
|
||||
"slug" in error_msg.lower()
|
||||
or "unique" in error_msg.lower()
|
||||
or "duplicate" in error_msg.lower()
|
||||
):
|
||||
logger.warning("Duplicate slug attempted: %s", obj_in.slug)
|
||||
raise DuplicateEntryError(
|
||||
f"Organization with slug '{obj_in.slug}' already exists"
|
||||
)
|
||||
logger.error("Integrity error creating organization: %s", error_msg)
|
||||
raise IntegrityConstraintError(f"Database integrity error: {error_msg}")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Unexpected error creating organization: {str(e)}", exc_info=True)
|
||||
logger.exception("Unexpected error creating organization: %s", e)
|
||||
raise
|
||||
|
||||
async def get_multi_with_filters(
|
||||
@@ -67,21 +79,15 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: Optional[bool] = None,
|
||||
search: Optional[str] = None,
|
||||
is_active: bool | None = None,
|
||||
search: str | None = None,
|
||||
sort_by: str = "created_at",
|
||||
sort_order: str = "desc"
|
||||
) -> tuple[List[Organization], int]:
|
||||
"""
|
||||
Get multiple organizations with filtering, searching, and sorting.
|
||||
|
||||
Returns:
|
||||
Tuple of (organizations list, total count)
|
||||
"""
|
||||
sort_order: str = "desc",
|
||||
) -> tuple[list[Organization], int]:
|
||||
"""Get multiple organizations with filtering, searching, and sorting."""
|
||||
try:
|
||||
query = select(Organization)
|
||||
|
||||
# Apply filters
|
||||
if is_active is not None:
|
||||
query = query.where(Organization.is_active == is_active)
|
||||
|
||||
@@ -89,30 +95,27 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
search_filter = or_(
|
||||
Organization.name.ilike(f"%{search}%"),
|
||||
Organization.slug.ilike(f"%{search}%"),
|
||||
Organization.description.ilike(f"%{search}%")
|
||||
Organization.description.ilike(f"%{search}%"),
|
||||
)
|
||||
query = query.where(search_filter)
|
||||
|
||||
# Get total count before pagination
|
||||
count_query = select(func.count()).select_from(query.alias())
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# Apply sorting
|
||||
sort_column = getattr(Organization, sort_by, Organization.created_at)
|
||||
if sort_order == "desc":
|
||||
query = query.order_by(sort_column.desc())
|
||||
else:
|
||||
query = query.order_by(sort_column.asc())
|
||||
|
||||
# Apply pagination
|
||||
query = query.offset(skip).limit(limit)
|
||||
result = await db.execute(query)
|
||||
organizations = list(result.scalars().all())
|
||||
|
||||
return organizations, total
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting organizations with filters: {str(e)}")
|
||||
logger.error("Error getting organizations with filters: %s", e)
|
||||
raise
|
||||
|
||||
async def get_member_count(self, db: AsyncSession, *, organization_id: UUID) -> int:
|
||||
@@ -122,13 +125,15 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
select(func.count(UserOrganization.user_id)).where(
|
||||
and_(
|
||||
UserOrganization.organization_id == organization_id,
|
||||
UserOrganization.is_active == True
|
||||
UserOrganization.is_active,
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalar_one() or 0
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting member count for organization {organization_id}: {str(e)}")
|
||||
logger.error(
|
||||
"Error getting member count for organization %s: %s", organization_id, e
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_multi_with_member_counts(
|
||||
@@ -137,76 +142,70 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: Optional[bool] = None,
|
||||
search: Optional[str] = None
|
||||
) -> tuple[List[Dict[str, Any]], int]:
|
||||
"""
|
||||
Get organizations with member counts in a SINGLE QUERY using JOIN and GROUP BY.
|
||||
This eliminates the N+1 query problem.
|
||||
|
||||
Returns:
|
||||
Tuple of (list of dicts with org and member_count, total count)
|
||||
"""
|
||||
is_active: bool | None = None,
|
||||
search: str | None = None,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""Get organizations with member counts in a SINGLE QUERY using JOIN and GROUP BY."""
|
||||
try:
|
||||
# Build base query with LEFT JOIN and GROUP BY
|
||||
# Use CASE statement to count only active members
|
||||
query = (
|
||||
select(
|
||||
Organization,
|
||||
func.count(
|
||||
func.distinct(
|
||||
case(
|
||||
(UserOrganization.is_active == True, UserOrganization.user_id),
|
||||
else_=None
|
||||
(
|
||||
UserOrganization.is_active,
|
||||
UserOrganization.user_id,
|
||||
),
|
||||
else_=None,
|
||||
)
|
||||
)
|
||||
).label('member_count')
|
||||
).label("member_count"),
|
||||
)
|
||||
.outerjoin(
|
||||
UserOrganization,
|
||||
Organization.id == UserOrganization.organization_id,
|
||||
)
|
||||
.outerjoin(UserOrganization, Organization.id == UserOrganization.organization_id)
|
||||
.group_by(Organization.id)
|
||||
)
|
||||
|
||||
# Apply filters
|
||||
if is_active is not None:
|
||||
query = query.where(Organization.is_active == is_active)
|
||||
|
||||
search_filter = None
|
||||
if search:
|
||||
search_filter = or_(
|
||||
Organization.name.ilike(f"%{search}%"),
|
||||
Organization.slug.ilike(f"%{search}%"),
|
||||
Organization.description.ilike(f"%{search}%")
|
||||
Organization.description.ilike(f"%{search}%"),
|
||||
)
|
||||
query = query.where(search_filter)
|
||||
|
||||
# Get total count
|
||||
count_query = select(func.count(Organization.id))
|
||||
if is_active is not None:
|
||||
count_query = count_query.where(Organization.is_active == is_active)
|
||||
if search:
|
||||
if search_filter is not None:
|
||||
count_query = count_query.where(search_filter)
|
||||
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# Apply pagination and ordering
|
||||
query = query.order_by(Organization.created_at.desc()).offset(skip).limit(limit)
|
||||
query = (
|
||||
query.order_by(Organization.created_at.desc()).offset(skip).limit(limit)
|
||||
)
|
||||
|
||||
result = await db.execute(query)
|
||||
rows = result.all()
|
||||
|
||||
# Convert to list of dicts
|
||||
orgs_with_counts = [
|
||||
{
|
||||
'organization': org,
|
||||
'member_count': member_count
|
||||
}
|
||||
{"organization": org, "member_count": member_count}
|
||||
for org, member_count in rows
|
||||
]
|
||||
|
||||
return orgs_with_counts, total
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting organizations with member counts: {str(e)}", exc_info=True)
|
||||
logger.exception("Error getting organizations with member counts: %s", e)
|
||||
raise
|
||||
|
||||
async def add_user(
|
||||
@@ -216,23 +215,21 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
organization_id: UUID,
|
||||
user_id: UUID,
|
||||
role: OrganizationRole = OrganizationRole.MEMBER,
|
||||
custom_permissions: Optional[str] = None
|
||||
custom_permissions: str | None = None,
|
||||
) -> UserOrganization:
|
||||
"""Add a user to an organization with a specific role."""
|
||||
try:
|
||||
# Check if relationship already exists
|
||||
result = await db.execute(
|
||||
select(UserOrganization).where(
|
||||
and_(
|
||||
UserOrganization.user_id == user_id,
|
||||
UserOrganization.organization_id == organization_id
|
||||
UserOrganization.organization_id == organization_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
existing = result.scalar_one_or_none()
|
||||
|
||||
if existing:
|
||||
# Reactivate if inactive, or raise error if already active
|
||||
if not existing.is_active:
|
||||
existing.is_active = True
|
||||
existing.role = role
|
||||
@@ -241,15 +238,16 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
await db.refresh(existing)
|
||||
return existing
|
||||
else:
|
||||
raise ValueError("User is already a member of this organization")
|
||||
raise DuplicateEntryError(
|
||||
"User is already a member of this organization"
|
||||
)
|
||||
|
||||
# Create new relationship
|
||||
user_org = UserOrganization(
|
||||
user_id=user_id,
|
||||
organization_id=organization_id,
|
||||
role=role,
|
||||
is_active=True,
|
||||
custom_permissions=custom_permissions
|
||||
custom_permissions=custom_permissions,
|
||||
)
|
||||
db.add(user_org)
|
||||
await db.commit()
|
||||
@@ -257,19 +255,15 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
return user_org
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Integrity error adding user to organization: {str(e)}")
|
||||
raise ValueError("Failed to add user to organization")
|
||||
logger.error("Integrity error adding user to organization: %s", e)
|
||||
raise IntegrityConstraintError("Failed to add user to organization")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error adding user to organization: {str(e)}", exc_info=True)
|
||||
logger.exception("Error adding user to organization: %s", e)
|
||||
raise
|
||||
|
||||
async def remove_user(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
organization_id: UUID,
|
||||
user_id: UUID
|
||||
self, db: AsyncSession, *, organization_id: UUID, user_id: UUID
|
||||
) -> bool:
|
||||
"""Remove a user from an organization (soft delete)."""
|
||||
try:
|
||||
@@ -277,7 +271,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
select(UserOrganization).where(
|
||||
and_(
|
||||
UserOrganization.user_id == user_id,
|
||||
UserOrganization.organization_id == organization_id
|
||||
UserOrganization.organization_id == organization_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
@@ -291,7 +285,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
return True
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error removing user from organization: {str(e)}", exc_info=True)
|
||||
logger.exception("Error removing user from organization: %s", e)
|
||||
raise
|
||||
|
||||
async def update_user_role(
|
||||
@@ -301,15 +295,15 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
organization_id: UUID,
|
||||
user_id: UUID,
|
||||
role: OrganizationRole,
|
||||
custom_permissions: Optional[str] = None
|
||||
) -> Optional[UserOrganization]:
|
||||
custom_permissions: str | None = None,
|
||||
) -> UserOrganization | None:
|
||||
"""Update a user's role in an organization."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(UserOrganization).where(
|
||||
and_(
|
||||
UserOrganization.user_id == user_id,
|
||||
UserOrganization.organization_id == organization_id
|
||||
UserOrganization.organization_id == organization_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
@@ -326,7 +320,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
return user_org
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error updating user role: {str(e)}", exc_info=True)
|
||||
logger.exception("Error updating user role: %s", e)
|
||||
raise
|
||||
|
||||
async def get_organization_members(
|
||||
@@ -336,16 +330,10 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
organization_id: UUID,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: bool = True
|
||||
) -> tuple[List[Dict[str, Any]], int]:
|
||||
"""
|
||||
Get members of an organization with user details.
|
||||
|
||||
Returns:
|
||||
Tuple of (members list with user details, total count)
|
||||
"""
|
||||
is_active: bool | None = True,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""Get members of an organization with user details."""
|
||||
try:
|
||||
# Build query with join
|
||||
query = (
|
||||
select(UserOrganization, User)
|
||||
.join(User, UserOrganization.user_id == User.id)
|
||||
@@ -355,50 +343,57 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
if is_active is not None:
|
||||
query = query.where(UserOrganization.is_active == is_active)
|
||||
|
||||
# Get total count
|
||||
count_query = select(func.count()).select_from(
|
||||
select(UserOrganization)
|
||||
.where(UserOrganization.organization_id == organization_id)
|
||||
.where(UserOrganization.is_active == is_active if is_active is not None else True)
|
||||
.where(
|
||||
UserOrganization.is_active == is_active
|
||||
if is_active is not None
|
||||
else True
|
||||
)
|
||||
.alias()
|
||||
)
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# Apply ordering and pagination
|
||||
query = query.order_by(UserOrganization.created_at.desc()).offset(skip).limit(limit)
|
||||
query = (
|
||||
query.order_by(UserOrganization.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
result = await db.execute(query)
|
||||
results = result.all()
|
||||
|
||||
members = []
|
||||
for user_org, user in results:
|
||||
members.append({
|
||||
"user_id": user.id,
|
||||
"email": user.email,
|
||||
"first_name": user.first_name,
|
||||
"last_name": user.last_name,
|
||||
"role": user_org.role,
|
||||
"is_active": user_org.is_active,
|
||||
"joined_at": user_org.created_at
|
||||
})
|
||||
members.append(
|
||||
{
|
||||
"user_id": user.id,
|
||||
"email": user.email,
|
||||
"first_name": user.first_name,
|
||||
"last_name": user.last_name,
|
||||
"role": user_org.role,
|
||||
"is_active": user_org.is_active,
|
||||
"joined_at": user_org.created_at,
|
||||
}
|
||||
)
|
||||
|
||||
return members, total
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting organization members: {str(e)}")
|
||||
logger.error("Error getting organization members: %s", e)
|
||||
raise
|
||||
|
||||
async def get_user_organizations(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: UUID,
|
||||
is_active: bool = True
|
||||
) -> List[Organization]:
|
||||
self, db: AsyncSession, *, user_id: UUID, is_active: bool | None = True
|
||||
) -> list[Organization]:
|
||||
"""Get all organizations a user belongs to."""
|
||||
try:
|
||||
query = (
|
||||
select(Organization)
|
||||
.join(UserOrganization, Organization.id == UserOrganization.organization_id)
|
||||
.join(
|
||||
UserOrganization,
|
||||
Organization.id == UserOrganization.organization_id,
|
||||
)
|
||||
.where(UserOrganization.user_id == user_id)
|
||||
)
|
||||
|
||||
@@ -408,44 +403,40 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user organizations: {str(e)}")
|
||||
logger.error("Error getting user organizations: %s", e)
|
||||
raise
|
||||
|
||||
async def get_user_organizations_with_details(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: UUID,
|
||||
is_active: bool = True
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get user's organizations with role and member count in SINGLE QUERY.
|
||||
Eliminates N+1 problem by using subquery for member counts.
|
||||
|
||||
Returns:
|
||||
List of dicts with organization, role, and member_count
|
||||
"""
|
||||
self, db: AsyncSession, *, user_id: UUID, is_active: bool | None = True
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get user's organizations with role and member count in SINGLE QUERY."""
|
||||
try:
|
||||
# Subquery to get member counts for each organization
|
||||
member_count_subq = (
|
||||
select(
|
||||
UserOrganization.organization_id,
|
||||
func.count(UserOrganization.user_id).label('member_count')
|
||||
func.count(UserOrganization.user_id).label("member_count"),
|
||||
)
|
||||
.where(UserOrganization.is_active == True)
|
||||
.where(UserOrganization.is_active)
|
||||
.group_by(UserOrganization.organization_id)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
# Main query with JOIN to get org, role, and member count
|
||||
query = (
|
||||
select(
|
||||
Organization,
|
||||
UserOrganization.role,
|
||||
func.coalesce(member_count_subq.c.member_count, 0).label('member_count')
|
||||
func.coalesce(member_count_subq.c.member_count, 0).label(
|
||||
"member_count"
|
||||
),
|
||||
)
|
||||
.join(
|
||||
UserOrganization,
|
||||
Organization.id == UserOrganization.organization_id,
|
||||
)
|
||||
.outerjoin(
|
||||
member_count_subq,
|
||||
Organization.id == member_count_subq.c.organization_id,
|
||||
)
|
||||
.join(UserOrganization, Organization.id == UserOrganization.organization_id)
|
||||
.outerjoin(member_count_subq, Organization.id == member_count_subq.c.organization_id)
|
||||
.where(UserOrganization.user_id == user_id)
|
||||
)
|
||||
|
||||
@@ -456,25 +447,17 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
rows = result.all()
|
||||
|
||||
return [
|
||||
{
|
||||
'organization': org,
|
||||
'role': role,
|
||||
'member_count': member_count
|
||||
}
|
||||
{"organization": org, "role": role, "member_count": member_count}
|
||||
for org, role, member_count in rows
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user organizations with details: {str(e)}", exc_info=True)
|
||||
logger.exception("Error getting user organizations with details: %s", e)
|
||||
raise
|
||||
|
||||
async def get_user_role_in_org(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: UUID,
|
||||
organization_id: UUID
|
||||
) -> Optional[OrganizationRole]:
|
||||
self, db: AsyncSession, *, user_id: UUID, organization_id: UUID
|
||||
) -> OrganizationRole | None:
|
||||
"""Get a user's role in a specific organization."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
@@ -482,39 +465,35 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
and_(
|
||||
UserOrganization.user_id == user_id,
|
||||
UserOrganization.organization_id == organization_id,
|
||||
UserOrganization.is_active == True
|
||||
UserOrganization.is_active,
|
||||
)
|
||||
)
|
||||
)
|
||||
user_org = result.scalar_one_or_none()
|
||||
|
||||
return user_org.role if user_org else None
|
||||
return user_org.role if user_org else None # pyright: ignore[reportReturnType]
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user role in org: {str(e)}")
|
||||
logger.error("Error getting user role in org: %s", e)
|
||||
raise
|
||||
|
||||
async def is_user_org_owner(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: UUID,
|
||||
organization_id: UUID
|
||||
self, db: AsyncSession, *, user_id: UUID, organization_id: UUID
|
||||
) -> bool:
|
||||
"""Check if a user is an owner of an organization."""
|
||||
role = await self.get_user_role_in_org(db, user_id=user_id, organization_id=organization_id)
|
||||
role = await self.get_user_role_in_org(
|
||||
db, user_id=user_id, organization_id=organization_id
|
||||
)
|
||||
return role == OrganizationRole.OWNER
|
||||
|
||||
async def is_user_org_admin(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: UUID,
|
||||
organization_id: UUID
|
||||
self, db: AsyncSession, *, user_id: UUID, organization_id: UUID
|
||||
) -> bool:
|
||||
"""Check if a user is an owner or admin of an organization."""
|
||||
role = await self.get_user_role_in_org(db, user_id=user_id, organization_id=organization_id)
|
||||
role = await self.get_user_role_in_org(
|
||||
db, user_id=user_id, organization_id=organization_id
|
||||
)
|
||||
return role in [OrganizationRole.OWNER, OrganizationRole.ADMIN]
|
||||
|
||||
|
||||
# Create a singleton instance for use across the application
|
||||
organization = CRUDOrganization(Organization)
|
||||
# Singleton instance
|
||||
organization_repo = OrganizationRepository(Organization)
|
||||
333
backend/app/repositories/session.py
Normal file
333
backend/app/repositories/session.py
Normal file
@@ -0,0 +1,333 @@
|
||||
# app/repositories/session.py
|
||||
"""Repository for UserSession model async database operations using SQLAlchemy 2.0 patterns."""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_, delete, func, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from app.core.repository_exceptions import IntegrityConstraintError, InvalidInputError
|
||||
from app.models.user_session import UserSession
|
||||
from app.repositories.base import BaseRepository
|
||||
from app.schemas.sessions import SessionCreate, SessionUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SessionRepository(BaseRepository[UserSession, SessionCreate, SessionUpdate]):
|
||||
"""Repository for UserSession model."""
|
||||
|
||||
async def get_by_jti(self, db: AsyncSession, *, jti: str) -> UserSession | None:
|
||||
"""Get session by refresh token JTI."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(UserSession).where(UserSession.refresh_token_jti == jti)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error("Error getting session by JTI %s: %s", jti, e)
|
||||
raise
|
||||
|
||||
async def get_active_by_jti(
|
||||
self, db: AsyncSession, *, jti: str
|
||||
) -> UserSession | None:
|
||||
"""Get active session by refresh token JTI."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(UserSession).where(
|
||||
and_(
|
||||
UserSession.refresh_token_jti == jti,
|
||||
UserSession.is_active,
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error("Error getting active session by JTI %s: %s", jti, e)
|
||||
raise
|
||||
|
||||
async def get_user_sessions(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: str,
|
||||
active_only: bool = True,
|
||||
with_user: bool = False,
|
||||
) -> list[UserSession]:
|
||||
"""Get all sessions for a user with optional eager loading."""
|
||||
try:
|
||||
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||
|
||||
query = select(UserSession).where(UserSession.user_id == user_uuid)
|
||||
|
||||
if with_user:
|
||||
query = query.options(joinedload(UserSession.user))
|
||||
|
||||
if active_only:
|
||||
query = query.where(UserSession.is_active)
|
||||
|
||||
query = query.order_by(UserSession.last_used_at.desc())
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e:
|
||||
logger.error("Error getting sessions for user %s: %s", user_id, e)
|
||||
raise
|
||||
|
||||
async def create_session(
|
||||
self, db: AsyncSession, *, obj_in: SessionCreate
|
||||
) -> UserSession:
|
||||
"""Create a new user session."""
|
||||
try:
|
||||
db_obj = UserSession(
|
||||
user_id=obj_in.user_id,
|
||||
refresh_token_jti=obj_in.refresh_token_jti,
|
||||
device_name=obj_in.device_name,
|
||||
device_id=obj_in.device_id,
|
||||
ip_address=obj_in.ip_address,
|
||||
user_agent=obj_in.user_agent,
|
||||
last_used_at=obj_in.last_used_at,
|
||||
expires_at=obj_in.expires_at,
|
||||
is_active=True,
|
||||
location_city=obj_in.location_city,
|
||||
location_country=obj_in.location_country,
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
|
||||
logger.info(
|
||||
"Session created for user %s from %s (IP: %s)",
|
||||
obj_in.user_id,
|
||||
obj_in.device_name,
|
||||
obj_in.ip_address,
|
||||
)
|
||||
|
||||
return db_obj
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.exception("Error creating session: %s", e)
|
||||
raise IntegrityConstraintError(f"Failed to create session: {e!s}")
|
||||
|
||||
async def deactivate(
|
||||
self, db: AsyncSession, *, session_id: str
|
||||
) -> UserSession | None:
|
||||
"""Deactivate a session (logout from device)."""
|
||||
try:
|
||||
session = await self.get(db, id=session_id)
|
||||
if not session:
|
||||
logger.warning("Session %s not found for deactivation", session_id)
|
||||
return None
|
||||
|
||||
session.is_active = False
|
||||
db.add(session)
|
||||
await db.commit()
|
||||
await db.refresh(session)
|
||||
|
||||
logger.info(
|
||||
"Session %s deactivated for user %s (%s)",
|
||||
session_id,
|
||||
session.user_id,
|
||||
session.device_name,
|
||||
)
|
||||
|
||||
return session
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error("Error deactivating session %s: %s", session_id, e)
|
||||
raise
|
||||
|
||||
async def deactivate_all_user_sessions(
|
||||
self, db: AsyncSession, *, user_id: str
|
||||
) -> int:
|
||||
"""Deactivate all active sessions for a user (logout from all devices)."""
|
||||
try:
|
||||
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||
|
||||
stmt = (
|
||||
update(UserSession)
|
||||
.where(and_(UserSession.user_id == user_uuid, UserSession.is_active))
|
||||
.values(is_active=False)
|
||||
)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
await db.commit()
|
||||
|
||||
count = result.rowcount
|
||||
|
||||
logger.info("Deactivated %s sessions for user %s", count, user_id)
|
||||
|
||||
return count
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error("Error deactivating all sessions for user %s: %s", user_id, e)
|
||||
raise
|
||||
|
||||
async def update_last_used(
|
||||
self, db: AsyncSession, *, session: UserSession
|
||||
) -> UserSession:
|
||||
"""Update the last_used_at timestamp for a session."""
|
||||
try:
|
||||
session.last_used_at = datetime.now(UTC)
|
||||
db.add(session)
|
||||
await db.commit()
|
||||
await db.refresh(session)
|
||||
return session
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error("Error updating last_used for session %s: %s", session.id, e)
|
||||
raise
|
||||
|
||||
async def update_refresh_token(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
session: UserSession,
|
||||
new_jti: str,
|
||||
new_expires_at: datetime,
|
||||
) -> UserSession:
|
||||
"""Update session with new refresh token JTI and expiration."""
|
||||
try:
|
||||
session.refresh_token_jti = new_jti
|
||||
session.expires_at = new_expires_at
|
||||
session.last_used_at = datetime.now(UTC)
|
||||
db.add(session)
|
||||
await db.commit()
|
||||
await db.refresh(session)
|
||||
return session
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
"Error updating refresh token for session %s: %s", session.id, e
|
||||
)
|
||||
raise
|
||||
|
||||
async def cleanup_expired(self, db: AsyncSession, *, keep_days: int = 30) -> int:
|
||||
"""Clean up expired sessions using optimized bulk DELETE."""
|
||||
try:
|
||||
cutoff_date = datetime.now(UTC) - timedelta(days=keep_days)
|
||||
now = datetime.now(UTC)
|
||||
|
||||
stmt = delete(UserSession).where(
|
||||
and_(
|
||||
UserSession.is_active == False, # noqa: E712
|
||||
UserSession.expires_at < now,
|
||||
UserSession.created_at < cutoff_date,
|
||||
)
|
||||
)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
await db.commit()
|
||||
|
||||
count = result.rowcount
|
||||
|
||||
if count > 0:
|
||||
logger.info("Cleaned up %s expired sessions using bulk DELETE", count)
|
||||
|
||||
return count
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error("Error cleaning up expired sessions: %s", e)
|
||||
raise
|
||||
|
||||
async def cleanup_expired_for_user(self, db: AsyncSession, *, user_id: str) -> int:
|
||||
"""Clean up expired and inactive sessions for a specific user."""
|
||||
try:
|
||||
try:
|
||||
uuid_obj = uuid.UUID(user_id)
|
||||
except (ValueError, AttributeError):
|
||||
logger.error("Invalid UUID format: %s", user_id)
|
||||
raise InvalidInputError(f"Invalid user ID format: {user_id}")
|
||||
|
||||
now = datetime.now(UTC)
|
||||
|
||||
stmt = delete(UserSession).where(
|
||||
and_(
|
||||
UserSession.user_id == uuid_obj,
|
||||
UserSession.is_active == False, # noqa: E712
|
||||
UserSession.expires_at < now,
|
||||
)
|
||||
)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
await db.commit()
|
||||
|
||||
count = result.rowcount
|
||||
|
||||
if count > 0:
|
||||
logger.info(
|
||||
"Cleaned up %s expired sessions for user %s using bulk DELETE",
|
||||
count,
|
||||
user_id,
|
||||
)
|
||||
|
||||
return count
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
"Error cleaning up expired sessions for user %s: %s", user_id, e
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_user_session_count(self, db: AsyncSession, *, user_id: str) -> int:
|
||||
"""Get count of active sessions for a user."""
|
||||
try:
|
||||
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||
|
||||
result = await db.execute(
|
||||
select(func.count(UserSession.id)).where(
|
||||
and_(UserSession.user_id == user_uuid, UserSession.is_active)
|
||||
)
|
||||
)
|
||||
return result.scalar_one()
|
||||
except Exception as e:
|
||||
logger.error("Error counting sessions for user %s: %s", user_id, e)
|
||||
raise
|
||||
|
||||
async def get_all_sessions(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
active_only: bool = True,
|
||||
with_user: bool = True,
|
||||
) -> tuple[list[UserSession], int]:
|
||||
"""Get all sessions across all users with pagination (admin only)."""
|
||||
try:
|
||||
query = select(UserSession)
|
||||
|
||||
if with_user:
|
||||
query = query.options(joinedload(UserSession.user))
|
||||
|
||||
if active_only:
|
||||
query = query.where(UserSession.is_active)
|
||||
|
||||
count_query = select(func.count(UserSession.id))
|
||||
if active_only:
|
||||
count_query = count_query.where(UserSession.is_active)
|
||||
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
query = (
|
||||
query.order_by(UserSession.last_used_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
result = await db.execute(query)
|
||||
sessions = list(result.scalars().all())
|
||||
|
||||
return sessions, total
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error getting all sessions: %s", e)
|
||||
raise
|
||||
|
||||
|
||||
# Singleton instance
|
||||
session_repo = SessionRepository(UserSession)
|
||||
209
backend/app/crud/user.py → backend/app/repositories/user.py
Executable file → Normal file
209
backend/app/crud/user.py → backend/app/repositories/user.py
Executable file → Normal file
@@ -1,8 +1,9 @@
|
||||
# app/crud/user_async.py
|
||||
"""Async CRUD operations for User model using SQLAlchemy 2.0 patterns."""
|
||||
# app/repositories/user.py
|
||||
"""Repository for User model async database operations using SQLAlchemy 2.0 patterns."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Union, Dict, Any, List, Tuple
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import or_, select, update
|
||||
@@ -10,31 +11,29 @@ from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.auth import get_password_hash_async
|
||||
from app.crud.base import CRUDBase
|
||||
from app.core.repository_exceptions import DuplicateEntryError, InvalidInputError
|
||||
from app.models.user import User
|
||||
from app.repositories.base import BaseRepository
|
||||
from app.schemas.users import UserCreate, UserUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
"""Async CRUD operations for User model."""
|
||||
class UserRepository(BaseRepository[User, UserCreate, UserUpdate]):
|
||||
"""Repository for User model."""
|
||||
|
||||
async def get_by_email(self, db: AsyncSession, *, email: str) -> Optional[User]:
|
||||
async def get_by_email(self, db: AsyncSession, *, email: str) -> User | None:
|
||||
"""Get user by email address."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(User).where(User.email == email)
|
||||
)
|
||||
result = await db.execute(select(User).where(User.email == email))
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user by email {email}: {str(e)}")
|
||||
logger.error("Error getting user by email %s: %s", email, e)
|
||||
raise
|
||||
|
||||
async def create(self, db: AsyncSession, *, obj_in: UserCreate) -> User:
|
||||
"""Create a new user with async password hashing and error handling."""
|
||||
try:
|
||||
# Hash password asynchronously to avoid blocking event loop
|
||||
password_hash = await get_password_hash_async(obj_in.password)
|
||||
|
||||
db_obj = User(
|
||||
@@ -42,9 +41,13 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
password_hash=password_hash,
|
||||
first_name=obj_in.first_name,
|
||||
last_name=obj_in.last_name,
|
||||
phone_number=obj_in.phone_number if hasattr(obj_in, 'phone_number') else None,
|
||||
is_superuser=obj_in.is_superuser if hasattr(obj_in, 'is_superuser') else False,
|
||||
preferences={}
|
||||
phone_number=obj_in.phone_number
|
||||
if hasattr(obj_in, "phone_number")
|
||||
else None,
|
||||
is_superuser=obj_in.is_superuser
|
||||
if hasattr(obj_in, "is_superuser")
|
||||
else False,
|
||||
preferences={},
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
@@ -52,23 +55,55 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
if "email" in error_msg.lower():
|
||||
logger.warning(f"Duplicate email attempted: {obj_in.email}")
|
||||
raise ValueError(f"User with email {obj_in.email} already exists")
|
||||
logger.error(f"Integrity error creating user: {error_msg}")
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
logger.warning("Duplicate email attempted: %s", obj_in.email)
|
||||
raise DuplicateEntryError(
|
||||
f"User with email {obj_in.email} already exists"
|
||||
)
|
||||
logger.error("Integrity error creating user: %s", error_msg)
|
||||
raise DuplicateEntryError(f"Database integrity error: {error_msg}")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Unexpected error creating user: {str(e)}", exc_info=True)
|
||||
logger.exception("Unexpected error creating user: %s", e)
|
||||
raise
|
||||
|
||||
async def update(
|
||||
async def create_oauth_user(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
db_obj: User,
|
||||
obj_in: Union[UserUpdate, Dict[str, Any]]
|
||||
email: str,
|
||||
first_name: str = "User",
|
||||
last_name: str | None = None,
|
||||
) -> User:
|
||||
"""Create a new passwordless user for OAuth sign-in."""
|
||||
try:
|
||||
db_obj = User(
|
||||
email=email,
|
||||
password_hash=None, # OAuth-only user
|
||||
first_name=first_name,
|
||||
last_name=last_name,
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.flush() # Get user.id without committing
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
if "email" in error_msg.lower():
|
||||
logger.warning("Duplicate email attempted: %s", email)
|
||||
raise DuplicateEntryError(f"User with email {email} already exists")
|
||||
logger.error("Integrity error creating OAuth user: %s", error_msg)
|
||||
raise DuplicateEntryError(f"Database integrity error: {error_msg}")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.exception("Unexpected error creating OAuth user: %s", e)
|
||||
raise
|
||||
|
||||
async def update(
|
||||
self, db: AsyncSession, *, db_obj: User, obj_in: UserUpdate | dict[str, Any]
|
||||
) -> User:
|
||||
"""Update user with async password hashing if password is updated."""
|
||||
if isinstance(obj_in, dict):
|
||||
@@ -76,77 +111,65 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
else:
|
||||
update_data = obj_in.model_dump(exclude_unset=True)
|
||||
|
||||
# Handle password separately if it exists in update data
|
||||
# Hash password asynchronously to avoid blocking event loop
|
||||
if "password" in update_data:
|
||||
update_data["password_hash"] = await get_password_hash_async(update_data["password"])
|
||||
update_data["password_hash"] = await get_password_hash_async(
|
||||
update_data["password"]
|
||||
)
|
||||
del update_data["password"]
|
||||
|
||||
return await super().update(db, db_obj=db_obj, obj_in=update_data)
|
||||
|
||||
async def update_password(
|
||||
self, db: AsyncSession, *, user: User, password_hash: str
|
||||
) -> User:
|
||||
"""Set a new password hash on a user and commit."""
|
||||
user.password_hash = password_hash
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return user
|
||||
|
||||
async def get_multi_with_total(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
sort_by: Optional[str] = None,
|
||||
sort_by: str | None = None,
|
||||
sort_order: str = "asc",
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
search: Optional[str] = None
|
||||
) -> Tuple[List[User], int]:
|
||||
"""
|
||||
Get multiple users with total count, filtering, sorting, and search.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
skip: Number of records to skip
|
||||
limit: Maximum number of records to return
|
||||
sort_by: Field name to sort by
|
||||
sort_order: Sort order ("asc" or "desc")
|
||||
filters: Dictionary of filters (field_name: value)
|
||||
search: Search term to match against email, first_name, last_name
|
||||
|
||||
Returns:
|
||||
Tuple of (users list, total count)
|
||||
"""
|
||||
# Validate pagination
|
||||
filters: dict[str, Any] | None = None,
|
||||
search: str | None = None,
|
||||
) -> tuple[list[User], int]:
|
||||
"""Get multiple users with total count, filtering, sorting, and search."""
|
||||
if skip < 0:
|
||||
raise ValueError("skip must be non-negative")
|
||||
raise InvalidInputError("skip must be non-negative")
|
||||
if limit < 0:
|
||||
raise ValueError("limit must be non-negative")
|
||||
raise InvalidInputError("limit must be non-negative")
|
||||
if limit > 1000:
|
||||
raise ValueError("Maximum limit is 1000")
|
||||
raise InvalidInputError("Maximum limit is 1000")
|
||||
|
||||
try:
|
||||
# Build base query
|
||||
query = select(User)
|
||||
|
||||
# Exclude soft-deleted users
|
||||
query = query.where(User.deleted_at.is_(None))
|
||||
|
||||
# Apply filters
|
||||
if filters:
|
||||
for field, value in filters.items():
|
||||
if hasattr(User, field) and value is not None:
|
||||
query = query.where(getattr(User, field) == value)
|
||||
|
||||
# Apply search
|
||||
if search:
|
||||
search_filter = or_(
|
||||
User.email.ilike(f"%{search}%"),
|
||||
User.first_name.ilike(f"%{search}%"),
|
||||
User.last_name.ilike(f"%{search}%")
|
||||
User.last_name.ilike(f"%{search}%"),
|
||||
)
|
||||
query = query.where(search_filter)
|
||||
|
||||
# Get total count
|
||||
from sqlalchemy import func
|
||||
|
||||
count_query = select(func.count()).select_from(query.alias())
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# Apply sorting
|
||||
if sort_by and hasattr(User, sort_by):
|
||||
sort_column = getattr(User, sort_by)
|
||||
if sort_order.lower() == "desc":
|
||||
@@ -154,7 +177,6 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
else:
|
||||
query = query.order_by(sort_column.asc())
|
||||
|
||||
# Apply pagination
|
||||
query = query.offset(skip).limit(limit)
|
||||
result = await db.execute(query)
|
||||
users = list(result.scalars().all())
|
||||
@@ -162,88 +184,63 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
return users, total
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving paginated users: {str(e)}")
|
||||
logger.error("Error retrieving paginated users: %s", e)
|
||||
raise
|
||||
|
||||
async def bulk_update_status(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_ids: List[UUID],
|
||||
is_active: bool
|
||||
self, db: AsyncSession, *, user_ids: list[UUID], is_active: bool
|
||||
) -> int:
|
||||
"""
|
||||
Bulk update is_active status for multiple users.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_ids: List of user IDs to update
|
||||
is_active: New active status
|
||||
|
||||
Returns:
|
||||
Number of users updated
|
||||
"""
|
||||
"""Bulk update is_active status for multiple users."""
|
||||
try:
|
||||
if not user_ids:
|
||||
return 0
|
||||
|
||||
# Use UPDATE with WHERE IN for efficiency
|
||||
stmt = (
|
||||
update(User)
|
||||
.where(User.id.in_(user_ids))
|
||||
.where(User.deleted_at.is_(None)) # Don't update deleted users
|
||||
.values(is_active=is_active, updated_at=datetime.now(timezone.utc))
|
||||
.where(User.deleted_at.is_(None))
|
||||
.values(is_active=is_active, updated_at=datetime.now(UTC))
|
||||
)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
await db.commit()
|
||||
|
||||
updated_count = result.rowcount
|
||||
logger.info(f"Bulk updated {updated_count} users to is_active={is_active}")
|
||||
logger.info(
|
||||
"Bulk updated %s users to is_active=%s", updated_count, is_active
|
||||
)
|
||||
return updated_count
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error bulk updating user status: {str(e)}", exc_info=True)
|
||||
logger.exception("Error bulk updating user status: %s", e)
|
||||
raise
|
||||
|
||||
async def bulk_soft_delete(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_ids: List[UUID],
|
||||
exclude_user_id: Optional[UUID] = None
|
||||
user_ids: list[UUID],
|
||||
exclude_user_id: UUID | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Bulk soft delete multiple users.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_ids: List of user IDs to delete
|
||||
exclude_user_id: Optional user ID to exclude (e.g., the admin performing the action)
|
||||
|
||||
Returns:
|
||||
Number of users deleted
|
||||
"""
|
||||
"""Bulk soft delete multiple users."""
|
||||
try:
|
||||
if not user_ids:
|
||||
return 0
|
||||
|
||||
# Remove excluded user from list
|
||||
filtered_ids = [uid for uid in user_ids if uid != exclude_user_id]
|
||||
|
||||
if not filtered_ids:
|
||||
return 0
|
||||
|
||||
# Use UPDATE with WHERE IN for efficiency
|
||||
stmt = (
|
||||
update(User)
|
||||
.where(User.id.in_(filtered_ids))
|
||||
.where(User.deleted_at.is_(None)) # Don't re-delete already deleted users
|
||||
.where(User.deleted_at.is_(None))
|
||||
.values(
|
||||
deleted_at=datetime.now(timezone.utc),
|
||||
deleted_at=datetime.now(UTC),
|
||||
is_active=False,
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -251,22 +248,22 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
await db.commit()
|
||||
|
||||
deleted_count = result.rowcount
|
||||
logger.info(f"Bulk soft deleted {deleted_count} users")
|
||||
logger.info("Bulk soft deleted %s users", deleted_count)
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error bulk deleting users: {str(e)}", exc_info=True)
|
||||
logger.exception("Error bulk deleting users: %s", e)
|
||||
raise
|
||||
|
||||
def is_active(self, user: User) -> bool:
|
||||
"""Check if user is active."""
|
||||
return user.is_active
|
||||
return bool(user.is_active)
|
||||
|
||||
def is_superuser(self, user: User) -> bool:
|
||||
"""Check if user is a superuser."""
|
||||
return user.is_superuser
|
||||
return bool(user.is_superuser)
|
||||
|
||||
|
||||
# Create a singleton instance for use across the application
|
||||
user = CRUDUser(User)
|
||||
# Singleton instance
|
||||
user_repo = UserRepository(User)
|
||||
@@ -1,18 +1,20 @@
|
||||
"""
|
||||
Common schemas used across the API for pagination, responses, filtering, and sorting.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from math import ceil
|
||||
from typing import Generic, TypeVar, List, Optional
|
||||
from typing import TypeVar
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
T = TypeVar('T')
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class SortOrder(str, Enum):
|
||||
"""Sort order options."""
|
||||
|
||||
ASC = "asc"
|
||||
DESC = "desc"
|
||||
|
||||
@@ -20,16 +22,9 @@ class SortOrder(str, Enum):
|
||||
class PaginationParams(BaseModel):
|
||||
"""Parameters for pagination."""
|
||||
|
||||
page: int = Field(
|
||||
default=1,
|
||||
ge=1,
|
||||
description="Page number (1-indexed)"
|
||||
)
|
||||
page: int = Field(default=1, ge=1, description="Page number (1-indexed)")
|
||||
limit: int = Field(
|
||||
default=20,
|
||||
ge=1,
|
||||
le=100,
|
||||
description="Number of items per page (max 100)"
|
||||
default=20, ge=1, le=100, description="Number of items per page (max 100)"
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -42,34 +37,20 @@ class PaginationParams(BaseModel):
|
||||
"""Alias for offset (compatibility with existing code)."""
|
||||
return self.offset
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"page": 1,
|
||||
"limit": 20
|
||||
}
|
||||
}
|
||||
}
|
||||
model_config = {"json_schema_extra": {"example": {"page": 1, "limit": 20}}}
|
||||
|
||||
|
||||
class SortParams(BaseModel):
|
||||
"""Parameters for sorting."""
|
||||
|
||||
sort_by: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Field name to sort by"
|
||||
)
|
||||
sort_by: str | None = Field(default=None, description="Field name to sort by")
|
||||
sort_order: SortOrder = Field(
|
||||
default=SortOrder.ASC,
|
||||
description="Sort order (asc or desc)"
|
||||
default=SortOrder.ASC, description="Sort order (asc or desc)"
|
||||
)
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"sort_by": "created_at",
|
||||
"sort_order": "desc"
|
||||
}
|
||||
"example": {"sort_by": "created_at", "sort_order": "desc"}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -92,32 +73,30 @@ class PaginationMeta(BaseModel):
|
||||
"page_size": 20,
|
||||
"total_pages": 8,
|
||||
"has_next": True,
|
||||
"has_prev": False
|
||||
"has_prev": False,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class PaginatedResponse(BaseModel, Generic[T]):
|
||||
class PaginatedResponse[T](BaseModel):
|
||||
"""Generic paginated response wrapper."""
|
||||
|
||||
data: List[T] = Field(..., description="List of items")
|
||||
data: list[T] = Field(..., description="List of items")
|
||||
pagination: PaginationMeta = Field(..., description="Pagination metadata")
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"data": [
|
||||
{"id": "123", "name": "Example Item"}
|
||||
],
|
||||
"data": [{"id": "123", "name": "Example Item"}],
|
||||
"pagination": {
|
||||
"total": 150,
|
||||
"page": 1,
|
||||
"page_size": 20,
|
||||
"total_pages": 8,
|
||||
"has_next": True,
|
||||
"has_prev": False
|
||||
}
|
||||
"has_prev": False,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -131,10 +110,7 @@ class MessageResponse(BaseModel):
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"success": True,
|
||||
"message": "Operation completed successfully"
|
||||
}
|
||||
"example": {"success": True, "message": "Operation completed successfully"}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -142,11 +118,11 @@ class MessageResponse(BaseModel):
|
||||
class BulkActionRequest(BaseModel):
|
||||
"""Request schema for bulk operations on multiple items."""
|
||||
|
||||
ids: List[UUID] = Field(
|
||||
ids: list[UUID] = Field(
|
||||
...,
|
||||
min_length=1,
|
||||
max_length=100,
|
||||
description="List of item IDs to perform action on (max 100)"
|
||||
description="List of item IDs to perform action on (max 100)",
|
||||
)
|
||||
|
||||
model_config = {
|
||||
@@ -154,7 +130,7 @@ class BulkActionRequest(BaseModel):
|
||||
"example": {
|
||||
"ids": [
|
||||
"550e8400-e29b-41d4-a716-446655440000",
|
||||
"6ba7b810-9dad-11d1-80b4-00c04fd430c8"
|
||||
"6ba7b810-9dad-11d1-80b4-00c04fd430c8",
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -166,24 +142,23 @@ class BulkActionResponse(BaseModel):
|
||||
|
||||
success: bool = Field(default=True, description="Operation success status")
|
||||
message: str = Field(..., description="Human-readable message")
|
||||
affected_count: int = Field(..., description="Number of items affected by the operation")
|
||||
affected_count: int = Field(
|
||||
..., description="Number of items affected by the operation"
|
||||
)
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"success": True,
|
||||
"message": "Successfully deactivated 5 users",
|
||||
"affected_count": 5
|
||||
"affected_count": 5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def create_pagination_meta(
|
||||
total: int,
|
||||
page: int,
|
||||
limit: int,
|
||||
items_count: int
|
||||
total: int, page: int, limit: int, items_count: int
|
||||
) -> PaginationMeta:
|
||||
"""
|
||||
Helper function to create pagination metadata.
|
||||
@@ -205,5 +180,5 @@ def create_pagination_meta(
|
||||
page_size=items_count,
|
||||
total_pages=total_pages,
|
||||
has_next=page < total_pages,
|
||||
has_prev=page > 1
|
||||
has_prev=page > 1,
|
||||
)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"""
|
||||
Error schemas for standardized API error responses.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -53,14 +53,14 @@ class ErrorDetail(BaseModel):
|
||||
|
||||
code: ErrorCode = Field(..., description="Machine-readable error code")
|
||||
message: str = Field(..., description="Human-readable error message")
|
||||
field: Optional[str] = Field(None, description="Field name if error is field-specific")
|
||||
field: str | None = Field(None, description="Field name if error is field-specific")
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"code": "VAL_002",
|
||||
"message": "Password must be at least 8 characters long",
|
||||
"field": "password"
|
||||
"field": "password",
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -70,7 +70,7 @@ class ErrorResponse(BaseModel):
|
||||
"""Standardized error response format."""
|
||||
|
||||
success: bool = Field(default=False, description="Always false for error responses")
|
||||
errors: List[ErrorDetail] = Field(..., description="List of errors that occurred")
|
||||
errors: list[ErrorDetail] = Field(..., description="List of errors that occurred")
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
@@ -80,9 +80,9 @@ class ErrorResponse(BaseModel):
|
||||
{
|
||||
"code": "AUTH_001",
|
||||
"message": "Invalid email or password",
|
||||
"field": None
|
||||
"field": None,
|
||||
}
|
||||
]
|
||||
],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
395
backend/app/schemas/oauth.py
Normal file
395
backend/app/schemas/oauth.py
Normal file
@@ -0,0 +1,395 @@
|
||||
"""
|
||||
Pydantic schemas for OAuth authentication.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
# ============================================================================
|
||||
# OAuth Provider Info (for frontend to display available providers)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class OAuthProviderInfo(BaseModel):
|
||||
"""Information about an available OAuth provider."""
|
||||
|
||||
provider: str = Field(..., description="Provider identifier (google, github)")
|
||||
name: str = Field(..., description="Human-readable provider name")
|
||||
icon: str | None = Field(None, description="Icon identifier for frontend")
|
||||
|
||||
|
||||
class OAuthProvidersResponse(BaseModel):
|
||||
"""Response containing list of enabled OAuth providers."""
|
||||
|
||||
enabled: bool = Field(..., description="Whether OAuth is globally enabled")
|
||||
providers: list[OAuthProviderInfo] = Field(
|
||||
default_factory=list, description="List of enabled providers"
|
||||
)
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"enabled": True,
|
||||
"providers": [
|
||||
{"provider": "google", "name": "Google", "icon": "google"},
|
||||
{"provider": "github", "name": "GitHub", "icon": "github"},
|
||||
],
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# OAuth Account (linked provider accounts)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class OAuthAccountBase(BaseModel):
|
||||
"""Base schema for OAuth accounts."""
|
||||
|
||||
provider: str = Field(..., max_length=50, description="OAuth provider name")
|
||||
provider_email: str | None = Field(
|
||||
None, max_length=255, description="Email from OAuth provider"
|
||||
)
|
||||
|
||||
|
||||
class OAuthAccountCreate(OAuthAccountBase):
|
||||
"""Schema for creating an OAuth account link (internal use)."""
|
||||
|
||||
user_id: UUID
|
||||
provider_user_id: str = Field(..., max_length=255)
|
||||
access_token: str | None = None
|
||||
refresh_token: str | None = None
|
||||
token_expires_at: datetime | None = None
|
||||
|
||||
|
||||
class OAuthAccountResponse(OAuthAccountBase):
|
||||
"""Schema for OAuth account response to clients."""
|
||||
|
||||
id: UUID
|
||||
created_at: datetime
|
||||
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"id": "123e4567-e89b-12d3-a456-426614174000",
|
||||
"provider": "google",
|
||||
"provider_email": "user@gmail.com",
|
||||
"created_at": "2025-11-24T12:00:00Z",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class OAuthAccountsListResponse(BaseModel):
|
||||
"""Response containing list of linked OAuth accounts."""
|
||||
|
||||
accounts: list[OAuthAccountResponse]
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"accounts": [
|
||||
{
|
||||
"id": "123e4567-e89b-12d3-a456-426614174000",
|
||||
"provider": "google",
|
||||
"provider_email": "user@gmail.com",
|
||||
"created_at": "2025-11-24T12:00:00Z",
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# OAuth Flow (authorization, callback, etc.)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class OAuthAuthorizeRequest(BaseModel):
|
||||
"""Request parameters for OAuth authorization."""
|
||||
|
||||
provider: str = Field(..., description="OAuth provider (google, github)")
|
||||
redirect_uri: str | None = Field(
|
||||
None, description="Frontend callback URL after OAuth"
|
||||
)
|
||||
mode: str = Field(
|
||||
default="login",
|
||||
description="OAuth mode: login, register, or link",
|
||||
pattern="^(login|register|link)$",
|
||||
)
|
||||
|
||||
|
||||
class OAuthCallbackRequest(BaseModel):
|
||||
"""Request parameters for OAuth callback."""
|
||||
|
||||
code: str = Field(..., description="Authorization code from provider")
|
||||
state: str = Field(..., description="State parameter for CSRF protection")
|
||||
|
||||
|
||||
class OAuthCallbackResponse(BaseModel):
|
||||
"""Response after successful OAuth authentication."""
|
||||
|
||||
access_token: str = Field(..., description="JWT access token")
|
||||
refresh_token: str = Field(..., description="JWT refresh token")
|
||||
token_type: str = Field(default="bearer")
|
||||
expires_in: int = Field(..., description="Token expiration in seconds")
|
||||
is_new_user: bool = Field(
|
||||
default=False, description="Whether a new user was created"
|
||||
)
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"access_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...",
|
||||
"refresh_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...",
|
||||
"token_type": "bearer",
|
||||
"expires_in": 900,
|
||||
"is_new_user": False,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class OAuthUnlinkResponse(BaseModel):
|
||||
"""Response after unlinking an OAuth account."""
|
||||
|
||||
success: bool = Field(..., description="Whether the unlink was successful")
|
||||
message: str = Field(..., description="Status message")
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"example": {"success": True, "message": "Google account unlinked"}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# OAuth State (CSRF protection - internal use)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class OAuthStateCreate(BaseModel):
|
||||
"""Schema for creating OAuth state (internal use)."""
|
||||
|
||||
state: str = Field(..., max_length=255)
|
||||
code_verifier: str | None = Field(None, max_length=128)
|
||||
nonce: str | None = Field(None, max_length=255)
|
||||
provider: str = Field(..., max_length=50)
|
||||
redirect_uri: str | None = Field(None, max_length=500)
|
||||
user_id: UUID | None = None
|
||||
expires_at: datetime
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# OAuth Client (Provider Mode - MCP clients)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class OAuthClientBase(BaseModel):
|
||||
"""Base schema for OAuth clients."""
|
||||
|
||||
client_name: str = Field(..., max_length=255, description="Client application name")
|
||||
client_description: str | None = Field(
|
||||
None, max_length=1000, description="Client description"
|
||||
)
|
||||
redirect_uris: list[str] = Field(
|
||||
default_factory=list, description="Allowed redirect URIs"
|
||||
)
|
||||
allowed_scopes: list[str] = Field(
|
||||
default_factory=list, description="Allowed OAuth scopes"
|
||||
)
|
||||
|
||||
|
||||
class OAuthClientCreate(OAuthClientBase):
|
||||
"""Schema for creating an OAuth client."""
|
||||
|
||||
client_type: str = Field(
|
||||
default="public",
|
||||
description="Client type: public or confidential",
|
||||
pattern="^(public|confidential)$",
|
||||
)
|
||||
|
||||
|
||||
class OAuthClientResponse(OAuthClientBase):
|
||||
"""Schema for OAuth client response."""
|
||||
|
||||
id: UUID
|
||||
client_id: str = Field(..., description="OAuth client ID")
|
||||
client_type: str
|
||||
is_active: bool
|
||||
created_at: datetime
|
||||
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"id": "123e4567-e89b-12d3-a456-426614174000",
|
||||
"client_id": "abc123def456",
|
||||
"client_name": "My MCP App",
|
||||
"client_description": "My application that uses MCP",
|
||||
"client_type": "public",
|
||||
"redirect_uris": ["http://localhost:3000/callback"],
|
||||
"allowed_scopes": ["read:users", "write:users"],
|
||||
"is_active": True,
|
||||
"created_at": "2025-11-24T12:00:00Z",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class OAuthClientWithSecret(OAuthClientResponse):
|
||||
"""Schema for OAuth client response including secret (only shown once)."""
|
||||
|
||||
client_secret: str | None = Field(
|
||||
None, description="Client secret (only shown once for confidential clients)"
|
||||
)
|
||||
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"id": "123e4567-e89b-12d3-a456-426614174000",
|
||||
"client_id": "abc123def456",
|
||||
"client_secret": "secret_xyz789",
|
||||
"client_name": "My MCP App",
|
||||
"client_type": "confidential",
|
||||
"redirect_uris": ["http://localhost:3000/callback"],
|
||||
"allowed_scopes": ["read:users"],
|
||||
"is_active": True,
|
||||
"created_at": "2025-11-24T12:00:00Z",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# OAuth Provider Discovery (RFC 8414 - skeleton)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class OAuthServerMetadata(BaseModel):
|
||||
"""OAuth 2.0 Authorization Server Metadata (RFC 8414)."""
|
||||
|
||||
issuer: str = Field(..., description="Authorization server issuer URL")
|
||||
authorization_endpoint: str = Field(..., description="Authorization endpoint URL")
|
||||
token_endpoint: str = Field(..., description="Token endpoint URL")
|
||||
registration_endpoint: str | None = Field(
|
||||
None, description="Dynamic client registration endpoint"
|
||||
)
|
||||
revocation_endpoint: str | None = Field(
|
||||
None, description="Token revocation endpoint"
|
||||
)
|
||||
introspection_endpoint: str | None = Field(
|
||||
None, description="Token introspection endpoint (RFC 7662)"
|
||||
)
|
||||
scopes_supported: list[str] = Field(
|
||||
default_factory=list, description="Supported scopes"
|
||||
)
|
||||
response_types_supported: list[str] = Field(
|
||||
default_factory=lambda: ["code"], description="Supported response types"
|
||||
)
|
||||
grant_types_supported: list[str] = Field(
|
||||
default_factory=lambda: ["authorization_code", "refresh_token"],
|
||||
description="Supported grant types",
|
||||
)
|
||||
code_challenge_methods_supported: list[str] = Field(
|
||||
default_factory=lambda: ["S256"], description="Supported PKCE methods"
|
||||
)
|
||||
token_endpoint_auth_methods_supported: list[str] = Field(
|
||||
default_factory=lambda: ["client_secret_basic", "client_secret_post", "none"],
|
||||
description="Supported client authentication methods",
|
||||
)
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"issuer": "https://api.example.com",
|
||||
"authorization_endpoint": "https://api.example.com/oauth/authorize",
|
||||
"token_endpoint": "https://api.example.com/oauth/token",
|
||||
"revocation_endpoint": "https://api.example.com/oauth/revoke",
|
||||
"introspection_endpoint": "https://api.example.com/oauth/introspect",
|
||||
"scopes_supported": ["openid", "profile", "email", "read:users"],
|
||||
"response_types_supported": ["code"],
|
||||
"grant_types_supported": ["authorization_code", "refresh_token"],
|
||||
"code_challenge_methods_supported": ["S256"],
|
||||
"token_endpoint_auth_methods_supported": [
|
||||
"client_secret_basic",
|
||||
"client_secret_post",
|
||||
"none",
|
||||
],
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# OAuth Token Responses (RFC 6749)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class OAuthTokenResponse(BaseModel):
|
||||
"""OAuth 2.0 Token Response (RFC 6749 Section 5.1)."""
|
||||
|
||||
access_token: str = Field(..., description="The access token issued by the server")
|
||||
token_type: str = Field(
|
||||
default="Bearer", description="The type of token (typically 'Bearer')"
|
||||
)
|
||||
expires_in: int | None = Field(None, description="Token lifetime in seconds")
|
||||
refresh_token: str | None = Field(
|
||||
None, description="Refresh token for obtaining new access tokens"
|
||||
)
|
||||
scope: str | None = Field(
|
||||
None, description="Space-separated list of granted scopes"
|
||||
)
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"access_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 3600,
|
||||
"refresh_token": "dGhpcyBpcyBhIHJlZnJlc2ggdG9rZW4...",
|
||||
"scope": "openid profile email",
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class OAuthTokenIntrospectionResponse(BaseModel):
|
||||
"""OAuth 2.0 Token Introspection Response (RFC 7662)."""
|
||||
|
||||
active: bool = Field(..., description="Whether the token is currently active")
|
||||
scope: str | None = Field(None, description="Space-separated list of scopes")
|
||||
client_id: str | None = Field(None, description="Client identifier for the token")
|
||||
username: str | None = Field(
|
||||
None, description="Human-readable identifier for the resource owner"
|
||||
)
|
||||
token_type: str | None = Field(
|
||||
None, description="Type of the token (e.g., 'Bearer')"
|
||||
)
|
||||
exp: int | None = Field(None, description="Token expiration time (Unix timestamp)")
|
||||
iat: int | None = Field(None, description="Token issue time (Unix timestamp)")
|
||||
nbf: int | None = Field(None, description="Token not-before time (Unix timestamp)")
|
||||
sub: str | None = Field(None, description="Subject of the token (user ID)")
|
||||
aud: str | None = Field(None, description="Intended audience of the token")
|
||||
iss: str | None = Field(None, description="Issuer of the token")
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"active": True,
|
||||
"scope": "openid profile",
|
||||
"client_id": "client123",
|
||||
"username": "user@example.com",
|
||||
"token_type": "Bearer",
|
||||
"exp": 1735689600,
|
||||
"iat": 1735686000,
|
||||
"sub": "user-uuid-here",
|
||||
}
|
||||
}
|
||||
)
|
||||
@@ -1,10 +1,10 @@
|
||||
# app/schemas/organizations.py
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any, List
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, field_validator, ConfigDict, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from app.models.user_organization import OrganizationRole
|
||||
|
||||
@@ -12,85 +12,94 @@ from app.models.user_organization import OrganizationRole
|
||||
# Organization Schemas
|
||||
class OrganizationBase(BaseModel):
|
||||
"""Base organization schema with common fields."""
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
slug: Optional[str] = Field(None, min_length=1, max_length=255)
|
||||
description: Optional[str] = None
|
||||
is_active: bool = True
|
||||
settings: Optional[Dict[str, Any]] = {}
|
||||
|
||||
@field_validator('slug')
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
slug: str | None = Field(None, min_length=1, max_length=255)
|
||||
description: str | None = None
|
||||
is_active: bool = True
|
||||
settings: dict[str, Any] | None = {}
|
||||
|
||||
@field_validator("slug")
|
||||
@classmethod
|
||||
def validate_slug(cls, v: Optional[str]) -> Optional[str]:
|
||||
def validate_slug(cls, v: str | None) -> str | None:
|
||||
"""Validate slug format: lowercase, alphanumeric, hyphens only."""
|
||||
if v is None:
|
||||
return v
|
||||
if not re.match(r'^[a-z0-9-]+$', v):
|
||||
raise ValueError('Slug must contain only lowercase letters, numbers, and hyphens')
|
||||
if v.startswith('-') or v.endswith('-'):
|
||||
raise ValueError('Slug cannot start or end with a hyphen')
|
||||
if '--' in v:
|
||||
raise ValueError('Slug cannot contain consecutive hyphens')
|
||||
if not re.match(r"^[a-z0-9-]+$", v):
|
||||
raise ValueError(
|
||||
"Slug must contain only lowercase letters, numbers, and hyphens"
|
||||
)
|
||||
if v.startswith("-") or v.endswith("-"):
|
||||
raise ValueError("Slug cannot start or end with a hyphen")
|
||||
if "--" in v:
|
||||
raise ValueError("Slug cannot contain consecutive hyphens")
|
||||
return v
|
||||
|
||||
@field_validator('name')
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
def validate_name(cls, v: str) -> str:
|
||||
"""Validate organization name."""
|
||||
if not v or v.strip() == "":
|
||||
raise ValueError('Organization name cannot be empty')
|
||||
raise ValueError("Organization name cannot be empty")
|
||||
return v.strip()
|
||||
|
||||
|
||||
class OrganizationCreate(OrganizationBase):
|
||||
"""Schema for creating a new organization."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
slug: str = Field(..., min_length=1, max_length=255)
|
||||
slug: str = Field(..., min_length=1, max_length=255) # pyright: ignore[reportIncompatibleVariableOverride]
|
||||
|
||||
|
||||
class OrganizationUpdate(BaseModel):
|
||||
"""Schema for updating an organization."""
|
||||
name: Optional[str] = Field(None, min_length=1, max_length=255)
|
||||
slug: Optional[str] = Field(None, min_length=1, max_length=255)
|
||||
description: Optional[str] = None
|
||||
is_active: Optional[bool] = None
|
||||
settings: Optional[Dict[str, Any]] = None
|
||||
|
||||
@field_validator('slug')
|
||||
name: str | None = Field(None, min_length=1, max_length=255)
|
||||
slug: str | None = Field(None, min_length=1, max_length=255)
|
||||
description: str | None = None
|
||||
is_active: bool | None = None
|
||||
settings: dict[str, Any] | None = None
|
||||
|
||||
@field_validator("slug")
|
||||
@classmethod
|
||||
def validate_slug(cls, v: Optional[str]) -> Optional[str]:
|
||||
def validate_slug(cls, v: str | None) -> str | None:
|
||||
"""Validate slug format."""
|
||||
if v is None:
|
||||
return v
|
||||
if not re.match(r'^[a-z0-9-]+$', v):
|
||||
raise ValueError('Slug must contain only lowercase letters, numbers, and hyphens')
|
||||
if v.startswith('-') or v.endswith('-'):
|
||||
raise ValueError('Slug cannot start or end with a hyphen')
|
||||
if '--' in v:
|
||||
raise ValueError('Slug cannot contain consecutive hyphens')
|
||||
if not re.match(r"^[a-z0-9-]+$", v):
|
||||
raise ValueError(
|
||||
"Slug must contain only lowercase letters, numbers, and hyphens"
|
||||
)
|
||||
if v.startswith("-") or v.endswith("-"):
|
||||
raise ValueError("Slug cannot start or end with a hyphen")
|
||||
if "--" in v:
|
||||
raise ValueError("Slug cannot contain consecutive hyphens")
|
||||
return v
|
||||
|
||||
@field_validator('name')
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
def validate_name(cls, v: Optional[str]) -> Optional[str]:
|
||||
def validate_name(cls, v: str | None) -> str | None:
|
||||
"""Validate organization name."""
|
||||
if v is not None and (not v or v.strip() == ""):
|
||||
raise ValueError('Organization name cannot be empty')
|
||||
raise ValueError("Organization name cannot be empty")
|
||||
return v.strip() if v else v
|
||||
|
||||
|
||||
class OrganizationResponse(OrganizationBase):
|
||||
"""Schema for organization API responses."""
|
||||
|
||||
id: UUID
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime] = None
|
||||
member_count: Optional[int] = 0
|
||||
updated_at: datetime | None = None
|
||||
member_count: int | None = 0
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class OrganizationListResponse(BaseModel):
|
||||
"""Schema for paginated organization list responses."""
|
||||
organizations: List[OrganizationResponse]
|
||||
|
||||
organizations: list[OrganizationResponse]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
@@ -100,44 +109,49 @@ class OrganizationListResponse(BaseModel):
|
||||
# User-Organization Relationship Schemas
|
||||
class UserOrganizationBase(BaseModel):
|
||||
"""Base schema for user-organization relationship."""
|
||||
|
||||
role: OrganizationRole = OrganizationRole.MEMBER
|
||||
is_active: bool = True
|
||||
custom_permissions: Optional[str] = None
|
||||
custom_permissions: str | None = None
|
||||
|
||||
|
||||
class UserOrganizationCreate(BaseModel):
|
||||
"""Schema for adding a user to an organization."""
|
||||
|
||||
user_id: UUID
|
||||
role: OrganizationRole = OrganizationRole.MEMBER
|
||||
custom_permissions: Optional[str] = None
|
||||
custom_permissions: str | None = None
|
||||
|
||||
|
||||
class UserOrganizationUpdate(BaseModel):
|
||||
"""Schema for updating user's role in an organization."""
|
||||
role: Optional[OrganizationRole] = None
|
||||
is_active: Optional[bool] = None
|
||||
custom_permissions: Optional[str] = None
|
||||
|
||||
role: OrganizationRole | None = None
|
||||
is_active: bool | None = None
|
||||
custom_permissions: str | None = None
|
||||
|
||||
|
||||
class UserOrganizationResponse(BaseModel):
|
||||
"""Schema for user-organization relationship responses."""
|
||||
|
||||
user_id: UUID
|
||||
organization_id: UUID
|
||||
role: OrganizationRole
|
||||
is_active: bool
|
||||
custom_permissions: Optional[str] = None
|
||||
custom_permissions: str | None = None
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime] = None
|
||||
updated_at: datetime | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class OrganizationMemberResponse(BaseModel):
|
||||
"""Schema for organization member information."""
|
||||
|
||||
user_id: UUID
|
||||
email: str
|
||||
first_name: str
|
||||
last_name: Optional[str] = None
|
||||
last_name: str | None = None
|
||||
role: OrganizationRole
|
||||
is_active: bool
|
||||
joined_at: datetime
|
||||
@@ -147,7 +161,8 @@ class OrganizationMemberResponse(BaseModel):
|
||||
|
||||
class OrganizationMemberListResponse(BaseModel):
|
||||
"""Schema for paginated organization member list."""
|
||||
members: List[OrganizationMemberResponse]
|
||||
|
||||
members: list[OrganizationMemberResponse]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
|
||||
@@ -1,37 +1,44 @@
|
||||
"""
|
||||
Pydantic schemas for user session management.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class SessionBase(BaseModel):
|
||||
"""Base schema for user sessions."""
|
||||
device_name: Optional[str] = Field(None, max_length=255, description="Friendly device name")
|
||||
device_id: Optional[str] = Field(None, max_length=255, description="Persistent device identifier")
|
||||
|
||||
device_name: str | None = Field(
|
||||
None, max_length=255, description="Friendly device name"
|
||||
)
|
||||
device_id: str | None = Field(
|
||||
None, max_length=255, description="Persistent device identifier"
|
||||
)
|
||||
|
||||
|
||||
class SessionCreate(SessionBase):
|
||||
"""Schema for creating a new session (internal use)."""
|
||||
|
||||
user_id: UUID
|
||||
refresh_token_jti: str = Field(..., max_length=255)
|
||||
ip_address: Optional[str] = Field(None, max_length=45)
|
||||
user_agent: Optional[str] = Field(None, max_length=500)
|
||||
ip_address: str | None = Field(None, max_length=45)
|
||||
user_agent: str | None = Field(None, max_length=500)
|
||||
last_used_at: datetime
|
||||
expires_at: datetime
|
||||
location_city: Optional[str] = Field(None, max_length=100)
|
||||
location_country: Optional[str] = Field(None, max_length=100)
|
||||
location_city: str | None = Field(None, max_length=100)
|
||||
location_country: str | None = Field(None, max_length=100)
|
||||
|
||||
|
||||
class SessionUpdate(BaseModel):
|
||||
"""Schema for updating a session (internal use)."""
|
||||
last_used_at: Optional[datetime] = None
|
||||
is_active: Optional[bool] = None
|
||||
refresh_token_jti: Optional[str] = None
|
||||
expires_at: Optional[datetime] = None
|
||||
|
||||
last_used_at: datetime | None = None
|
||||
is_active: bool | None = None
|
||||
refresh_token_jti: str | None = None
|
||||
expires_at: datetime | None = None
|
||||
|
||||
|
||||
class SessionResponse(SessionBase):
|
||||
@@ -40,14 +47,17 @@ class SessionResponse(SessionBase):
|
||||
|
||||
This is what users see when they list their active sessions.
|
||||
"""
|
||||
|
||||
id: UUID
|
||||
ip_address: Optional[str] = None
|
||||
location_city: Optional[str] = None
|
||||
location_country: Optional[str] = None
|
||||
ip_address: str | None = None
|
||||
location_city: str | None = None
|
||||
location_country: str | None = None
|
||||
last_used_at: datetime
|
||||
created_at: datetime
|
||||
expires_at: datetime
|
||||
is_current: bool = Field(default=False, description="Whether this is the current session")
|
||||
is_current: bool = Field(
|
||||
default=False, description="Whether this is the current session"
|
||||
)
|
||||
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
@@ -62,14 +72,15 @@ class SessionResponse(SessionBase):
|
||||
"last_used_at": "2025-10-31T12:00:00Z",
|
||||
"created_at": "2025-10-30T09:00:00Z",
|
||||
"expires_at": "2025-11-06T09:00:00Z",
|
||||
"is_current": True
|
||||
"is_current": True,
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class SessionListResponse(BaseModel):
|
||||
"""Response containing list of sessions."""
|
||||
|
||||
sessions: list[SessionResponse]
|
||||
total: int = Field(..., description="Total number of active sessions")
|
||||
|
||||
@@ -84,10 +95,10 @@ class SessionListResponse(BaseModel):
|
||||
"last_used_at": "2025-10-31T12:00:00Z",
|
||||
"created_at": "2025-10-30T09:00:00Z",
|
||||
"expires_at": "2025-11-06T09:00:00Z",
|
||||
"is_current": True
|
||||
"is_current": True,
|
||||
}
|
||||
],
|
||||
"total": 1
|
||||
"total": 1,
|
||||
}
|
||||
}
|
||||
)
|
||||
@@ -95,17 +106,14 @@ class SessionListResponse(BaseModel):
|
||||
|
||||
class LogoutRequest(BaseModel):
|
||||
"""Request schema for logout endpoint."""
|
||||
|
||||
refresh_token: str = Field(
|
||||
...,
|
||||
description="Refresh token for the session to logout from",
|
||||
min_length=10
|
||||
..., description="Refresh token for the session to logout from", min_length=10
|
||||
)
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"refresh_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
|
||||
}
|
||||
"example": {"refresh_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."}
|
||||
}
|
||||
)
|
||||
|
||||
@@ -116,13 +124,14 @@ class AdminSessionResponse(SessionBase):
|
||||
|
||||
Includes user information for admin to see who owns each session.
|
||||
"""
|
||||
|
||||
id: UUID
|
||||
user_id: UUID
|
||||
user_email: str = Field(..., description="Email of the user who owns this session")
|
||||
user_full_name: Optional[str] = Field(None, description="Full name of the user")
|
||||
ip_address: Optional[str] = None
|
||||
location_city: Optional[str] = None
|
||||
location_country: Optional[str] = None
|
||||
user_full_name: str | None = Field(None, description="Full name of the user")
|
||||
ip_address: str | None = None
|
||||
location_city: str | None = None
|
||||
location_country: str | None = None
|
||||
last_used_at: datetime
|
||||
created_at: datetime
|
||||
expires_at: datetime
|
||||
@@ -144,20 +153,21 @@ class AdminSessionResponse(SessionBase):
|
||||
"last_used_at": "2025-10-31T12:00:00Z",
|
||||
"created_at": "2025-10-30T09:00:00Z",
|
||||
"expires_at": "2025-11-06T09:00:00Z",
|
||||
"is_active": True
|
||||
"is_active": True,
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class DeviceInfo(BaseModel):
|
||||
"""Device information extracted from request."""
|
||||
device_name: Optional[str] = None
|
||||
device_id: Optional[str] = None
|
||||
ip_address: Optional[str] = None
|
||||
user_agent: Optional[str] = None
|
||||
location_city: Optional[str] = None
|
||||
location_country: Optional[str] = None
|
||||
|
||||
device_name: str | None = None
|
||||
device_id: str | None = None
|
||||
ip_address: str | None = None
|
||||
user_agent: str | None = None
|
||||
location_city: str | None = None
|
||||
location_country: str | None = None
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
@@ -167,7 +177,7 @@ class DeviceInfo(BaseModel):
|
||||
"ip_address": "192.168.1.50",
|
||||
"user_agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7)...",
|
||||
"location_city": "San Francisco",
|
||||
"location_country": "United States"
|
||||
"location_country": "United States",
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
# app/schemas/users.py
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, EmailStr, field_validator, ConfigDict, Field
|
||||
from pydantic import BaseModel, ConfigDict, EmailStr, Field, field_validator
|
||||
|
||||
from app.schemas.validators import validate_password_strength, validate_phone_number
|
||||
|
||||
@@ -11,20 +11,21 @@ from app.schemas.validators import validate_password_strength, validate_phone_nu
|
||||
class UserBase(BaseModel):
|
||||
email: EmailStr
|
||||
first_name: str
|
||||
last_name: Optional[str] = None
|
||||
phone_number: Optional[str] = None
|
||||
last_name: str | None = None
|
||||
phone_number: str | None = None
|
||||
|
||||
@field_validator('phone_number')
|
||||
@field_validator("phone_number")
|
||||
@classmethod
|
||||
def validate_phone(cls, v: Optional[str]) -> Optional[str]:
|
||||
def validate_phone(cls, v: str | None) -> str | None:
|
||||
return validate_phone_number(v)
|
||||
|
||||
|
||||
class UserCreate(UserBase):
|
||||
password: str
|
||||
is_superuser: bool = False
|
||||
is_active: bool = True
|
||||
|
||||
@field_validator('password')
|
||||
@field_validator("password")
|
||||
@classmethod
|
||||
def password_strength(cls, v: str) -> str:
|
||||
"""Enterprise-grade password strength validation"""
|
||||
@@ -32,30 +33,57 @@ class UserCreate(UserBase):
|
||||
|
||||
|
||||
class UserUpdate(BaseModel):
|
||||
first_name: Optional[str] = None
|
||||
last_name: Optional[str] = None
|
||||
phone_number: Optional[str] = None
|
||||
password: Optional[str] = None
|
||||
preferences: Optional[Dict[str, Any]] = None
|
||||
is_active: Optional[bool] = None # Changed default from True to None to avoid unintended updates
|
||||
is_superuser: Optional[bool] = None # Explicitly reject privilege escalation attempts
|
||||
first_name: str | None = None
|
||||
last_name: str | None = None
|
||||
phone_number: str | None = None
|
||||
password: str | None = None
|
||||
preferences: dict[str, Any] | None = None
|
||||
locale: str | None = Field(
|
||||
None,
|
||||
max_length=10,
|
||||
pattern=r"^[a-z]{2}(-[A-Z]{2})?$",
|
||||
description="User's preferred locale (BCP 47 format: en, it, en-US, it-IT)",
|
||||
examples=["en", "it", "en-US", "it-IT"],
|
||||
)
|
||||
is_active: bool | None = (
|
||||
None # Changed default from True to None to avoid unintended updates
|
||||
)
|
||||
is_superuser: bool | None = None # Explicitly reject privilege escalation attempts
|
||||
|
||||
@field_validator('phone_number')
|
||||
@field_validator("phone_number")
|
||||
@classmethod
|
||||
def validate_phone(cls, v: Optional[str]) -> Optional[str]:
|
||||
def validate_phone(cls, v: str | None) -> str | None:
|
||||
return validate_phone_number(v)
|
||||
|
||||
@field_validator('password')
|
||||
@field_validator("password")
|
||||
@classmethod
|
||||
def password_strength(cls, v: Optional[str]) -> Optional[str]:
|
||||
def password_strength(cls, v: str | None) -> str | None:
|
||||
"""Enterprise-grade password strength validation"""
|
||||
if v is None:
|
||||
return v
|
||||
return validate_password_strength(v)
|
||||
|
||||
@field_validator('is_superuser')
|
||||
@field_validator("locale")
|
||||
@classmethod
|
||||
def prevent_superuser_modification(cls, v: Optional[bool]) -> Optional[bool]:
|
||||
def validate_locale(cls, v: str | None) -> str | None:
|
||||
"""Validate locale against supported locales."""
|
||||
if v is None:
|
||||
return v
|
||||
# Only support English and Italian for template showcase
|
||||
# Note: Locales stored in lowercase for case-insensitive matching
|
||||
supported_locales = {"en", "it", "en-us", "en-gb", "it-it"}
|
||||
# Normalize to lowercase for comparison and storage
|
||||
v_lower = v.lower()
|
||||
if v_lower not in supported_locales:
|
||||
raise ValueError(
|
||||
f"Unsupported locale '{v}'. Supported locales: {sorted(supported_locales)}"
|
||||
)
|
||||
# Return normalized lowercase version for consistency
|
||||
return v_lower
|
||||
|
||||
@field_validator("is_superuser")
|
||||
@classmethod
|
||||
def prevent_superuser_modification(cls, v: bool | None) -> bool | None:
|
||||
"""Prevent users from modifying their superuser status via this schema."""
|
||||
if v is not None:
|
||||
raise ValueError("Cannot modify superuser status through user update")
|
||||
@@ -67,7 +95,8 @@ class UserInDB(UserBase):
|
||||
is_active: bool
|
||||
is_superuser: bool
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime] = None
|
||||
updated_at: datetime | None = None
|
||||
locale: str | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@@ -77,28 +106,29 @@ class UserResponse(UserBase):
|
||||
is_active: bool
|
||||
is_superuser: bool
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime] = None
|
||||
updated_at: datetime | None = None
|
||||
locale: str | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class Token(BaseModel):
|
||||
access_token: str
|
||||
refresh_token: Optional[str] = None
|
||||
refresh_token: str | None = None
|
||||
token_type: str = "bearer"
|
||||
user: "UserResponse" # Forward reference since UserResponse is defined above
|
||||
expires_in: Optional[int] = None # Token expiration in seconds
|
||||
expires_in: int | None = None # Token expiration in seconds
|
||||
|
||||
|
||||
class TokenPayload(BaseModel):
|
||||
sub: str # User ID
|
||||
exp: int # Expiration time
|
||||
iat: Optional[int] = None # Issued at
|
||||
jti: Optional[str] = None # JWT ID
|
||||
is_superuser: Optional[bool] = False
|
||||
first_name: Optional[str] = None
|
||||
email: Optional[str] = None
|
||||
type: Optional[str] = None # Token type (access/refresh)
|
||||
iat: int | None = None # Issued at
|
||||
jti: str | None = None # JWT ID
|
||||
is_superuser: bool | None = False
|
||||
first_name: str | None = None
|
||||
email: str | None = None
|
||||
type: str | None = None # Token type (access/refresh)
|
||||
|
||||
|
||||
class TokenData(BaseModel):
|
||||
@@ -108,10 +138,11 @@ class TokenData(BaseModel):
|
||||
|
||||
class PasswordChange(BaseModel):
|
||||
"""Schema for changing password (requires current password)."""
|
||||
|
||||
current_password: str
|
||||
new_password: str
|
||||
|
||||
@field_validator('new_password')
|
||||
@field_validator("new_password")
|
||||
@classmethod
|
||||
def password_strength(cls, v: str) -> str:
|
||||
"""Enterprise-grade password strength validation"""
|
||||
@@ -120,10 +151,11 @@ class PasswordChange(BaseModel):
|
||||
|
||||
class PasswordReset(BaseModel):
|
||||
"""Schema for resetting password (via email token)."""
|
||||
|
||||
token: str
|
||||
new_password: str
|
||||
|
||||
@field_validator('new_password')
|
||||
@field_validator("new_password")
|
||||
@classmethod
|
||||
def password_strength(cls, v: str) -> str:
|
||||
"""Enterprise-grade password strength validation"""
|
||||
@@ -141,23 +173,19 @@ class RefreshTokenRequest(BaseModel):
|
||||
|
||||
class PasswordResetRequest(BaseModel):
|
||||
"""Schema for requesting a password reset."""
|
||||
|
||||
email: EmailStr = Field(..., description="Email address of the account")
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"email": "user@example.com"
|
||||
}
|
||||
}
|
||||
}
|
||||
model_config = {"json_schema_extra": {"example": {"email": "user@example.com"}}}
|
||||
|
||||
|
||||
class PasswordResetConfirm(BaseModel):
|
||||
"""Schema for confirming a password reset with token."""
|
||||
|
||||
token: str = Field(..., description="Password reset token from email")
|
||||
new_password: str = Field(..., min_length=8, description="New password")
|
||||
|
||||
@field_validator('new_password')
|
||||
@field_validator("new_password")
|
||||
@classmethod
|
||||
def password_strength(cls, v: str) -> str:
|
||||
"""Enterprise-grade password strength validation"""
|
||||
@@ -167,7 +195,7 @@ class PasswordResetConfirm(BaseModel):
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"token": "eyJwYXlsb2FkIjp7ImVtYWlsIjoidXNlckBleGFtcGxlLmNvbSIsImV4cCI6MTcxMjM0NTY3OH19",
|
||||
"new_password": "NewSecurePassword123"
|
||||
"new_password": "NewSecurePassword123",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,19 +4,34 @@ Shared validators for Pydantic schemas.
|
||||
This module provides reusable validation functions to ensure consistency
|
||||
across all schemas and avoid code duplication.
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Set
|
||||
|
||||
# Common weak passwords that should be rejected
|
||||
COMMON_PASSWORDS: Set[str] = {
|
||||
'password', 'password1', 'password123', 'password1234',
|
||||
'admin', 'admin123', 'admin1234',
|
||||
'welcome', 'welcome1', 'welcome123',
|
||||
'qwerty', 'qwerty123',
|
||||
'12345678', '123456789', '1234567890',
|
||||
'letmein', 'letmein1', 'letmein123',
|
||||
'monkey123', 'dragon123',
|
||||
'passw0rd', 'p@ssw0rd', 'p@ssword',
|
||||
COMMON_PASSWORDS: set[str] = {
|
||||
"password",
|
||||
"password1",
|
||||
"password123",
|
||||
"password1234",
|
||||
"admin",
|
||||
"admin123",
|
||||
"admin1234",
|
||||
"welcome",
|
||||
"welcome1",
|
||||
"welcome123",
|
||||
"qwerty",
|
||||
"qwerty123",
|
||||
"12345678",
|
||||
"123456789",
|
||||
"1234567890",
|
||||
"letmein",
|
||||
"letmein1",
|
||||
"letmein123",
|
||||
"monkey123",
|
||||
"dragon123",
|
||||
"passw0rd",
|
||||
"p@ssw0rd",
|
||||
"p@ssword",
|
||||
}
|
||||
|
||||
|
||||
@@ -45,20 +60,32 @@ def validate_password_strength(password: str) -> str:
|
||||
>>> validate_password_strength("MySecureP@ss123") # Valid
|
||||
>>> validate_password_strength("password1") # Invalid - too weak
|
||||
"""
|
||||
# Check if we are in demo mode
|
||||
from app.core.config import settings
|
||||
|
||||
if settings.DEMO_MODE:
|
||||
# In demo mode, allow specific weak passwords for demo accounts
|
||||
demo_passwords = {"Demo123!", "Admin123!"}
|
||||
if password in demo_passwords:
|
||||
return password
|
||||
|
||||
# Check minimum length
|
||||
if len(password) < 12:
|
||||
raise ValueError('Password must be at least 12 characters long')
|
||||
raise ValueError("Password must be at least 12 characters long")
|
||||
|
||||
# Check against common passwords (case-insensitive)
|
||||
if password.lower() in COMMON_PASSWORDS:
|
||||
raise ValueError('Password is too common. Please choose a stronger password')
|
||||
raise ValueError("Password is too common. Please choose a stronger password")
|
||||
|
||||
# Check for required character types
|
||||
checks = [
|
||||
(any(c.islower() for c in password), 'at least one lowercase letter'),
|
||||
(any(c.isupper() for c in password), 'at least one uppercase letter'),
|
||||
(any(c.isdigit() for c in password), 'at least one digit'),
|
||||
(any(c in '!@#$%^&*()_+-=[]{}|;:,.<>?~`' for c in password), 'at least one special character (!@#$%^&*()_+-=[]{}|;:,.<>?~`)')
|
||||
(any(c.islower() for c in password), "at least one lowercase letter"),
|
||||
(any(c.isupper() for c in password), "at least one uppercase letter"),
|
||||
(any(c.isdigit() for c in password), "at least one digit"),
|
||||
(
|
||||
any(c in "!@#$%^&*()_+-=[]{}|;:,.<>?~`" for c in password),
|
||||
"at least one special character (!@#$%^&*()_+-=[]{}|;:,.<>?~`)",
|
||||
),
|
||||
]
|
||||
|
||||
failed = [msg for check, msg in checks if not check]
|
||||
@@ -94,10 +121,10 @@ def validate_phone_number(phone: str | None) -> str | None:
|
||||
|
||||
# Check for empty strings
|
||||
if not phone or phone.strip() == "":
|
||||
raise ValueError('Phone number cannot be empty')
|
||||
raise ValueError("Phone number cannot be empty")
|
||||
|
||||
# Remove all spaces and formatting characters
|
||||
cleaned = re.sub(r'[\s\-\(\)]', '', phone)
|
||||
cleaned = re.sub(r"[\s\-\(\)]", "", phone)
|
||||
|
||||
# Basic pattern:
|
||||
# Must start with + or 0
|
||||
@@ -105,19 +132,19 @@ def validate_phone_number(phone: str | None) -> str | None:
|
||||
# After 0 must have at least 8 digits
|
||||
# Maximum total length of 15 digits (international standard)
|
||||
# Only allowed characters are + at start and digits
|
||||
pattern = r'^(?:\+[0-9]{8,14}|0[0-9]{8,14})$'
|
||||
pattern = r"^(?:\+[0-9]{8,14}|0[0-9]{8,14})$"
|
||||
|
||||
if not re.match(pattern, cleaned):
|
||||
raise ValueError('Phone number must start with + or 0 followed by 8-14 digits')
|
||||
raise ValueError("Phone number must start with + or 0 followed by 8-14 digits")
|
||||
|
||||
# Additional validation to catch specific invalid cases
|
||||
# NOTE: These checks are defensive code - the regex pattern above already catches these cases
|
||||
if cleaned.count('+') > 1: # pragma: no cover
|
||||
raise ValueError('Phone number can only contain one + symbol at the start')
|
||||
if cleaned.count("+") > 1: # pragma: no cover
|
||||
raise ValueError("Phone number can only contain one + symbol at the start")
|
||||
|
||||
# Check for any non-digit characters (except the leading +)
|
||||
if not all(c.isdigit() for c in cleaned[1:]): # pragma: no cover
|
||||
raise ValueError('Phone number can only contain digits after the prefix')
|
||||
raise ValueError("Phone number can only contain digits after the prefix")
|
||||
|
||||
return cleaned
|
||||
|
||||
@@ -169,16 +196,16 @@ def validate_slug(slug: str) -> str:
|
||||
ValueError: If slug format is invalid
|
||||
"""
|
||||
if not slug or len(slug) < 2:
|
||||
raise ValueError('Slug must be at least 2 characters long')
|
||||
raise ValueError("Slug must be at least 2 characters long")
|
||||
|
||||
if len(slug) > 50:
|
||||
raise ValueError('Slug must be at most 50 characters long')
|
||||
raise ValueError("Slug must be at most 50 characters long")
|
||||
|
||||
# Check format
|
||||
if not re.match(r'^[a-z0-9]+(?:-[a-z0-9]+)*$', slug):
|
||||
if not re.match(r"^[a-z0-9]+(?:-[a-z0-9]+)*$", slug):
|
||||
raise ValueError(
|
||||
'Slug can only contain lowercase letters, numbers, and hyphens. '
|
||||
'It cannot start or end with a hyphen, and cannot contain consecutive hyphens'
|
||||
"Slug can only contain lowercase letters, numbers, and hyphens. "
|
||||
"It cannot start or end with a hyphen, and cannot contain consecutive hyphens"
|
||||
)
|
||||
|
||||
return slug
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
# app/services/__init__.py
|
||||
from . import oauth_provider_service
|
||||
from .auth_service import AuthService
|
||||
from .oauth_service import OAuthService
|
||||
from .organization_service import OrganizationService, organization_service
|
||||
from .session_service import SessionService, session_service
|
||||
from .user_service import UserService, user_service
|
||||
|
||||
__all__ = [
|
||||
"AuthService",
|
||||
"OAuthService",
|
||||
"OrganizationService",
|
||||
"SessionService",
|
||||
"UserService",
|
||||
"oauth_provider_service",
|
||||
"organization_service",
|
||||
"session_service",
|
||||
"user_service",
|
||||
]
|
||||
|
||||
@@ -1,32 +1,38 @@
|
||||
# app/services/auth_service.py
|
||||
import logging
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.auth import (
|
||||
verify_password_async,
|
||||
get_password_hash_async,
|
||||
TokenExpiredError,
|
||||
TokenInvalidError,
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
TokenExpiredError,
|
||||
TokenInvalidError
|
||||
get_password_hash_async,
|
||||
verify_password_async,
|
||||
)
|
||||
from app.core.config import settings
|
||||
from app.core.exceptions import AuthenticationError
|
||||
from app.core.exceptions import AuthenticationError, DuplicateError
|
||||
from app.core.repository_exceptions import DuplicateEntryError
|
||||
from app.models.user import User
|
||||
from app.repositories.user import user_repo
|
||||
from app.schemas.users import Token, UserCreate, UserResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Pre-computed bcrypt hash used for constant-time comparison when user is not found,
|
||||
# preventing timing attacks that could enumerate valid email addresses.
|
||||
_DUMMY_HASH = "$2b$12$EixZaYVK1fsbw1ZfbX3OXePaWxn96p36zLFbnJHfxPSEFBzXKiHia"
|
||||
|
||||
|
||||
class AuthService:
|
||||
"""Service for handling authentication operations"""
|
||||
|
||||
@staticmethod
|
||||
async def authenticate_user(db: AsyncSession, email: str, password: str) -> Optional[User]:
|
||||
async def authenticate_user(
|
||||
db: AsyncSession, email: str, password: str
|
||||
) -> User | None:
|
||||
"""
|
||||
Authenticate a user with email and password using async password verification.
|
||||
|
||||
@@ -38,10 +44,12 @@ class AuthService:
|
||||
Returns:
|
||||
User if authenticated, None otherwise
|
||||
"""
|
||||
result = await db.execute(select(User).where(User.email == email))
|
||||
user = result.scalar_one_or_none()
|
||||
user = await user_repo.get_by_email(db, email=email)
|
||||
|
||||
if not user:
|
||||
# Perform a dummy verification to match timing of a real bcrypt check,
|
||||
# preventing email enumeration via response-time differences.
|
||||
await verify_password_async(password, _DUMMY_HASH)
|
||||
return None
|
||||
|
||||
# Verify password asynchronously to avoid blocking event loop
|
||||
@@ -70,41 +78,24 @@ class AuthService:
|
||||
"""
|
||||
try:
|
||||
# Check if user already exists
|
||||
result = await db.execute(select(User).where(User.email == user_data.email))
|
||||
existing_user = result.scalar_one_or_none()
|
||||
existing_user = await user_repo.get_by_email(db, email=user_data.email)
|
||||
if existing_user:
|
||||
raise AuthenticationError("User with this email already exists")
|
||||
raise DuplicateError("User with this email already exists")
|
||||
|
||||
# Create new user with async password hashing
|
||||
# Hash password asynchronously to avoid blocking event loop
|
||||
hashed_password = await get_password_hash_async(user_data.password)
|
||||
# Delegate creation (hashing + commit) to the repository
|
||||
user = await user_repo.create(db, obj_in=user_data)
|
||||
|
||||
# Create user object from model
|
||||
user = User(
|
||||
email=user_data.email,
|
||||
password_hash=hashed_password,
|
||||
first_name=user_data.first_name,
|
||||
last_name=user_data.last_name,
|
||||
phone_number=user_data.phone_number,
|
||||
is_active=True,
|
||||
is_superuser=False
|
||||
)
|
||||
|
||||
db.add(user)
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
|
||||
logger.info(f"User created successfully: {user.email}")
|
||||
logger.info("User created successfully: %s", user.email)
|
||||
return user
|
||||
|
||||
except AuthenticationError:
|
||||
# Re-raise authentication errors without rollback
|
||||
except (AuthenticationError, DuplicateError):
|
||||
# Re-raise API exceptions without rollback
|
||||
raise
|
||||
except DuplicateEntryError as e:
|
||||
raise DuplicateError(str(e))
|
||||
except Exception as e:
|
||||
# Rollback on any database errors
|
||||
await db.rollback()
|
||||
logger.error(f"Error creating user: {str(e)}", exc_info=True)
|
||||
raise AuthenticationError(f"Failed to create user: {str(e)}")
|
||||
logger.exception("Error creating user: %s", e)
|
||||
raise AuthenticationError(f"Failed to create user: {e!s}")
|
||||
|
||||
@staticmethod
|
||||
def create_tokens(user: User) -> Token:
|
||||
@@ -121,18 +112,13 @@ class AuthService:
|
||||
claims = {
|
||||
"is_superuser": user.is_superuser,
|
||||
"email": user.email,
|
||||
"first_name": user.first_name
|
||||
"first_name": user.first_name,
|
||||
}
|
||||
|
||||
# Create tokens
|
||||
access_token = create_access_token(
|
||||
subject=str(user.id),
|
||||
claims=claims
|
||||
)
|
||||
access_token = create_access_token(subject=str(user.id), claims=claims)
|
||||
|
||||
refresh_token = create_refresh_token(
|
||||
subject=str(user.id)
|
||||
)
|
||||
refresh_token = create_refresh_token(subject=str(user.id))
|
||||
|
||||
# Convert User model to UserResponse schema
|
||||
user_response = UserResponse.model_validate(user)
|
||||
@@ -141,7 +127,8 @@ class AuthService:
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
user=user_response,
|
||||
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60 # Convert minutes to seconds
|
||||
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
* 60, # Convert minutes to seconds
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -171,8 +158,7 @@ class AuthService:
|
||||
user_id = token_data.user_id
|
||||
|
||||
# Get user from database
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
user = await user_repo.get(db, id=str(user_id))
|
||||
if not user or not user.is_active:
|
||||
raise TokenInvalidError("Invalid user or inactive account")
|
||||
|
||||
@@ -180,11 +166,13 @@ class AuthService:
|
||||
return AuthService.create_tokens(user)
|
||||
|
||||
except (TokenExpiredError, TokenInvalidError) as e:
|
||||
logger.warning(f"Token refresh failed: {str(e)}")
|
||||
logger.warning("Token refresh failed: %s", e)
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
async def change_password(db: AsyncSession, user_id: UUID, current_password: str, new_password: str) -> bool:
|
||||
async def change_password(
|
||||
db: AsyncSession, user_id: UUID, current_password: str, new_password: str
|
||||
) -> bool:
|
||||
"""
|
||||
Change a user's password.
|
||||
|
||||
@@ -201,8 +189,7 @@ class AuthService:
|
||||
AuthenticationError: If current password is incorrect or update fails
|
||||
"""
|
||||
try:
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
user = await user_repo.get(db, id=str(user_id))
|
||||
if not user:
|
||||
raise AuthenticationError("User not found")
|
||||
|
||||
@@ -211,10 +198,10 @@ class AuthService:
|
||||
raise AuthenticationError("Current password is incorrect")
|
||||
|
||||
# Hash new password asynchronously to avoid blocking event loop
|
||||
user.password_hash = await get_password_hash_async(new_password)
|
||||
await db.commit()
|
||||
new_hash = await get_password_hash_async(new_password)
|
||||
await user_repo.update_password(db, user=user, password_hash=new_hash)
|
||||
|
||||
logger.info(f"Password changed successfully for user {user_id}")
|
||||
logger.info("Password changed successfully for user %s", user_id)
|
||||
return True
|
||||
|
||||
except AuthenticationError:
|
||||
@@ -223,5 +210,34 @@ class AuthService:
|
||||
except Exception as e:
|
||||
# Rollback on any database errors
|
||||
await db.rollback()
|
||||
logger.error(f"Error changing password for user {user_id}: {str(e)}", exc_info=True)
|
||||
raise AuthenticationError(f"Failed to change password: {str(e)}")
|
||||
logger.exception("Error changing password for user %s: %s", user_id, e)
|
||||
raise AuthenticationError(f"Failed to change password: {e!s}")
|
||||
|
||||
@staticmethod
|
||||
async def reset_password(
|
||||
db: AsyncSession, *, email: str, new_password: str
|
||||
) -> User:
|
||||
"""
|
||||
Reset a user's password without requiring the current password.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
email: User email address
|
||||
new_password: New password to set
|
||||
|
||||
Returns:
|
||||
Updated user
|
||||
|
||||
Raises:
|
||||
AuthenticationError: If user not found or inactive
|
||||
"""
|
||||
user = await user_repo.get_by_email(db, email=email)
|
||||
if not user:
|
||||
raise AuthenticationError("User not found")
|
||||
if not user.is_active:
|
||||
raise AuthenticationError("User account is inactive")
|
||||
|
||||
new_hash = await get_password_hash_async(new_password)
|
||||
user = await user_repo.update_password(db, user=user, password_hash=new_hash)
|
||||
logger.info("Password reset successfully for %s", email)
|
||||
return user
|
||||
|
||||
@@ -5,9 +5,9 @@ Email service with placeholder implementation.
|
||||
This service provides email sending functionality with a simple console/log-based
|
||||
placeholder that can be easily replaced with a real email provider (SendGrid, SES, etc.)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
@@ -20,13 +20,12 @@ class EmailBackend(ABC):
|
||||
@abstractmethod
|
||||
async def send_email(
|
||||
self,
|
||||
to: List[str],
|
||||
to: list[str],
|
||||
subject: str,
|
||||
html_content: str,
|
||||
text_content: Optional[str] = None
|
||||
text_content: str | None = None,
|
||||
) -> bool:
|
||||
"""Send an email."""
|
||||
pass
|
||||
|
||||
|
||||
class ConsoleEmailBackend(EmailBackend):
|
||||
@@ -39,10 +38,10 @@ class ConsoleEmailBackend(EmailBackend):
|
||||
|
||||
async def send_email(
|
||||
self,
|
||||
to: List[str],
|
||||
to: list[str],
|
||||
subject: str,
|
||||
html_content: str,
|
||||
text_content: Optional[str] = None
|
||||
text_content: str | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Log email content to console/logs.
|
||||
@@ -59,8 +58,8 @@ class ConsoleEmailBackend(EmailBackend):
|
||||
logger.info("=" * 80)
|
||||
logger.info("EMAIL SENT (Console Backend)")
|
||||
logger.info("=" * 80)
|
||||
logger.info(f"To: {', '.join(to)}")
|
||||
logger.info(f"Subject: {subject}")
|
||||
logger.info("To: %s", ", ".join(to))
|
||||
logger.info("Subject: %s", subject)
|
||||
logger.info("-" * 80)
|
||||
if text_content:
|
||||
logger.info("Plain Text Content:")
|
||||
@@ -88,10 +87,10 @@ class SMTPEmailBackend(EmailBackend):
|
||||
|
||||
async def send_email(
|
||||
self,
|
||||
to: List[str],
|
||||
to: list[str],
|
||||
subject: str,
|
||||
html_content: str,
|
||||
text_content: Optional[str] = None
|
||||
text_content: str | None = None,
|
||||
) -> bool:
|
||||
"""Send email via SMTP."""
|
||||
# TODO: Implement SMTP sending
|
||||
@@ -108,7 +107,7 @@ class EmailService:
|
||||
and can be configured to use different backends (console, SMTP, SendGrid, etc.)
|
||||
"""
|
||||
|
||||
def __init__(self, backend: Optional[EmailBackend] = None):
|
||||
def __init__(self, backend: EmailBackend | None = None):
|
||||
"""
|
||||
Initialize email service with a backend.
|
||||
|
||||
@@ -118,10 +117,7 @@ class EmailService:
|
||||
self.backend = backend or ConsoleEmailBackend()
|
||||
|
||||
async def send_password_reset_email(
|
||||
self,
|
||||
to_email: str,
|
||||
reset_token: str,
|
||||
user_name: Optional[str] = None
|
||||
self, to_email: str, reset_token: str, user_name: str | None = None
|
||||
) -> bool:
|
||||
"""
|
||||
Send password reset email.
|
||||
@@ -142,7 +138,7 @@ class EmailService:
|
||||
|
||||
# Plain text version
|
||||
text_content = f"""
|
||||
Hello{' ' + user_name if user_name else ''},
|
||||
Hello{" " + user_name if user_name else ""},
|
||||
|
||||
You requested a password reset for your account. Click the link below to reset your password:
|
||||
|
||||
@@ -177,7 +173,7 @@ The {settings.PROJECT_NAME} Team
|
||||
<h1>Password Reset</h1>
|
||||
</div>
|
||||
<div class="content">
|
||||
<p>Hello{' ' + user_name if user_name else ''},</p>
|
||||
<p>Hello{" " + user_name if user_name else ""},</p>
|
||||
<p>You requested a password reset for your account. Click the button below to reset your password:</p>
|
||||
<p style="text-align: center;">
|
||||
<a href="{reset_url}" class="button">Reset Password</a>
|
||||
@@ -200,17 +196,14 @@ The {settings.PROJECT_NAME} Team
|
||||
to=[to_email],
|
||||
subject=subject,
|
||||
html_content=html_content,
|
||||
text_content=text_content
|
||||
text_content=text_content,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send password reset email to {to_email}: {str(e)}")
|
||||
logger.error("Failed to send password reset email to %s: %s", to_email, e)
|
||||
return False
|
||||
|
||||
async def send_email_verification(
|
||||
self,
|
||||
to_email: str,
|
||||
verification_token: str,
|
||||
user_name: Optional[str] = None
|
||||
self, to_email: str, verification_token: str, user_name: str | None = None
|
||||
) -> bool:
|
||||
"""
|
||||
Send email verification email.
|
||||
@@ -224,14 +217,16 @@ The {settings.PROJECT_NAME} Team
|
||||
True if email sent successfully
|
||||
"""
|
||||
# Generate verification URL
|
||||
verification_url = f"{settings.FRONTEND_URL}/verify-email?token={verification_token}"
|
||||
verification_url = (
|
||||
f"{settings.FRONTEND_URL}/verify-email?token={verification_token}"
|
||||
)
|
||||
|
||||
# Prepare email content
|
||||
subject = "Verify Your Email Address"
|
||||
|
||||
# Plain text version
|
||||
text_content = f"""
|
||||
Hello{' ' + user_name if user_name else ''},
|
||||
Hello{" " + user_name if user_name else ""},
|
||||
|
||||
Thank you for signing up! Please verify your email address by clicking the link below:
|
||||
|
||||
@@ -266,7 +261,7 @@ The {settings.PROJECT_NAME} Team
|
||||
<h1>Verify Your Email</h1>
|
||||
</div>
|
||||
<div class="content">
|
||||
<p>Hello{' ' + user_name if user_name else ''},</p>
|
||||
<p>Hello{" " + user_name if user_name else ""},</p>
|
||||
<p>Thank you for signing up! Please verify your email address by clicking the button below:</p>
|
||||
<p style="text-align: center;">
|
||||
<a href="{verification_url}" class="button">Verify Email</a>
|
||||
@@ -289,10 +284,10 @@ The {settings.PROJECT_NAME} Team
|
||||
to=[to_email],
|
||||
subject=subject,
|
||||
html_content=html_content,
|
||||
text_content=text_content
|
||||
text_content=text_content,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send verification email to {to_email}: {str(e)}")
|
||||
logger.error("Failed to send verification email to %s: %s", to_email, e)
|
||||
return False
|
||||
|
||||
|
||||
|
||||
970
backend/app/services/oauth_provider_service.py
Executable file
970
backend/app/services/oauth_provider_service.py
Executable file
@@ -0,0 +1,970 @@
|
||||
"""
|
||||
OAuth Provider Service for MCP integration.
|
||||
|
||||
Implements OAuth 2.0 Authorization Server functionality:
|
||||
- Authorization code flow with PKCE
|
||||
- Token issuance (JWT access tokens, opaque refresh tokens)
|
||||
- Token refresh
|
||||
- Token revocation
|
||||
- Consent management
|
||||
|
||||
Security features:
|
||||
- PKCE required for public clients (S256)
|
||||
- Short-lived authorization codes (10 minutes)
|
||||
- JWT access tokens (self-contained, no DB lookup)
|
||||
- Secure refresh token storage (hashed)
|
||||
- Token rotation on refresh
|
||||
- Comprehensive validation
|
||||
"""
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import logging
|
||||
import secrets
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
import jwt
|
||||
from jwt.exceptions import ExpiredSignatureError, InvalidTokenError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.config import settings
|
||||
from app.models.oauth_client import OAuthClient
|
||||
from app.models.user import User
|
||||
from app.repositories.oauth_authorization_code import oauth_authorization_code_repo
|
||||
from app.repositories.oauth_client import oauth_client_repo
|
||||
from app.repositories.oauth_consent import oauth_consent_repo
|
||||
from app.repositories.oauth_provider_token import oauth_provider_token_repo
|
||||
from app.repositories.user import user_repo
|
||||
from app.schemas.oauth import OAuthClientCreate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Constants
|
||||
AUTHORIZATION_CODE_EXPIRY_MINUTES = 10
|
||||
ACCESS_TOKEN_EXPIRY_MINUTES = 60 # 1 hour for MCP clients
|
||||
REFRESH_TOKEN_EXPIRY_DAYS = 30
|
||||
|
||||
|
||||
class OAuthProviderError(Exception):
|
||||
"""Base exception for OAuth provider errors."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
error: str,
|
||||
error_description: str | None = None,
|
||||
error_uri: str | None = None,
|
||||
):
|
||||
self.error = error
|
||||
self.error_description = error_description
|
||||
self.error_uri = error_uri
|
||||
super().__init__(error_description or error)
|
||||
|
||||
|
||||
class InvalidClientError(OAuthProviderError):
|
||||
"""Client authentication failed."""
|
||||
|
||||
def __init__(self, description: str = "Invalid client credentials"):
|
||||
super().__init__("invalid_client", description)
|
||||
|
||||
|
||||
class InvalidGrantError(OAuthProviderError):
|
||||
"""Invalid authorization grant."""
|
||||
|
||||
def __init__(self, description: str = "Invalid grant"):
|
||||
super().__init__("invalid_grant", description)
|
||||
|
||||
|
||||
class InvalidRequestError(OAuthProviderError):
|
||||
"""Invalid request parameters."""
|
||||
|
||||
def __init__(self, description: str = "Invalid request"):
|
||||
super().__init__("invalid_request", description)
|
||||
|
||||
|
||||
class InvalidScopeError(OAuthProviderError):
|
||||
"""Invalid scope requested."""
|
||||
|
||||
def __init__(self, description: str = "Invalid scope"):
|
||||
super().__init__("invalid_scope", description)
|
||||
|
||||
|
||||
class UnauthorizedClientError(OAuthProviderError):
|
||||
"""Client not authorized for this grant type."""
|
||||
|
||||
def __init__(self, description: str = "Unauthorized client"):
|
||||
super().__init__("unauthorized_client", description)
|
||||
|
||||
|
||||
class AccessDeniedError(OAuthProviderError):
|
||||
"""User denied authorization."""
|
||||
|
||||
def __init__(self, description: str = "Access denied"):
|
||||
super().__init__("access_denied", description)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Helper Functions
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def generate_code() -> str:
|
||||
"""Generate a cryptographically secure authorization code."""
|
||||
return secrets.token_urlsafe(64)
|
||||
|
||||
|
||||
def generate_token() -> str:
|
||||
"""Generate a cryptographically secure token."""
|
||||
return secrets.token_urlsafe(48)
|
||||
|
||||
|
||||
def generate_jti() -> str:
|
||||
"""Generate a unique JWT ID."""
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
|
||||
def hash_token(token: str) -> str:
|
||||
"""Hash a token using SHA-256."""
|
||||
return hashlib.sha256(token.encode()).hexdigest()
|
||||
|
||||
|
||||
def verify_pkce(code_verifier: str, code_challenge: str, method: str) -> bool:
|
||||
"""
|
||||
Verify PKCE code_verifier against stored code_challenge.
|
||||
|
||||
SECURITY: Only S256 method is supported. The 'plain' method provides
|
||||
no security benefit and is explicitly rejected per RFC 7636 Section 4.3.
|
||||
"""
|
||||
if method != "S256":
|
||||
# SECURITY: Reject any method other than S256
|
||||
# 'plain' method provides no security against code interception attacks
|
||||
logger.warning("PKCE verification rejected for unsupported method: %s", method)
|
||||
return False
|
||||
|
||||
# SHA-256 hash, then base64url encode (RFC 7636 Section 4.2)
|
||||
digest = hashlib.sha256(code_verifier.encode()).digest()
|
||||
computed = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
|
||||
return secrets.compare_digest(computed, code_challenge)
|
||||
|
||||
|
||||
def parse_scope(scope: str) -> list[str]:
|
||||
"""Parse space-separated scope string into list."""
|
||||
return [s.strip() for s in scope.split() if s.strip()]
|
||||
|
||||
|
||||
def join_scope(scopes: list[str]) -> str:
|
||||
"""Join scope list into space-separated string."""
|
||||
return " ".join(sorted(set(scopes)))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Client Validation
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def get_client(db: AsyncSession, client_id: str) -> OAuthClient | None:
|
||||
"""Get OAuth client by client_id."""
|
||||
return await oauth_client_repo.get_by_client_id(db, client_id=client_id)
|
||||
|
||||
|
||||
async def validate_client(
|
||||
db: AsyncSession,
|
||||
client_id: str,
|
||||
client_secret: str | None = None,
|
||||
require_secret: bool = False,
|
||||
) -> OAuthClient:
|
||||
"""
|
||||
Validate OAuth client credentials.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
client_id: Client identifier
|
||||
client_secret: Client secret (required for confidential clients)
|
||||
require_secret: Whether to require secret validation
|
||||
|
||||
Returns:
|
||||
Validated OAuthClient
|
||||
|
||||
Raises:
|
||||
InvalidClientError: If client validation fails
|
||||
"""
|
||||
client = await get_client(db, client_id)
|
||||
if not client:
|
||||
raise InvalidClientError("Unknown client_id")
|
||||
|
||||
# Confidential clients must provide valid secret
|
||||
if client.client_type == "confidential" or require_secret:
|
||||
if not client_secret:
|
||||
raise InvalidClientError("Client secret required")
|
||||
if not client.client_secret_hash:
|
||||
raise InvalidClientError("Client not configured with secret")
|
||||
|
||||
# SECURITY: Verify secret using bcrypt
|
||||
from app.core.auth import verify_password
|
||||
|
||||
stored_hash = str(client.client_secret_hash)
|
||||
|
||||
if not stored_hash.startswith("$2"):
|
||||
raise InvalidClientError(
|
||||
"Client secret uses deprecated hash format. "
|
||||
"Please regenerate your client credentials."
|
||||
)
|
||||
|
||||
if not verify_password(client_secret, stored_hash):
|
||||
raise InvalidClientError("Invalid client secret")
|
||||
|
||||
return client
|
||||
|
||||
|
||||
def validate_redirect_uri(client: OAuthClient, redirect_uri: str) -> None:
|
||||
"""
|
||||
Validate redirect_uri against client's registered URIs.
|
||||
|
||||
Raises:
|
||||
InvalidRequestError: If redirect_uri is not registered
|
||||
"""
|
||||
if not client.redirect_uris:
|
||||
raise InvalidRequestError("Client has no registered redirect URIs")
|
||||
|
||||
if redirect_uri not in client.redirect_uris:
|
||||
raise InvalidRequestError("Invalid redirect_uri")
|
||||
|
||||
|
||||
def validate_scopes(client: OAuthClient, requested_scopes: list[str]) -> list[str]:
|
||||
"""
|
||||
Validate requested scopes against client's allowed scopes.
|
||||
|
||||
Returns:
|
||||
List of valid scopes (intersection of requested and allowed)
|
||||
|
||||
Raises:
|
||||
InvalidScopeError: If no valid scopes
|
||||
"""
|
||||
allowed = set(client.allowed_scopes or [])
|
||||
requested = set(requested_scopes)
|
||||
|
||||
# If no scopes requested, use all allowed scopes
|
||||
if not requested:
|
||||
return list(allowed)
|
||||
|
||||
valid = requested & allowed
|
||||
if not valid:
|
||||
raise InvalidScopeError(
|
||||
"None of the requested scopes are allowed for this client"
|
||||
)
|
||||
|
||||
# Warn if some scopes were filtered out
|
||||
invalid = requested - allowed
|
||||
if invalid:
|
||||
logger.warning(
|
||||
"Client %s requested invalid scopes: %s", client.client_id, invalid
|
||||
)
|
||||
|
||||
return list(valid)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Authorization Code Flow
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def create_authorization_code(
|
||||
db: AsyncSession,
|
||||
client: OAuthClient,
|
||||
user: User,
|
||||
redirect_uri: str,
|
||||
scope: str,
|
||||
code_challenge: str | None = None,
|
||||
code_challenge_method: str | None = None,
|
||||
state: str | None = None,
|
||||
nonce: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Create an authorization code for the authorization code flow.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
client: Validated OAuth client
|
||||
user: Authenticated user
|
||||
redirect_uri: Validated redirect URI
|
||||
scope: Granted scopes (space-separated)
|
||||
code_challenge: PKCE code challenge
|
||||
code_challenge_method: PKCE method (S256)
|
||||
state: CSRF state parameter
|
||||
nonce: OpenID Connect nonce
|
||||
|
||||
Returns:
|
||||
Authorization code string
|
||||
"""
|
||||
# Public clients MUST use PKCE
|
||||
if client.client_type == "public":
|
||||
if not code_challenge or code_challenge_method != "S256":
|
||||
raise InvalidRequestError("PKCE with S256 is required for public clients")
|
||||
|
||||
code = generate_code()
|
||||
expires_at = datetime.now(UTC) + timedelta(
|
||||
minutes=AUTHORIZATION_CODE_EXPIRY_MINUTES
|
||||
)
|
||||
|
||||
await oauth_authorization_code_repo.create_code(
|
||||
db,
|
||||
code=code,
|
||||
client_id=client.client_id,
|
||||
user_id=user.id,
|
||||
redirect_uri=redirect_uri,
|
||||
scope=scope,
|
||||
expires_at=expires_at,
|
||||
code_challenge=code_challenge,
|
||||
code_challenge_method=code_challenge_method,
|
||||
state=state,
|
||||
nonce=nonce,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Created authorization code for user %s and client %s",
|
||||
user.id,
|
||||
client.client_id,
|
||||
)
|
||||
return code
|
||||
|
||||
|
||||
async def exchange_authorization_code(
|
||||
db: AsyncSession,
|
||||
code: str,
|
||||
client_id: str,
|
||||
redirect_uri: str,
|
||||
code_verifier: str | None = None,
|
||||
client_secret: str | None = None,
|
||||
device_info: str | None = None,
|
||||
ip_address: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Exchange authorization code for tokens.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
code: Authorization code
|
||||
client_id: Client identifier
|
||||
redirect_uri: Must match the original redirect_uri
|
||||
code_verifier: PKCE code verifier
|
||||
client_secret: Client secret (for confidential clients)
|
||||
device_info: Optional device information
|
||||
ip_address: Optional IP address
|
||||
|
||||
Returns:
|
||||
Token response dict with access_token, refresh_token, etc.
|
||||
|
||||
Raises:
|
||||
InvalidGrantError: If code is invalid, expired, or already used
|
||||
InvalidClientError: If client validation fails
|
||||
"""
|
||||
# Atomically mark code as used and fetch it (prevents race condition)
|
||||
# RFC 6749 Section 4.1.2: Authorization codes MUST be single-use
|
||||
updated_id = await oauth_authorization_code_repo.consume_code_atomically(
|
||||
db, code=code
|
||||
)
|
||||
|
||||
if not updated_id:
|
||||
# Either code doesn't exist or was already used
|
||||
# Check if it exists to provide appropriate error
|
||||
existing_code = await oauth_authorization_code_repo.get_by_code(db, code=code)
|
||||
|
||||
if existing_code and existing_code.used:
|
||||
# Code reuse is a security incident - revoke all tokens for this grant
|
||||
logger.warning(
|
||||
"Authorization code reuse detected for client %s",
|
||||
existing_code.client_id,
|
||||
)
|
||||
await revoke_tokens_for_user_client(
|
||||
db, UUID(str(existing_code.user_id)), str(existing_code.client_id)
|
||||
)
|
||||
raise InvalidGrantError("Authorization code has already been used")
|
||||
else:
|
||||
raise InvalidGrantError("Invalid authorization code")
|
||||
|
||||
# Now fetch the full auth code record
|
||||
auth_code = await oauth_authorization_code_repo.get_by_id(db, code_id=updated_id)
|
||||
if auth_code is None:
|
||||
raise InvalidGrantError("Authorization code not found after consumption")
|
||||
|
||||
if auth_code.is_expired:
|
||||
raise InvalidGrantError("Authorization code has expired")
|
||||
|
||||
if auth_code.client_id != client_id:
|
||||
raise InvalidGrantError("Authorization code was not issued to this client")
|
||||
|
||||
if auth_code.redirect_uri != redirect_uri:
|
||||
raise InvalidGrantError("redirect_uri mismatch")
|
||||
|
||||
# Validate client - ALWAYS require secret for confidential clients
|
||||
client = await get_client(db, client_id)
|
||||
if not client:
|
||||
raise InvalidClientError("Unknown client_id")
|
||||
|
||||
# Confidential clients MUST authenticate (RFC 6749 Section 3.2.1)
|
||||
if client.client_type == "confidential":
|
||||
if not client_secret:
|
||||
raise InvalidClientError("Client secret required for confidential clients")
|
||||
client = await validate_client(
|
||||
db, client_id, client_secret, require_secret=True
|
||||
)
|
||||
elif client_secret:
|
||||
# Public client provided secret - validate it if given
|
||||
client = await validate_client(
|
||||
db, client_id, client_secret, require_secret=True
|
||||
)
|
||||
|
||||
# Verify PKCE
|
||||
if auth_code.code_challenge:
|
||||
if not code_verifier:
|
||||
raise InvalidGrantError("code_verifier required")
|
||||
if not verify_pkce(
|
||||
code_verifier,
|
||||
str(auth_code.code_challenge),
|
||||
str(auth_code.code_challenge_method or "S256"),
|
||||
):
|
||||
raise InvalidGrantError("Invalid code_verifier")
|
||||
elif client.client_type == "public":
|
||||
# Public clients without PKCE - this shouldn't happen if we validated on authorize
|
||||
raise InvalidGrantError("PKCE required for public clients")
|
||||
|
||||
# Get user
|
||||
user = await user_repo.get(db, id=str(auth_code.user_id))
|
||||
if not user or not user.is_active:
|
||||
raise InvalidGrantError("User not found or inactive")
|
||||
|
||||
# Generate tokens
|
||||
return await create_tokens(
|
||||
db=db,
|
||||
client=client,
|
||||
user=user,
|
||||
scope=str(auth_code.scope),
|
||||
nonce=str(auth_code.nonce) if auth_code.nonce else None,
|
||||
device_info=device_info,
|
||||
ip_address=ip_address,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Token Generation
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def create_tokens(
|
||||
db: AsyncSession,
|
||||
client: OAuthClient,
|
||||
user: User,
|
||||
scope: str,
|
||||
nonce: str | None = None,
|
||||
device_info: str | None = None,
|
||||
ip_address: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Create access and refresh tokens.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
client: OAuth client
|
||||
user: User
|
||||
scope: Granted scopes
|
||||
nonce: OpenID Connect nonce (included in ID token)
|
||||
device_info: Optional device information
|
||||
ip_address: Optional IP address
|
||||
|
||||
Returns:
|
||||
Token response dict
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
jti = generate_jti()
|
||||
|
||||
# Access token expiry
|
||||
access_token_lifetime = int(client.access_token_lifetime or "3600")
|
||||
access_expires = now + timedelta(seconds=access_token_lifetime)
|
||||
|
||||
# Refresh token expiry
|
||||
refresh_token_lifetime = int(
|
||||
client.refresh_token_lifetime or str(REFRESH_TOKEN_EXPIRY_DAYS * 86400)
|
||||
)
|
||||
refresh_expires = now + timedelta(seconds=refresh_token_lifetime)
|
||||
|
||||
# Create JWT access token
|
||||
# SECURITY: Include all standard JWT claims per RFC 7519
|
||||
access_token_payload = {
|
||||
"iss": settings.OAUTH_ISSUER,
|
||||
"sub": str(user.id),
|
||||
"aud": client.client_id,
|
||||
"exp": int(access_expires.timestamp()),
|
||||
"iat": int(now.timestamp()),
|
||||
"nbf": int(now.timestamp()), # Not Before - token is valid immediately
|
||||
"jti": jti,
|
||||
"scope": scope,
|
||||
"client_id": client.client_id,
|
||||
# User info (basic claims)
|
||||
"email": user.email,
|
||||
"name": f"{user.first_name or ''} {user.last_name or ''}".strip() or user.email,
|
||||
}
|
||||
|
||||
# Add nonce for OpenID Connect
|
||||
if nonce:
|
||||
access_token_payload["nonce"] = nonce
|
||||
|
||||
access_token = jwt.encode(
|
||||
access_token_payload,
|
||||
settings.SECRET_KEY,
|
||||
algorithm=settings.ALGORITHM,
|
||||
)
|
||||
|
||||
# Create opaque refresh token
|
||||
refresh_token = generate_token()
|
||||
refresh_token_hash = hash_token(refresh_token)
|
||||
|
||||
# Store refresh token in database
|
||||
await oauth_provider_token_repo.create_token(
|
||||
db,
|
||||
token_hash=refresh_token_hash,
|
||||
jti=jti,
|
||||
client_id=client.client_id,
|
||||
user_id=user.id,
|
||||
scope=scope,
|
||||
expires_at=refresh_expires,
|
||||
device_info=device_info,
|
||||
ip_address=ip_address,
|
||||
)
|
||||
|
||||
logger.info("Issued tokens for user %s to client %s", user.id, client.client_id)
|
||||
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": access_token_lifetime,
|
||||
"refresh_token": refresh_token,
|
||||
"scope": scope,
|
||||
}
|
||||
|
||||
|
||||
async def refresh_tokens(
|
||||
db: AsyncSession,
|
||||
refresh_token: str,
|
||||
client_id: str,
|
||||
client_secret: str | None = None,
|
||||
scope: str | None = None,
|
||||
device_info: str | None = None,
|
||||
ip_address: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Refresh access token using refresh token.
|
||||
|
||||
Implements token rotation - old refresh token is invalidated,
|
||||
new refresh token is issued.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
refresh_token: Refresh token
|
||||
client_id: Client identifier
|
||||
client_secret: Client secret (for confidential clients)
|
||||
scope: Optional reduced scope
|
||||
device_info: Optional device information
|
||||
ip_address: Optional IP address
|
||||
|
||||
Returns:
|
||||
New token response dict
|
||||
|
||||
Raises:
|
||||
InvalidGrantError: If refresh token is invalid
|
||||
"""
|
||||
# Find refresh token
|
||||
token_hash = hash_token(refresh_token)
|
||||
token_record = await oauth_provider_token_repo.get_by_token_hash(
|
||||
db, token_hash=token_hash
|
||||
)
|
||||
|
||||
if not token_record:
|
||||
raise InvalidGrantError("Invalid refresh token")
|
||||
|
||||
if token_record.revoked:
|
||||
# Token reuse after revocation - security incident
|
||||
logger.warning(
|
||||
"Revoked refresh token reuse detected for client %s", token_record.client_id
|
||||
)
|
||||
raise InvalidGrantError("Refresh token has been revoked")
|
||||
|
||||
if token_record.is_expired:
|
||||
raise InvalidGrantError("Refresh token has expired")
|
||||
|
||||
if token_record.client_id != client_id:
|
||||
raise InvalidGrantError("Refresh token was not issued to this client")
|
||||
|
||||
# Validate client
|
||||
client = await validate_client(
|
||||
db,
|
||||
client_id,
|
||||
client_secret,
|
||||
require_secret=(client_secret is not None),
|
||||
)
|
||||
|
||||
# Get user
|
||||
user = await user_repo.get(db, id=str(token_record.user_id))
|
||||
if not user or not user.is_active:
|
||||
raise InvalidGrantError("User not found or inactive")
|
||||
|
||||
# Validate scope (can only reduce, not expand)
|
||||
token_scope = str(token_record.scope) if token_record.scope else ""
|
||||
original_scopes = set(parse_scope(token_scope))
|
||||
if scope:
|
||||
requested_scopes = set(parse_scope(scope))
|
||||
if not requested_scopes.issubset(original_scopes):
|
||||
raise InvalidScopeError("Cannot expand scope on refresh")
|
||||
final_scope = join_scope(list(requested_scopes))
|
||||
else:
|
||||
final_scope = token_scope
|
||||
|
||||
# Revoke old refresh token (token rotation)
|
||||
await oauth_provider_token_repo.revoke(db, token=token_record)
|
||||
|
||||
# Issue new tokens
|
||||
device = str(token_record.device_info) if token_record.device_info else None
|
||||
ip_addr = str(token_record.ip_address) if token_record.ip_address else None
|
||||
return await create_tokens(
|
||||
db=db,
|
||||
client=client,
|
||||
user=user,
|
||||
scope=final_scope,
|
||||
device_info=device_info or device,
|
||||
ip_address=ip_address or ip_addr,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Token Revocation
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def revoke_token(
|
||||
db: AsyncSession,
|
||||
token: str,
|
||||
token_type_hint: str | None = None,
|
||||
client_id: str | None = None,
|
||||
client_secret: str | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Revoke a token (access or refresh).
|
||||
|
||||
For refresh tokens: marks as revoked in database
|
||||
For access tokens: we can't truly revoke JWTs, but we can revoke
|
||||
the associated refresh token to prevent further refreshes
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
token: Token to revoke
|
||||
token_type_hint: "access_token" or "refresh_token"
|
||||
client_id: Client identifier (for validation)
|
||||
client_secret: Client secret (for confidential clients)
|
||||
|
||||
Returns:
|
||||
True if token was revoked, False if not found
|
||||
"""
|
||||
# Try as refresh token first (more likely)
|
||||
if token_type_hint != "access_token":
|
||||
token_hash = hash_token(token)
|
||||
refresh_record = await oauth_provider_token_repo.get_by_token_hash(
|
||||
db, token_hash=token_hash
|
||||
)
|
||||
|
||||
if refresh_record:
|
||||
# Validate client if provided
|
||||
if client_id and refresh_record.client_id != client_id:
|
||||
raise InvalidClientError("Token was not issued to this client")
|
||||
|
||||
await oauth_provider_token_repo.revoke(db, token=refresh_record)
|
||||
logger.info("Revoked refresh token %s...", refresh_record.jti[:8])
|
||||
return True
|
||||
|
||||
# Try as access token (JWT)
|
||||
if token_type_hint != "refresh_token":
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.SECRET_KEY,
|
||||
algorithms=[settings.ALGORITHM],
|
||||
options={
|
||||
"verify_exp": False,
|
||||
"verify_aud": False,
|
||||
}, # Allow expired tokens
|
||||
)
|
||||
jti = payload.get("jti")
|
||||
if jti:
|
||||
# Find and revoke the associated refresh token
|
||||
refresh_record = await oauth_provider_token_repo.get_by_jti(db, jti=jti)
|
||||
if refresh_record:
|
||||
if client_id and refresh_record.client_id != client_id:
|
||||
raise InvalidClientError("Token was not issued to this client")
|
||||
await oauth_provider_token_repo.revoke(db, token=refresh_record)
|
||||
logger.info(
|
||||
"Revoked refresh token via access token JTI %s...", jti[:8]
|
||||
)
|
||||
return True
|
||||
except InvalidTokenError:
|
||||
pass
|
||||
except Exception: # noqa: S110 - Intentional: invalid JWT not an error
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
|
||||
async def revoke_tokens_for_user_client(
|
||||
db: AsyncSession,
|
||||
user_id: UUID,
|
||||
client_id: str,
|
||||
) -> int:
|
||||
"""
|
||||
Revoke all tokens for a specific user-client pair.
|
||||
|
||||
Used when security incidents are detected (e.g., code reuse).
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_id: User identifier
|
||||
client_id: Client identifier
|
||||
|
||||
Returns:
|
||||
Number of tokens revoked
|
||||
"""
|
||||
count = await oauth_provider_token_repo.revoke_all_for_user_client(
|
||||
db, user_id=user_id, client_id=client_id
|
||||
)
|
||||
|
||||
if count > 0:
|
||||
logger.warning(
|
||||
"Revoked %s tokens for user %s and client %s", count, user_id, client_id
|
||||
)
|
||||
|
||||
return count
|
||||
|
||||
|
||||
async def revoke_all_user_tokens(db: AsyncSession, user_id: UUID) -> int:
|
||||
"""
|
||||
Revoke all OAuth provider tokens for a user.
|
||||
|
||||
Used when user changes password or explicitly logs out everywhere.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_id: User identifier
|
||||
|
||||
Returns:
|
||||
Number of tokens revoked
|
||||
"""
|
||||
count = await oauth_provider_token_repo.revoke_all_for_user(db, user_id=user_id)
|
||||
|
||||
if count > 0:
|
||||
logger.info("Revoked %s OAuth provider tokens for user %s", count, user_id)
|
||||
|
||||
return count
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Token Introspection (RFC 7662)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def introspect_token(
|
||||
db: AsyncSession,
|
||||
token: str,
|
||||
token_type_hint: str | None = None,
|
||||
client_id: str | None = None,
|
||||
client_secret: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Introspect a token to determine its validity and metadata.
|
||||
|
||||
Implements RFC 7662 Token Introspection.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
token: Token to introspect
|
||||
token_type_hint: "access_token" or "refresh_token"
|
||||
client_id: Client requesting introspection
|
||||
client_secret: Client secret
|
||||
|
||||
Returns:
|
||||
Introspection response dict
|
||||
"""
|
||||
# Validate client if credentials provided
|
||||
if client_id:
|
||||
await validate_client(db, client_id, client_secret)
|
||||
|
||||
# Try as access token (JWT) first
|
||||
if token_type_hint != "refresh_token":
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.SECRET_KEY,
|
||||
algorithms=[settings.ALGORITHM],
|
||||
options={
|
||||
"verify_aud": False
|
||||
}, # Don't require audience match for introspection
|
||||
)
|
||||
|
||||
# Check if associated refresh token is revoked
|
||||
jti = payload.get("jti")
|
||||
if jti:
|
||||
refresh_record = await oauth_provider_token_repo.get_by_jti(db, jti=jti)
|
||||
if refresh_record and refresh_record.revoked:
|
||||
return {"active": False}
|
||||
|
||||
return {
|
||||
"active": True,
|
||||
"scope": payload.get("scope", ""),
|
||||
"client_id": payload.get("client_id"),
|
||||
"username": payload.get("email"),
|
||||
"token_type": "Bearer",
|
||||
"exp": payload.get("exp"),
|
||||
"iat": payload.get("iat"),
|
||||
"sub": payload.get("sub"),
|
||||
"aud": payload.get("aud"),
|
||||
"iss": payload.get("iss"),
|
||||
}
|
||||
except ExpiredSignatureError:
|
||||
return {"active": False}
|
||||
except InvalidTokenError:
|
||||
pass
|
||||
except Exception: # noqa: S110 - Intentional: invalid JWT falls through to refresh token check
|
||||
pass
|
||||
|
||||
# Try as refresh token
|
||||
if token_type_hint != "access_token":
|
||||
token_hash = hash_token(token)
|
||||
refresh_record = await oauth_provider_token_repo.get_by_token_hash(
|
||||
db, token_hash=token_hash
|
||||
)
|
||||
|
||||
if refresh_record and refresh_record.is_valid:
|
||||
return {
|
||||
"active": True,
|
||||
"scope": refresh_record.scope,
|
||||
"client_id": refresh_record.client_id,
|
||||
"token_type": "refresh_token",
|
||||
"exp": int(refresh_record.expires_at.timestamp()),
|
||||
"iat": int(refresh_record.created_at.timestamp()),
|
||||
"sub": str(refresh_record.user_id),
|
||||
}
|
||||
|
||||
return {"active": False}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Consent Management
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def get_consent(
|
||||
db: AsyncSession,
|
||||
user_id: UUID,
|
||||
client_id: str,
|
||||
):
|
||||
"""Get existing consent record for user-client pair."""
|
||||
return await oauth_consent_repo.get_consent(
|
||||
db, user_id=user_id, client_id=client_id
|
||||
)
|
||||
|
||||
|
||||
async def check_consent(
|
||||
db: AsyncSession,
|
||||
user_id: UUID,
|
||||
client_id: str,
|
||||
requested_scopes: list[str],
|
||||
) -> bool:
|
||||
"""
|
||||
Check if user has already consented to the requested scopes.
|
||||
|
||||
Returns True if all requested scopes are already granted.
|
||||
"""
|
||||
consent = await get_consent(db, user_id, client_id)
|
||||
if not consent:
|
||||
return False
|
||||
return consent.has_scopes(requested_scopes)
|
||||
|
||||
|
||||
async def grant_consent(
|
||||
db: AsyncSession,
|
||||
user_id: UUID,
|
||||
client_id: str,
|
||||
scopes: list[str],
|
||||
):
|
||||
"""
|
||||
Grant or update consent for a user-client pair.
|
||||
|
||||
If consent already exists, updates the granted scopes.
|
||||
"""
|
||||
return await oauth_consent_repo.grant_consent(
|
||||
db, user_id=user_id, client_id=client_id, scopes=scopes
|
||||
)
|
||||
|
||||
|
||||
async def revoke_consent(
|
||||
db: AsyncSession,
|
||||
user_id: UUID,
|
||||
client_id: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Revoke consent and all tokens for a user-client pair.
|
||||
|
||||
Returns True if consent was found and revoked.
|
||||
"""
|
||||
# Revoke all tokens first
|
||||
await revoke_tokens_for_user_client(db, user_id, client_id)
|
||||
|
||||
# Delete consent record
|
||||
return await oauth_consent_repo.revoke_consent(
|
||||
db, user_id=user_id, client_id=client_id
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Cleanup
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def register_client(db: AsyncSession, client_data: OAuthClientCreate) -> tuple:
|
||||
"""Create a new OAuth client. Returns (client, secret)."""
|
||||
return await oauth_client_repo.create_client(db, obj_in=client_data)
|
||||
|
||||
|
||||
async def list_clients(db: AsyncSession) -> list:
|
||||
"""List all registered OAuth clients."""
|
||||
return await oauth_client_repo.get_all_clients(db)
|
||||
|
||||
|
||||
async def delete_client_by_id(db: AsyncSession, client_id: str) -> None:
|
||||
"""Delete an OAuth client by client_id."""
|
||||
await oauth_client_repo.delete_client(db, client_id=client_id)
|
||||
|
||||
|
||||
async def list_user_consents(db: AsyncSession, user_id: UUID) -> list[dict]:
|
||||
"""Get all OAuth consents for a user with client details."""
|
||||
return await oauth_consent_repo.get_user_consents_with_clients(db, user_id=user_id)
|
||||
|
||||
|
||||
async def cleanup_expired_codes(db: AsyncSession) -> int:
|
||||
"""
|
||||
Delete expired authorization codes.
|
||||
|
||||
Should be called periodically (e.g., every hour).
|
||||
|
||||
Returns:
|
||||
Number of codes deleted
|
||||
"""
|
||||
return await oauth_authorization_code_repo.cleanup_expired(db)
|
||||
|
||||
|
||||
async def cleanup_expired_tokens(db: AsyncSession) -> int:
|
||||
"""
|
||||
Delete expired and revoked refresh tokens.
|
||||
|
||||
Should be called periodically (e.g., daily).
|
||||
|
||||
Returns:
|
||||
Number of tokens deleted
|
||||
"""
|
||||
return await oauth_provider_token_repo.cleanup_expired(db, cutoff_days=7)
|
||||
744
backend/app/services/oauth_service.py
Normal file
744
backend/app/services/oauth_service.py
Normal file
@@ -0,0 +1,744 @@
|
||||
"""
|
||||
OAuth Service for handling social authentication flows.
|
||||
|
||||
Supports:
|
||||
- Google OAuth (OpenID Connect)
|
||||
- GitHub OAuth
|
||||
|
||||
Features:
|
||||
- PKCE support for public clients
|
||||
- State parameter for CSRF protection
|
||||
- Auto-linking by email (configurable)
|
||||
- Account linking for existing users
|
||||
"""
|
||||
|
||||
import logging
|
||||
import secrets
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import TypedDict, cast
|
||||
from uuid import UUID
|
||||
|
||||
from authlib.integrations.httpx_client import AsyncOAuth2Client
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.auth import create_access_token, create_refresh_token
|
||||
from app.core.config import settings
|
||||
from app.core.exceptions import AuthenticationError
|
||||
from app.models.user import User
|
||||
from app.repositories.oauth_account import oauth_account_repo as oauth_account
|
||||
from app.repositories.oauth_state import oauth_state_repo as oauth_state
|
||||
from app.repositories.user import user_repo
|
||||
from app.schemas.oauth import (
|
||||
OAuthAccountCreate,
|
||||
OAuthCallbackResponse,
|
||||
OAuthProviderInfo,
|
||||
OAuthProvidersResponse,
|
||||
OAuthStateCreate,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _OAuthProviderConfigRequired(TypedDict):
|
||||
name: str
|
||||
icon: str
|
||||
authorize_url: str
|
||||
token_url: str
|
||||
userinfo_url: str
|
||||
scopes: list[str]
|
||||
supports_pkce: bool
|
||||
|
||||
|
||||
class OAuthProviderConfig(_OAuthProviderConfigRequired, total=False):
|
||||
"""Type definition for OAuth provider configuration."""
|
||||
|
||||
email_url: str # Optional, GitHub-only
|
||||
|
||||
|
||||
# Provider configurations
|
||||
OAUTH_PROVIDERS: dict[str, OAuthProviderConfig] = {
|
||||
"google": {
|
||||
"name": "Google",
|
||||
"icon": "google",
|
||||
"authorize_url": "https://accounts.google.com/o/oauth2/v2/auth",
|
||||
"token_url": "https://oauth2.googleapis.com/token",
|
||||
"userinfo_url": "https://www.googleapis.com/oauth2/v3/userinfo",
|
||||
"scopes": ["openid", "email", "profile"],
|
||||
"supports_pkce": True,
|
||||
},
|
||||
"github": {
|
||||
"name": "GitHub",
|
||||
"icon": "github",
|
||||
"authorize_url": "https://github.com/login/oauth/authorize",
|
||||
"token_url": "https://github.com/login/oauth/access_token",
|
||||
"userinfo_url": "https://api.github.com/user",
|
||||
"email_url": "https://api.github.com/user/emails",
|
||||
"scopes": ["read:user", "user:email"],
|
||||
"supports_pkce": False, # GitHub doesn't support PKCE
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class OAuthService:
|
||||
"""Service for handling OAuth authentication flows."""
|
||||
|
||||
@staticmethod
|
||||
def get_enabled_providers() -> OAuthProvidersResponse:
|
||||
"""
|
||||
Get list of enabled OAuth providers.
|
||||
|
||||
Returns:
|
||||
OAuthProvidersResponse with enabled providers
|
||||
"""
|
||||
providers = []
|
||||
|
||||
for provider_id in settings.enabled_oauth_providers:
|
||||
if provider_id in OAUTH_PROVIDERS:
|
||||
config = OAUTH_PROVIDERS[provider_id]
|
||||
providers.append(
|
||||
OAuthProviderInfo(
|
||||
provider=provider_id,
|
||||
name=config["name"],
|
||||
icon=config["icon"],
|
||||
)
|
||||
)
|
||||
|
||||
return OAuthProvidersResponse(
|
||||
enabled=settings.OAUTH_ENABLED and len(providers) > 0,
|
||||
providers=providers,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_provider_credentials(provider: str) -> tuple[str, str]:
|
||||
"""Get client ID and secret for a provider."""
|
||||
if provider == "google":
|
||||
client_id = settings.OAUTH_GOOGLE_CLIENT_ID
|
||||
client_secret = settings.OAUTH_GOOGLE_CLIENT_SECRET
|
||||
elif provider == "github":
|
||||
client_id = settings.OAUTH_GITHUB_CLIENT_ID
|
||||
client_secret = settings.OAUTH_GITHUB_CLIENT_SECRET
|
||||
else:
|
||||
raise AuthenticationError(f"Unknown OAuth provider: {provider}")
|
||||
|
||||
if not client_id or not client_secret:
|
||||
raise AuthenticationError(f"OAuth provider {provider} is not configured")
|
||||
|
||||
return client_id, client_secret
|
||||
|
||||
@staticmethod
|
||||
async def create_authorization_url(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
provider: str,
|
||||
redirect_uri: str,
|
||||
user_id: str | None = None,
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Create OAuth authorization URL with state and optional PKCE.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
provider: OAuth provider (google, github)
|
||||
redirect_uri: Callback URL after OAuth
|
||||
user_id: User ID if linking account (user is logged in)
|
||||
|
||||
Returns:
|
||||
Tuple of (authorization_url, state)
|
||||
|
||||
Raises:
|
||||
AuthenticationError: If provider is not configured
|
||||
"""
|
||||
if not settings.OAUTH_ENABLED:
|
||||
raise AuthenticationError("OAuth is not enabled")
|
||||
|
||||
if provider not in OAUTH_PROVIDERS:
|
||||
raise AuthenticationError(f"Unknown OAuth provider: {provider}")
|
||||
|
||||
if provider not in settings.enabled_oauth_providers:
|
||||
raise AuthenticationError(f"OAuth provider {provider} is not enabled")
|
||||
|
||||
config = OAUTH_PROVIDERS[provider]
|
||||
client_id, client_secret = OAuthService._get_provider_credentials(provider)
|
||||
|
||||
# Generate state for CSRF protection
|
||||
state = secrets.token_urlsafe(32)
|
||||
|
||||
# Generate PKCE code verifier and challenge if supported
|
||||
code_verifier = None
|
||||
code_challenge = None
|
||||
if config.get("supports_pkce"):
|
||||
code_verifier = secrets.token_urlsafe(64)
|
||||
# Create code_challenge using S256 method
|
||||
import base64
|
||||
import hashlib
|
||||
|
||||
code_challenge_bytes = hashlib.sha256(code_verifier.encode()).digest()
|
||||
code_challenge = (
|
||||
base64.urlsafe_b64encode(code_challenge_bytes).decode().rstrip("=")
|
||||
)
|
||||
|
||||
# Generate nonce for OIDC (Google)
|
||||
nonce = secrets.token_urlsafe(32) if provider == "google" else None
|
||||
|
||||
# Store state in database
|
||||
from uuid import UUID
|
||||
|
||||
state_data = OAuthStateCreate(
|
||||
state=state,
|
||||
code_verifier=code_verifier,
|
||||
nonce=nonce,
|
||||
provider=provider,
|
||||
redirect_uri=redirect_uri,
|
||||
user_id=UUID(user_id) if user_id else None,
|
||||
expires_at=datetime.now(UTC)
|
||||
+ timedelta(minutes=settings.OAUTH_STATE_EXPIRE_MINUTES),
|
||||
)
|
||||
await oauth_state.create_state(db, obj_in=state_data)
|
||||
|
||||
# Build authorization URL
|
||||
async with AsyncOAuth2Client(
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
redirect_uri=redirect_uri,
|
||||
) as client:
|
||||
# Prepare authorization params
|
||||
auth_params = {
|
||||
"state": state,
|
||||
"scope": " ".join(config["scopes"]),
|
||||
}
|
||||
|
||||
if code_challenge:
|
||||
auth_params["code_challenge"] = code_challenge
|
||||
auth_params["code_challenge_method"] = "S256"
|
||||
|
||||
if nonce:
|
||||
auth_params["nonce"] = nonce
|
||||
|
||||
url, _ = client.create_authorization_url(
|
||||
config["authorize_url"],
|
||||
**auth_params,
|
||||
)
|
||||
|
||||
logger.info("OAuth authorization URL created for %s", provider)
|
||||
return url, state
|
||||
|
||||
@staticmethod
|
||||
async def handle_callback(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
code: str,
|
||||
state: str,
|
||||
redirect_uri: str,
|
||||
) -> OAuthCallbackResponse:
|
||||
"""
|
||||
Handle OAuth callback and authenticate/create user.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
code: Authorization code from provider
|
||||
state: State parameter for CSRF verification
|
||||
redirect_uri: Callback URL (must match authorization request)
|
||||
|
||||
Returns:
|
||||
OAuthCallbackResponse with tokens
|
||||
|
||||
Raises:
|
||||
AuthenticationError: If authentication fails
|
||||
"""
|
||||
# Validate and consume state
|
||||
state_record = await oauth_state.get_and_consume_state(db, state=state)
|
||||
if not state_record:
|
||||
raise AuthenticationError("Invalid or expired OAuth state")
|
||||
|
||||
# SECURITY: Validate redirect_uri matches the one from authorization request
|
||||
# This prevents authorization code injection attacks (RFC 6749 Section 10.6)
|
||||
if state_record.redirect_uri != redirect_uri:
|
||||
logger.warning(
|
||||
"OAuth redirect_uri mismatch: expected %s, got %s",
|
||||
state_record.redirect_uri,
|
||||
redirect_uri,
|
||||
)
|
||||
raise AuthenticationError("Redirect URI mismatch")
|
||||
|
||||
# Extract provider from state record (str for type safety)
|
||||
provider: str = str(state_record.provider)
|
||||
|
||||
if provider not in OAUTH_PROVIDERS:
|
||||
raise AuthenticationError(f"Unknown OAuth provider: {provider}")
|
||||
|
||||
config = OAUTH_PROVIDERS[provider]
|
||||
client_id, client_secret = OAuthService._get_provider_credentials(provider)
|
||||
|
||||
# Exchange code for tokens
|
||||
async with AsyncOAuth2Client(
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
redirect_uri=redirect_uri,
|
||||
) as client:
|
||||
try:
|
||||
# Prepare token request params
|
||||
token_params: dict[str, str] = {"code": code}
|
||||
|
||||
if state_record.code_verifier:
|
||||
token_params["code_verifier"] = str(state_record.code_verifier)
|
||||
|
||||
token = await client.fetch_token(
|
||||
config["token_url"],
|
||||
**token_params,
|
||||
)
|
||||
|
||||
# SECURITY: Validate ID token signature and nonce for OpenID Connect
|
||||
# This prevents token forgery and replay attacks (OIDC Core 3.1.3.7)
|
||||
if provider == "google" and state_record.nonce:
|
||||
id_token = token.get("id_token")
|
||||
if id_token:
|
||||
await OAuthService._verify_google_id_token(
|
||||
id_token=str(id_token),
|
||||
expected_nonce=str(state_record.nonce),
|
||||
client_id=client_id,
|
||||
)
|
||||
except AuthenticationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("OAuth token exchange failed: %s", e)
|
||||
raise AuthenticationError("Failed to exchange authorization code")
|
||||
|
||||
# Get user info from provider
|
||||
try:
|
||||
access_token = token.get("access_token")
|
||||
if not access_token:
|
||||
raise AuthenticationError("No access token received")
|
||||
|
||||
user_info = await OAuthService._get_user_info(
|
||||
client, provider, config, access_token
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get user info: %s", e)
|
||||
raise AuthenticationError(
|
||||
"Failed to get user information from provider"
|
||||
)
|
||||
|
||||
# Process user info and create/link account
|
||||
provider_user_id = str(user_info.get("id") or user_info.get("sub"))
|
||||
# Email can be None if user didn't grant email permission
|
||||
# SECURITY: Normalize email (lowercase, strip) to prevent case-based account duplication
|
||||
email_raw = user_info.get("email")
|
||||
provider_email: str | None = (
|
||||
str(email_raw).lower().strip() if email_raw else None
|
||||
)
|
||||
|
||||
if not provider_user_id:
|
||||
raise AuthenticationError("Provider did not return user ID")
|
||||
|
||||
# Check if this OAuth account already exists
|
||||
existing_oauth = await oauth_account.get_by_provider_id(
|
||||
db, provider=provider, provider_user_id=provider_user_id
|
||||
)
|
||||
|
||||
is_new_user = False
|
||||
|
||||
if existing_oauth:
|
||||
# Existing OAuth account - login
|
||||
user = existing_oauth.user
|
||||
if not user.is_active:
|
||||
raise AuthenticationError("User account is inactive")
|
||||
|
||||
# Update tokens if stored
|
||||
if token.get("access_token"):
|
||||
await oauth_account.update_tokens(
|
||||
db,
|
||||
account=existing_oauth,
|
||||
access_token=token.get("access_token"),
|
||||
refresh_token=token.get("refresh_token"),
|
||||
token_expires_at=datetime.now(UTC)
|
||||
+ timedelta(seconds=token.get("expires_in", 3600)),
|
||||
)
|
||||
|
||||
logger.info("OAuth login successful for %s via %s", user.email, provider)
|
||||
|
||||
elif state_record.user_id:
|
||||
# Account linking flow (user is already logged in)
|
||||
user = await user_repo.get(db, id=str(state_record.user_id))
|
||||
|
||||
if not user:
|
||||
raise AuthenticationError("User not found for account linking")
|
||||
|
||||
# Check if user already has this provider linked
|
||||
user_id = cast(UUID, user.id)
|
||||
existing_provider = await oauth_account.get_user_account_by_provider(
|
||||
db, user_id=user_id, provider=provider
|
||||
)
|
||||
if existing_provider:
|
||||
raise AuthenticationError(
|
||||
f"You already have a {provider} account linked"
|
||||
)
|
||||
|
||||
# Create OAuth account link
|
||||
oauth_create = OAuthAccountCreate(
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
provider_user_id=provider_user_id,
|
||||
provider_email=provider_email,
|
||||
access_token=token.get("access_token"),
|
||||
refresh_token=token.get("refresh_token"),
|
||||
token_expires_at=datetime.now(UTC)
|
||||
+ timedelta(seconds=token.get("expires_in", 3600))
|
||||
if token.get("expires_in")
|
||||
else None,
|
||||
)
|
||||
await oauth_account.create_account(db, obj_in=oauth_create)
|
||||
|
||||
logger.info("OAuth account linked: %s -> %s", provider, user.email)
|
||||
|
||||
else:
|
||||
# New OAuth login - check for existing user by email
|
||||
user = None
|
||||
|
||||
if provider_email and settings.OAUTH_AUTO_LINK_BY_EMAIL:
|
||||
user = await user_repo.get_by_email(db, email=provider_email)
|
||||
|
||||
if user:
|
||||
# Auto-link to existing user
|
||||
if not user.is_active:
|
||||
raise AuthenticationError("User account is inactive")
|
||||
|
||||
# Check if user already has this provider linked
|
||||
user_id = cast(UUID, user.id)
|
||||
existing_provider = await oauth_account.get_user_account_by_provider(
|
||||
db, user_id=user_id, provider=provider
|
||||
)
|
||||
if existing_provider:
|
||||
# This shouldn't happen if we got here, but safety check
|
||||
logger.warning(
|
||||
"OAuth account already linked (race condition?): %s -> %s",
|
||||
provider,
|
||||
user.email,
|
||||
)
|
||||
else:
|
||||
# Create OAuth account link
|
||||
oauth_create = OAuthAccountCreate(
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
provider_user_id=provider_user_id,
|
||||
provider_email=provider_email,
|
||||
access_token=token.get("access_token"),
|
||||
refresh_token=token.get("refresh_token"),
|
||||
token_expires_at=datetime.now(UTC)
|
||||
+ timedelta(seconds=token.get("expires_in", 3600))
|
||||
if token.get("expires_in")
|
||||
else None,
|
||||
)
|
||||
await oauth_account.create_account(db, obj_in=oauth_create)
|
||||
|
||||
logger.info(
|
||||
"OAuth auto-linked by email: %s -> %s", provider, user.email
|
||||
)
|
||||
|
||||
else:
|
||||
# Create new user
|
||||
if not provider_email:
|
||||
raise AuthenticationError(
|
||||
f"Email is required for registration. "
|
||||
f"Please grant email permission to {provider}."
|
||||
)
|
||||
|
||||
user = await OAuthService._create_oauth_user(
|
||||
db,
|
||||
email=provider_email,
|
||||
provider=provider,
|
||||
provider_user_id=provider_user_id,
|
||||
user_info=user_info,
|
||||
token=token,
|
||||
)
|
||||
is_new_user = True
|
||||
|
||||
logger.info("New user created via OAuth: %s (%s)", user.email, provider)
|
||||
|
||||
# Generate JWT tokens
|
||||
claims = {
|
||||
"is_superuser": user.is_superuser,
|
||||
"email": user.email,
|
||||
"first_name": user.first_name,
|
||||
}
|
||||
|
||||
access_token_jwt = create_access_token(subject=str(user.id), claims=claims)
|
||||
refresh_token_jwt = create_refresh_token(subject=str(user.id))
|
||||
|
||||
return OAuthCallbackResponse(
|
||||
access_token=access_token_jwt,
|
||||
refresh_token=refresh_token_jwt,
|
||||
token_type="bearer",
|
||||
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||||
is_new_user=is_new_user,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _get_user_info(
|
||||
client: AsyncOAuth2Client,
|
||||
provider: str,
|
||||
config: OAuthProviderConfig,
|
||||
access_token: str,
|
||||
) -> dict[str, object]:
|
||||
"""Get user info from OAuth provider."""
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
|
||||
if provider == "github":
|
||||
# GitHub returns JSON with Accept header
|
||||
headers["Accept"] = "application/vnd.github+json"
|
||||
|
||||
resp = await client.get(config["userinfo_url"], headers=headers)
|
||||
resp.raise_for_status()
|
||||
user_info = resp.json()
|
||||
|
||||
# GitHub requires separate request for email
|
||||
if provider == "github" and not user_info.get("email"):
|
||||
email_resp = await client.get(
|
||||
config["email_url"], # pyright: ignore[reportTypedDictNotRequiredAccess]
|
||||
headers=headers,
|
||||
)
|
||||
email_resp.raise_for_status()
|
||||
emails = email_resp.json()
|
||||
|
||||
# Find primary verified email
|
||||
for email_data in emails:
|
||||
if email_data.get("primary") and email_data.get("verified"):
|
||||
user_info["email"] = email_data["email"]
|
||||
break
|
||||
|
||||
return user_info
|
||||
|
||||
# Google's OIDC configuration endpoints
|
||||
GOOGLE_JWKS_URL = "https://www.googleapis.com/oauth2/v3/certs"
|
||||
GOOGLE_ISSUERS = ("https://accounts.google.com", "accounts.google.com")
|
||||
|
||||
@staticmethod
|
||||
async def _verify_google_id_token(
|
||||
id_token: str,
|
||||
expected_nonce: str,
|
||||
client_id: str,
|
||||
) -> dict[str, object]:
|
||||
"""
|
||||
Verify Google ID token signature and claims.
|
||||
|
||||
SECURITY: This properly verifies the ID token by:
|
||||
1. Fetching Google's public keys (JWKS)
|
||||
2. Verifying the JWT signature against the public key
|
||||
3. Validating issuer, audience, expiry, and nonce claims
|
||||
|
||||
Args:
|
||||
id_token: The ID token JWT string
|
||||
expected_nonce: The nonce we sent in the authorization request
|
||||
client_id: Our OAuth client ID (expected audience)
|
||||
|
||||
Returns:
|
||||
Decoded ID token payload
|
||||
|
||||
Raises:
|
||||
AuthenticationError: If verification fails
|
||||
"""
|
||||
import httpx
|
||||
import jwt as pyjwt
|
||||
from jwt.algorithms import RSAAlgorithm
|
||||
from jwt.exceptions import InvalidTokenError
|
||||
|
||||
try:
|
||||
# Fetch Google's public keys (JWKS)
|
||||
# In production, consider caching this with TTL matching Cache-Control header
|
||||
async with httpx.AsyncClient() as client:
|
||||
jwks_response = await client.get(
|
||||
OAuthService.GOOGLE_JWKS_URL,
|
||||
timeout=10.0,
|
||||
)
|
||||
jwks_response.raise_for_status()
|
||||
jwks = jwks_response.json()
|
||||
|
||||
# Get the key ID from the token header
|
||||
unverified_header = pyjwt.get_unverified_header(id_token)
|
||||
kid = unverified_header.get("kid")
|
||||
if not kid:
|
||||
raise AuthenticationError("ID token missing key ID (kid)")
|
||||
|
||||
# Find the matching public key
|
||||
jwk_data = None
|
||||
for key in jwks.get("keys", []):
|
||||
if key.get("kid") == kid:
|
||||
jwk_data = key
|
||||
break
|
||||
|
||||
if not jwk_data:
|
||||
raise AuthenticationError("ID token signed with unknown key")
|
||||
|
||||
# Convert JWK to a public key object for PyJWT
|
||||
public_key = RSAAlgorithm.from_jwk(jwk_data)
|
||||
|
||||
# Verify the token signature and decode claims
|
||||
# PyJWT will verify signature against the RSA public key
|
||||
payload = pyjwt.decode(
|
||||
id_token,
|
||||
public_key,
|
||||
algorithms=["RS256"], # Google uses RS256
|
||||
audience=client_id,
|
||||
issuer=OAuthService.GOOGLE_ISSUERS,
|
||||
options={
|
||||
"verify_signature": True,
|
||||
"verify_aud": True,
|
||||
"verify_iss": True,
|
||||
"verify_exp": True,
|
||||
"verify_iat": True,
|
||||
},
|
||||
)
|
||||
|
||||
# Verify nonce (OIDC replay attack protection)
|
||||
token_nonce = payload.get("nonce")
|
||||
if token_nonce != expected_nonce:
|
||||
logger.warning(
|
||||
"OAuth ID token nonce mismatch: expected %s, got %s",
|
||||
expected_nonce,
|
||||
token_nonce,
|
||||
)
|
||||
raise AuthenticationError("Invalid ID token nonce")
|
||||
|
||||
logger.debug("Google ID token verified successfully")
|
||||
return payload
|
||||
|
||||
except InvalidTokenError as e:
|
||||
logger.warning("Google ID token verification failed: %s", e)
|
||||
raise AuthenticationError("Invalid ID token signature")
|
||||
except httpx.HTTPError as e:
|
||||
logger.error("Failed to fetch Google JWKS: %s", e)
|
||||
# If we can't verify the ID token, fail closed for security
|
||||
raise AuthenticationError("Failed to verify ID token")
|
||||
except Exception as e:
|
||||
logger.error("Unexpected error verifying Google ID token: %s", e)
|
||||
raise AuthenticationError("ID token verification error")
|
||||
|
||||
@staticmethod
|
||||
async def _create_oauth_user(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
email: str,
|
||||
provider: str,
|
||||
provider_user_id: str,
|
||||
user_info: dict,
|
||||
token: dict,
|
||||
) -> User:
|
||||
"""Create a new user from OAuth provider data."""
|
||||
# Extract name from user_info
|
||||
first_name = "User"
|
||||
last_name = None
|
||||
|
||||
if provider == "google":
|
||||
first_name = user_info.get("given_name") or user_info.get("name", "User")
|
||||
last_name = user_info.get("family_name")
|
||||
elif provider == "github":
|
||||
# GitHub has full name, try to split
|
||||
name = user_info.get("name") or user_info.get("login", "User")
|
||||
parts = name.split(" ", 1)
|
||||
first_name = parts[0]
|
||||
last_name = parts[1] if len(parts) > 1 else None
|
||||
|
||||
# Create user (no password for OAuth-only users)
|
||||
user = User(
|
||||
email=email,
|
||||
password_hash=None, # OAuth-only user
|
||||
first_name=first_name,
|
||||
last_name=last_name,
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
)
|
||||
db.add(user)
|
||||
await db.flush() # Get user.id
|
||||
|
||||
# Create OAuth account link
|
||||
user_id = cast(UUID, user.id)
|
||||
oauth_create = OAuthAccountCreate(
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
provider_user_id=provider_user_id,
|
||||
provider_email=email,
|
||||
access_token=token.get("access_token"),
|
||||
refresh_token=token.get("refresh_token"),
|
||||
token_expires_at=datetime.now(UTC)
|
||||
+ timedelta(seconds=token.get("expires_in", 3600))
|
||||
if token.get("expires_in")
|
||||
else None,
|
||||
)
|
||||
await oauth_account.create_account(db, obj_in=oauth_create)
|
||||
|
||||
await db.refresh(user)
|
||||
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
async def unlink_provider(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user: User,
|
||||
provider: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Unlink an OAuth provider from a user account.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user: User to unlink from
|
||||
provider: Provider to unlink
|
||||
|
||||
Returns:
|
||||
True if unlinked successfully
|
||||
|
||||
Raises:
|
||||
AuthenticationError: If unlinking would leave user without login method
|
||||
"""
|
||||
# Check if user can safely remove this OAuth account
|
||||
# Note: We query directly instead of using user.can_remove_oauth property
|
||||
# because the property uses lazy loading which doesn't work in async context
|
||||
user_id = cast(UUID, user.id)
|
||||
has_password = user.password_hash is not None
|
||||
oauth_accounts = await oauth_account.get_user_accounts(db, user_id=user_id)
|
||||
can_remove = has_password or len(oauth_accounts) > 1
|
||||
|
||||
if not can_remove:
|
||||
raise AuthenticationError(
|
||||
"Cannot unlink OAuth account. You must have either a password set "
|
||||
"or at least one other OAuth provider linked."
|
||||
)
|
||||
|
||||
deleted = await oauth_account.delete_account(
|
||||
db, user_id=user_id, provider=provider
|
||||
)
|
||||
|
||||
if not deleted:
|
||||
raise AuthenticationError(f"No {provider} account found to unlink")
|
||||
|
||||
logger.info("OAuth provider unlinked: %s from %s", provider, user.email)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def get_user_accounts(db: AsyncSession, *, user_id: UUID) -> list:
|
||||
"""Get all OAuth accounts linked to a user."""
|
||||
return await oauth_account.get_user_accounts(db, user_id=user_id)
|
||||
|
||||
@staticmethod
|
||||
async def get_user_account_by_provider(
|
||||
db: AsyncSession, *, user_id: UUID, provider: str
|
||||
):
|
||||
"""Get a specific OAuth account for a user and provider."""
|
||||
return await oauth_account.get_user_account_by_provider(
|
||||
db, user_id=user_id, provider=provider
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def cleanup_expired_states(db: AsyncSession) -> int:
|
||||
"""
|
||||
Clean up expired OAuth states.
|
||||
|
||||
Should be called periodically (e.g., by a background task).
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Number of states cleaned up
|
||||
"""
|
||||
return await oauth_state.cleanup_expired(db)
|
||||
155
backend/app/services/organization_service.py
Normal file
155
backend/app/services/organization_service.py
Normal file
@@ -0,0 +1,155 @@
|
||||
# app/services/organization_service.py
|
||||
"""Service layer for organization operations — delegates to OrganizationRepository."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.exceptions import NotFoundError
|
||||
from app.models.organization import Organization
|
||||
from app.models.user_organization import OrganizationRole, UserOrganization
|
||||
from app.repositories.organization import OrganizationRepository, organization_repo
|
||||
from app.schemas.organizations import OrganizationCreate, OrganizationUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OrganizationService:
|
||||
"""Service for organization management operations."""
|
||||
|
||||
def __init__(
|
||||
self, organization_repository: OrganizationRepository | None = None
|
||||
) -> None:
|
||||
self._repo = organization_repository or organization_repo
|
||||
|
||||
async def get_organization(self, db: AsyncSession, org_id: str) -> Organization:
|
||||
"""Get organization by ID, raising NotFoundError if not found."""
|
||||
org = await self._repo.get(db, id=org_id)
|
||||
if not org:
|
||||
raise NotFoundError(f"Organization {org_id} not found")
|
||||
return org
|
||||
|
||||
async def create_organization(
|
||||
self, db: AsyncSession, *, obj_in: OrganizationCreate
|
||||
) -> Organization:
|
||||
"""Create a new organization."""
|
||||
return await self._repo.create(db, obj_in=obj_in)
|
||||
|
||||
async def update_organization(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
org: Organization,
|
||||
obj_in: OrganizationUpdate | dict[str, Any],
|
||||
) -> Organization:
|
||||
"""Update an existing organization."""
|
||||
return await self._repo.update(db, db_obj=org, obj_in=obj_in)
|
||||
|
||||
async def remove_organization(self, db: AsyncSession, org_id: str) -> None:
|
||||
"""Permanently delete an organization by ID."""
|
||||
await self._repo.remove(db, id=org_id)
|
||||
|
||||
async def get_member_count(self, db: AsyncSession, *, organization_id: UUID) -> int:
|
||||
"""Get number of active members in an organization."""
|
||||
return await self._repo.get_member_count(db, organization_id=organization_id)
|
||||
|
||||
async def get_multi_with_member_counts(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: bool | None = None,
|
||||
search: str | None = None,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""List organizations with member counts and pagination."""
|
||||
return await self._repo.get_multi_with_member_counts(
|
||||
db, skip=skip, limit=limit, is_active=is_active, search=search
|
||||
)
|
||||
|
||||
async def get_user_organizations_with_details(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: UUID,
|
||||
is_active: bool | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get all organizations a user belongs to, with membership details."""
|
||||
return await self._repo.get_user_organizations_with_details(
|
||||
db, user_id=user_id, is_active=is_active
|
||||
)
|
||||
|
||||
async def get_organization_members(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
organization_id: UUID,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: bool | None = True,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""Get members of an organization with pagination."""
|
||||
return await self._repo.get_organization_members(
|
||||
db,
|
||||
organization_id=organization_id,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
is_active=is_active,
|
||||
)
|
||||
|
||||
async def add_member(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
organization_id: UUID,
|
||||
user_id: UUID,
|
||||
role: OrganizationRole = OrganizationRole.MEMBER,
|
||||
) -> UserOrganization:
|
||||
"""Add a user to an organization."""
|
||||
return await self._repo.add_user(
|
||||
db, organization_id=organization_id, user_id=user_id, role=role
|
||||
)
|
||||
|
||||
async def remove_member(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
organization_id: UUID,
|
||||
user_id: UUID,
|
||||
) -> bool:
|
||||
"""Remove a user from an organization. Returns True if found and removed."""
|
||||
return await self._repo.remove_user(
|
||||
db, organization_id=organization_id, user_id=user_id
|
||||
)
|
||||
|
||||
async def get_user_role_in_org(
|
||||
self, db: AsyncSession, *, user_id: UUID, organization_id: UUID
|
||||
) -> OrganizationRole | None:
|
||||
"""Get the role of a user in an organization."""
|
||||
return await self._repo.get_user_role_in_org(
|
||||
db, user_id=user_id, organization_id=organization_id
|
||||
)
|
||||
|
||||
async def get_org_distribution(
|
||||
self, db: AsyncSession, *, limit: int = 6
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Return top organizations by member count for admin dashboard."""
|
||||
from sqlalchemy import func, select
|
||||
|
||||
result = await db.execute(
|
||||
select(
|
||||
Organization.name,
|
||||
func.count(UserOrganization.user_id).label("count"),
|
||||
)
|
||||
.join(UserOrganization, Organization.id == UserOrganization.organization_id)
|
||||
.group_by(Organization.name)
|
||||
.order_by(func.count(UserOrganization.user_id).desc())
|
||||
.limit(limit)
|
||||
)
|
||||
return [{"name": row.name, "value": row.count} for row in result.all()]
|
||||
|
||||
|
||||
# Default singleton
|
||||
organization_service = OrganizationService()
|
||||
@@ -3,11 +3,12 @@ Background job for cleaning up expired sessions.
|
||||
|
||||
This service runs periodically to remove old session records from the database.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from app.core.database import SessionLocal
|
||||
from app.crud.session import session as session_crud
|
||||
from app.repositories.session import session_repo as session_repo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -31,15 +32,15 @@ async def cleanup_expired_sessions(keep_days: int = 30) -> int:
|
||||
|
||||
async with SessionLocal() as db:
|
||||
try:
|
||||
# Use CRUD method to cleanup
|
||||
count = await session_crud.cleanup_expired(db, keep_days=keep_days)
|
||||
# Use repository method to cleanup
|
||||
count = await session_repo.cleanup_expired(db, keep_days=keep_days)
|
||||
|
||||
logger.info(f"Session cleanup complete: {count} sessions deleted")
|
||||
logger.info("Session cleanup complete: %s sessions deleted", count)
|
||||
|
||||
return count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during session cleanup: {str(e)}", exc_info=True)
|
||||
logger.exception("Error during session cleanup: %s", e)
|
||||
return 0
|
||||
|
||||
|
||||
@@ -52,20 +53,21 @@ async def get_session_statistics() -> dict:
|
||||
"""
|
||||
async with SessionLocal() as db:
|
||||
try:
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from app.models.user_session import UserSession
|
||||
from sqlalchemy import select, func
|
||||
|
||||
total_result = await db.execute(select(func.count(UserSession.id)))
|
||||
total_sessions = total_result.scalar_one()
|
||||
|
||||
active_result = await db.execute(
|
||||
select(func.count(UserSession.id)).where(UserSession.is_active == True)
|
||||
select(func.count(UserSession.id)).where(UserSession.is_active)
|
||||
)
|
||||
active_sessions = active_result.scalar_one()
|
||||
|
||||
expired_result = await db.execute(
|
||||
select(func.count(UserSession.id)).where(
|
||||
UserSession.expires_at < datetime.now(timezone.utc)
|
||||
UserSession.expires_at < datetime.now(UTC)
|
||||
)
|
||||
)
|
||||
expired_sessions = expired_result.scalar_one()
|
||||
@@ -77,10 +79,10 @@ async def get_session_statistics() -> dict:
|
||||
"expired": expired_sessions,
|
||||
}
|
||||
|
||||
logger.info(f"Session statistics: {stats}")
|
||||
logger.info("Session statistics: %s", stats)
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting session statistics: {str(e)}", exc_info=True)
|
||||
logger.exception("Error getting session statistics: %s", e)
|
||||
return {}
|
||||
|
||||
97
backend/app/services/session_service.py
Normal file
97
backend/app/services/session_service.py
Normal file
@@ -0,0 +1,97 @@
|
||||
# app/services/session_service.py
|
||||
"""Service layer for session operations — delegates to SessionRepository."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.user_session import UserSession
|
||||
from app.repositories.session import SessionRepository, session_repo
|
||||
from app.schemas.sessions import SessionCreate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SessionService:
|
||||
"""Service for user session management operations."""
|
||||
|
||||
def __init__(self, session_repository: SessionRepository | None = None) -> None:
|
||||
self._repo = session_repository or session_repo
|
||||
|
||||
async def create_session(
|
||||
self, db: AsyncSession, *, obj_in: SessionCreate
|
||||
) -> UserSession:
|
||||
"""Create a new session record."""
|
||||
return await self._repo.create_session(db, obj_in=obj_in)
|
||||
|
||||
async def get_session(
|
||||
self, db: AsyncSession, session_id: str
|
||||
) -> UserSession | None:
|
||||
"""Get session by ID."""
|
||||
return await self._repo.get(db, id=session_id)
|
||||
|
||||
async def get_user_sessions(
|
||||
self, db: AsyncSession, *, user_id: str, active_only: bool = True
|
||||
) -> list[UserSession]:
|
||||
"""Get all sessions for a user."""
|
||||
return await self._repo.get_user_sessions(
|
||||
db, user_id=user_id, active_only=active_only
|
||||
)
|
||||
|
||||
async def get_active_by_jti(
|
||||
self, db: AsyncSession, *, jti: str
|
||||
) -> UserSession | None:
|
||||
"""Get active session by refresh token JTI."""
|
||||
return await self._repo.get_active_by_jti(db, jti=jti)
|
||||
|
||||
async def get_by_jti(self, db: AsyncSession, *, jti: str) -> UserSession | None:
|
||||
"""Get session by refresh token JTI (active or inactive)."""
|
||||
return await self._repo.get_by_jti(db, jti=jti)
|
||||
|
||||
async def deactivate(
|
||||
self, db: AsyncSession, *, session_id: str
|
||||
) -> UserSession | None:
|
||||
"""Deactivate a session (logout from device)."""
|
||||
return await self._repo.deactivate(db, session_id=session_id)
|
||||
|
||||
async def deactivate_all_user_sessions(
|
||||
self, db: AsyncSession, *, user_id: str
|
||||
) -> int:
|
||||
"""Deactivate all sessions for a user. Returns count deactivated."""
|
||||
return await self._repo.deactivate_all_user_sessions(db, user_id=user_id)
|
||||
|
||||
async def update_refresh_token(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
session: UserSession,
|
||||
new_jti: str,
|
||||
new_expires_at: datetime,
|
||||
) -> UserSession:
|
||||
"""Update session with a rotated refresh token."""
|
||||
return await self._repo.update_refresh_token(
|
||||
db, session=session, new_jti=new_jti, new_expires_at=new_expires_at
|
||||
)
|
||||
|
||||
async def cleanup_expired_for_user(self, db: AsyncSession, *, user_id: str) -> int:
|
||||
"""Remove expired sessions for a user. Returns count removed."""
|
||||
return await self._repo.cleanup_expired_for_user(db, user_id=user_id)
|
||||
|
||||
async def get_all_sessions(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
active_only: bool = True,
|
||||
with_user: bool = True,
|
||||
) -> tuple[list[UserSession], int]:
|
||||
"""Get all sessions with pagination (admin only)."""
|
||||
return await self._repo.get_all_sessions(
|
||||
db, skip=skip, limit=limit, active_only=active_only, with_user=with_user
|
||||
)
|
||||
|
||||
|
||||
# Default singleton
|
||||
session_service = SessionService()
|
||||
120
backend/app/services/user_service.py
Normal file
120
backend/app/services/user_service.py
Normal file
@@ -0,0 +1,120 @@
|
||||
# app/services/user_service.py
|
||||
"""Service layer for user operations — delegates to UserRepository."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.exceptions import NotFoundError
|
||||
from app.models.user import User
|
||||
from app.repositories.user import UserRepository, user_repo
|
||||
from app.schemas.users import UserCreate, UserUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UserService:
|
||||
"""Service for user management operations."""
|
||||
|
||||
def __init__(self, user_repository: UserRepository | None = None) -> None:
|
||||
self._repo = user_repository or user_repo
|
||||
|
||||
async def get_user(self, db: AsyncSession, user_id: str) -> User:
|
||||
"""Get user by ID, raising NotFoundError if not found."""
|
||||
user = await self._repo.get(db, id=user_id)
|
||||
if not user:
|
||||
raise NotFoundError(f"User {user_id} not found")
|
||||
return user
|
||||
|
||||
async def get_by_email(self, db: AsyncSession, email: str) -> User | None:
|
||||
"""Get user by email address."""
|
||||
return await self._repo.get_by_email(db, email=email)
|
||||
|
||||
async def create_user(self, db: AsyncSession, user_data: UserCreate) -> User:
|
||||
"""Create a new user."""
|
||||
return await self._repo.create(db, obj_in=user_data)
|
||||
|
||||
async def update_user(
|
||||
self, db: AsyncSession, *, user: User, obj_in: UserUpdate | dict[str, Any]
|
||||
) -> User:
|
||||
"""Update an existing user."""
|
||||
return await self._repo.update(db, db_obj=user, obj_in=obj_in)
|
||||
|
||||
async def soft_delete_user(self, db: AsyncSession, user_id: str) -> None:
|
||||
"""Soft-delete a user by ID."""
|
||||
await self._repo.soft_delete(db, id=user_id)
|
||||
|
||||
async def list_users(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
sort_by: str | None = None,
|
||||
sort_order: str = "asc",
|
||||
filters: dict[str, Any] | None = None,
|
||||
search: str | None = None,
|
||||
) -> tuple[list[User], int]:
|
||||
"""List users with pagination, sorting, filtering, and search."""
|
||||
return await self._repo.get_multi_with_total(
|
||||
db,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
sort_by=sort_by,
|
||||
sort_order=sort_order,
|
||||
filters=filters,
|
||||
search=search,
|
||||
)
|
||||
|
||||
async def bulk_update_status(
|
||||
self, db: AsyncSession, *, user_ids: list[UUID], is_active: bool
|
||||
) -> int:
|
||||
"""Bulk update active status for multiple users. Returns count updated."""
|
||||
return await self._repo.bulk_update_status(
|
||||
db, user_ids=user_ids, is_active=is_active
|
||||
)
|
||||
|
||||
async def bulk_soft_delete(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_ids: list[UUID],
|
||||
exclude_user_id: UUID | None = None,
|
||||
) -> int:
|
||||
"""Bulk soft-delete multiple users. Returns count deleted."""
|
||||
return await self._repo.bulk_soft_delete(
|
||||
db, user_ids=user_ids, exclude_user_id=exclude_user_id
|
||||
)
|
||||
|
||||
async def get_stats(self, db: AsyncSession) -> dict[str, Any]:
|
||||
"""Return user stats needed for the admin dashboard."""
|
||||
from sqlalchemy import func, select
|
||||
|
||||
total_users = (
|
||||
await db.execute(select(func.count()).select_from(User))
|
||||
).scalar() or 0
|
||||
active_count = (
|
||||
await db.execute(
|
||||
select(func.count()).select_from(User).where(User.is_active)
|
||||
)
|
||||
).scalar() or 0
|
||||
inactive_count = (
|
||||
await db.execute(
|
||||
select(func.count()).select_from(User).where(User.is_active.is_(False))
|
||||
)
|
||||
).scalar() or 0
|
||||
all_users = list(
|
||||
(await db.execute(select(User).order_by(User.created_at))).scalars().all()
|
||||
)
|
||||
return {
|
||||
"total_users": total_users,
|
||||
"active_count": active_count,
|
||||
"inactive_count": inactive_count,
|
||||
"all_users": all_users,
|
||||
}
|
||||
|
||||
|
||||
# Default singleton
|
||||
user_service = UserService()
|
||||
@@ -2,7 +2,8 @@
|
||||
Authentication utilities for testing.
|
||||
This module provides tools to bypass FastAPI's authentication in tests.
|
||||
"""
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
@@ -13,9 +14,9 @@ from app.models.user import User
|
||||
|
||||
|
||||
def create_test_auth_client(
|
||||
app: FastAPI,
|
||||
test_user: User,
|
||||
extra_overrides: Optional[Dict[Callable, Callable]] = None
|
||||
app: FastAPI,
|
||||
test_user: User,
|
||||
extra_overrides: dict[Callable, Callable] | None = None,
|
||||
) -> TestClient:
|
||||
"""
|
||||
Create a test client with authentication pre-configured.
|
||||
@@ -47,10 +48,7 @@ def create_test_auth_client(
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def create_test_optional_auth_client(
|
||||
app: FastAPI,
|
||||
test_user: User
|
||||
) -> TestClient:
|
||||
def create_test_optional_auth_client(app: FastAPI, test_user: User) -> TestClient:
|
||||
"""
|
||||
Create a test client with optional authentication pre-configured.
|
||||
|
||||
@@ -70,10 +68,7 @@ def create_test_optional_auth_client(
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def create_test_superuser_client(
|
||||
app: FastAPI,
|
||||
test_user: User
|
||||
) -> TestClient:
|
||||
def create_test_superuser_client(app: FastAPI, test_user: User) -> TestClient:
|
||||
"""
|
||||
Create a test client with superuser authentication pre-configured.
|
||||
|
||||
@@ -120,7 +115,7 @@ def cleanup_test_client_auth(app: FastAPI) -> None:
|
||||
auth_deps = [
|
||||
get_current_user,
|
||||
get_optional_current_user,
|
||||
OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||
OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login"),
|
||||
]
|
||||
|
||||
# Remove overrides
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"""
|
||||
Utility functions for extracting and parsing device information from HTTP requests.
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
@@ -19,11 +19,11 @@ def extract_device_info(request: Request) -> DeviceInfo:
|
||||
Returns:
|
||||
DeviceInfo object with parsed device information
|
||||
"""
|
||||
user_agent = request.headers.get('user-agent', '')
|
||||
user_agent = request.headers.get("user-agent", "")
|
||||
|
||||
device_info = DeviceInfo(
|
||||
device_name=parse_device_name(user_agent),
|
||||
device_id=request.headers.get('x-device-id'), # Client must send this header
|
||||
device_id=request.headers.get("x-device-id"), # Client must send this header
|
||||
ip_address=get_client_ip(request),
|
||||
user_agent=user_agent[:500] if user_agent else None, # Truncate to max length
|
||||
location_city=None, # Can be populated via IP geolocation service
|
||||
@@ -33,7 +33,7 @@ def extract_device_info(request: Request) -> DeviceInfo:
|
||||
return device_info
|
||||
|
||||
|
||||
def parse_device_name(user_agent: str) -> Optional[str]:
|
||||
def parse_device_name(user_agent: str) -> str | None:
|
||||
"""
|
||||
Parse user agent string to extract a friendly device name.
|
||||
|
||||
@@ -54,48 +54,48 @@ def parse_device_name(user_agent: str) -> Optional[str]:
|
||||
user_agent_lower = user_agent.lower()
|
||||
|
||||
# Mobile devices (check first, as they can contain desktop patterns too)
|
||||
if 'iphone' in user_agent_lower:
|
||||
if "iphone" in user_agent_lower:
|
||||
return "iPhone"
|
||||
elif 'ipad' in user_agent_lower:
|
||||
elif "ipad" in user_agent_lower:
|
||||
return "iPad"
|
||||
elif 'android' in user_agent_lower:
|
||||
elif "android" in user_agent_lower:
|
||||
# Try to extract device model
|
||||
android_match = re.search(r'android.*;\s*([^)]+)\s*build', user_agent_lower)
|
||||
android_match = re.search(r"android.*;\s*([^)]+)\s*build", user_agent_lower)
|
||||
if android_match:
|
||||
device_model = android_match.group(1).strip()
|
||||
return f"Android ({device_model.title()})"
|
||||
return "Android device"
|
||||
elif 'windows phone' in user_agent_lower:
|
||||
elif "windows phone" in user_agent_lower:
|
||||
return "Windows Phone"
|
||||
|
||||
# Tablets (check before desktop, as some tablets contain "android")
|
||||
elif 'tablet' in user_agent_lower:
|
||||
elif "tablet" in user_agent_lower:
|
||||
return "Tablet"
|
||||
|
||||
# Smart TVs (check before desktop OS patterns)
|
||||
elif any(tv in user_agent_lower for tv in ['smart-tv', 'smarttv']):
|
||||
elif any(tv in user_agent_lower for tv in ["smart-tv", "smarttv"]):
|
||||
return "Smart TV"
|
||||
|
||||
# Game consoles (check before desktop OS patterns, as Xbox contains "Windows")
|
||||
elif 'playstation' in user_agent_lower:
|
||||
elif "playstation" in user_agent_lower:
|
||||
return "PlayStation"
|
||||
elif 'xbox' in user_agent_lower:
|
||||
elif "xbox" in user_agent_lower:
|
||||
return "Xbox"
|
||||
elif 'nintendo' in user_agent_lower:
|
||||
elif "nintendo" in user_agent_lower:
|
||||
return "Nintendo"
|
||||
|
||||
# Desktop operating systems
|
||||
elif 'macintosh' in user_agent_lower or 'mac os x' in user_agent_lower:
|
||||
elif "macintosh" in user_agent_lower or "mac os x" in user_agent_lower:
|
||||
# Try to extract browser
|
||||
browser = extract_browser(user_agent)
|
||||
return f"{browser} on Mac" if browser else "Mac"
|
||||
elif 'windows' in user_agent_lower:
|
||||
elif "windows" in user_agent_lower:
|
||||
browser = extract_browser(user_agent)
|
||||
return f"{browser} on Windows" if browser else "Windows PC"
|
||||
elif 'linux' in user_agent_lower and 'android' not in user_agent_lower:
|
||||
elif "linux" in user_agent_lower and "android" not in user_agent_lower:
|
||||
browser = extract_browser(user_agent)
|
||||
return f"{browser} on Linux" if browser else "Linux"
|
||||
elif 'cros' in user_agent_lower:
|
||||
elif "cros" in user_agent_lower:
|
||||
return "Chromebook"
|
||||
|
||||
# Fallback: just return browser name if detected
|
||||
@@ -106,7 +106,7 @@ def parse_device_name(user_agent: str) -> Optional[str]:
|
||||
return "Unknown device"
|
||||
|
||||
|
||||
def extract_browser(user_agent: str) -> Optional[str]:
|
||||
def extract_browser(user_agent: str) -> str | None:
|
||||
"""
|
||||
Extract browser name from user agent string.
|
||||
|
||||
@@ -126,26 +126,26 @@ def extract_browser(user_agent: str) -> Optional[str]:
|
||||
user_agent_lower = user_agent.lower()
|
||||
|
||||
# Check specific browsers (order matters - check Edge before Chrome!)
|
||||
if 'edg/' in user_agent_lower or 'edge/' in user_agent_lower:
|
||||
if "edg/" in user_agent_lower or "edge/" in user_agent_lower:
|
||||
return "Edge"
|
||||
elif 'opr/' in user_agent_lower or 'opera' in user_agent_lower:
|
||||
elif "opr/" in user_agent_lower or "opera" in user_agent_lower:
|
||||
return "Opera"
|
||||
elif 'chrome/' in user_agent_lower:
|
||||
elif "chrome/" in user_agent_lower:
|
||||
return "Chrome"
|
||||
elif 'safari/' in user_agent_lower:
|
||||
elif "safari/" in user_agent_lower:
|
||||
# Make sure it's actually Safari, not Chrome (which also contains "Safari")
|
||||
if 'chrome' not in user_agent_lower:
|
||||
if "chrome" not in user_agent_lower:
|
||||
return "Safari"
|
||||
return None
|
||||
elif 'firefox/' in user_agent_lower:
|
||||
elif "firefox/" in user_agent_lower:
|
||||
return "Firefox"
|
||||
elif 'msie' in user_agent_lower or 'trident/' in user_agent_lower:
|
||||
elif "msie" in user_agent_lower or "trident/" in user_agent_lower:
|
||||
return "Internet Explorer"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_client_ip(request: Request) -> Optional[str]:
|
||||
def get_client_ip(request: Request) -> str | None:
|
||||
"""
|
||||
Extract client IP address from request, considering proxy headers.
|
||||
|
||||
@@ -163,14 +163,14 @@ def get_client_ip(request: Request) -> Optional[str]:
|
||||
- request.client.host is fallback for direct connections
|
||||
"""
|
||||
# Check X-Forwarded-For (common in proxied environments)
|
||||
x_forwarded_for = request.headers.get('x-forwarded-for')
|
||||
x_forwarded_for = request.headers.get("x-forwarded-for")
|
||||
if x_forwarded_for:
|
||||
# Get the first IP (original client)
|
||||
client_ip = x_forwarded_for.split(',')[0].strip()
|
||||
client_ip = x_forwarded_for.split(",")[0].strip()
|
||||
return client_ip
|
||||
|
||||
# Check X-Real-IP (used by some proxies like nginx)
|
||||
x_real_ip = request.headers.get('x-real-ip')
|
||||
x_real_ip = request.headers.get("x-real-ip")
|
||||
if x_real_ip:
|
||||
return x_real_ip.strip()
|
||||
|
||||
@@ -195,9 +195,17 @@ def is_mobile_device(user_agent: str) -> bool:
|
||||
return False
|
||||
|
||||
mobile_patterns = [
|
||||
'mobile', 'android', 'iphone', 'ipad', 'ipod',
|
||||
'blackberry', 'windows phone', 'webos', 'opera mini',
|
||||
'iemobile', 'mobile safari'
|
||||
"mobile",
|
||||
"android",
|
||||
"iphone",
|
||||
"ipad",
|
||||
"ipod",
|
||||
"blackberry",
|
||||
"windows phone",
|
||||
"webos",
|
||||
"opera mini",
|
||||
"iemobile",
|
||||
"mobile safari",
|
||||
]
|
||||
|
||||
user_agent_lower = user_agent.lower()
|
||||
@@ -220,7 +228,7 @@ def get_device_type(user_agent: str) -> str:
|
||||
user_agent_lower = user_agent.lower()
|
||||
|
||||
# Check for tablets first (they can contain "mobile" too)
|
||||
if 'ipad' in user_agent_lower or 'tablet' in user_agent_lower:
|
||||
if "ipad" in user_agent_lower or "tablet" in user_agent_lower:
|
||||
return "tablet"
|
||||
|
||||
# Check for mobile
|
||||
@@ -228,7 +236,7 @@ def get_device_type(user_agent: str) -> str:
|
||||
return "mobile"
|
||||
|
||||
# Check for desktop OS patterns
|
||||
if any(os in user_agent_lower for os in ['windows', 'macintosh', 'linux', 'cros']):
|
||||
if any(os in user_agent_lower for os in ["windows", "macintosh", "linux", "cros"]):
|
||||
return "desktop"
|
||||
|
||||
return "other"
|
||||
|
||||
@@ -5,18 +5,21 @@ This module provides utilities for creating and verifying signed tokens,
|
||||
useful for operations like file uploads, password resets, or any other
|
||||
time-limited, single-use operations.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import secrets
|
||||
import time
|
||||
from typing import Dict, Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
def create_upload_token(file_path: str, content_type: str, expires_in: int = 300) -> str:
|
||||
def create_upload_token(
|
||||
file_path: str, content_type: str, expires_in: int = 300
|
||||
) -> str:
|
||||
"""
|
||||
Create a signed token for secure file uploads.
|
||||
|
||||
@@ -40,34 +43,29 @@ def create_upload_token(file_path: str, content_type: str, expires_in: int = 300
|
||||
"path": file_path,
|
||||
"content_type": content_type,
|
||||
"exp": int(time.time()) + expires_in,
|
||||
"nonce": secrets.token_hex(8) # Add randomness to prevent token reuse
|
||||
"nonce": secrets.token_hex(8), # Add randomness to prevent token reuse
|
||||
}
|
||||
|
||||
# Convert to JSON and encode
|
||||
payload_bytes = json.dumps(payload).encode('utf-8')
|
||||
payload_bytes = json.dumps(payload).encode("utf-8")
|
||||
|
||||
# Create a signature using HMAC-SHA256 for security
|
||||
# This prevents length extension attacks that plain SHA-256 is vulnerable to
|
||||
signature = hmac.new(
|
||||
settings.SECRET_KEY.encode('utf-8'),
|
||||
payload_bytes,
|
||||
hashlib.sha256
|
||||
settings.SECRET_KEY.encode("utf-8"), payload_bytes, hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
# Combine payload and signature
|
||||
token_data = {
|
||||
"payload": payload,
|
||||
"signature": signature
|
||||
}
|
||||
token_data = {"payload": payload, "signature": signature}
|
||||
|
||||
# Encode the final token
|
||||
token_json = json.dumps(token_data)
|
||||
token = base64.urlsafe_b64encode(token_json.encode('utf-8')).decode('utf-8')
|
||||
token = base64.urlsafe_b64encode(token_json.encode("utf-8")).decode("utf-8")
|
||||
|
||||
return token
|
||||
|
||||
|
||||
def verify_upload_token(token: str) -> Optional[Dict[str, Any]]:
|
||||
def verify_upload_token(token: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
Verify an upload token and return the payload if valid.
|
||||
|
||||
@@ -88,7 +86,7 @@ def verify_upload_token(token: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
try:
|
||||
# Decode the token
|
||||
token_json = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8')
|
||||
token_json = base64.urlsafe_b64decode(token.encode("utf-8")).decode("utf-8")
|
||||
token_data = json.loads(token_json)
|
||||
|
||||
# Extract payload and signature
|
||||
@@ -96,11 +94,9 @@ def verify_upload_token(token: str) -> Optional[Dict[str, Any]]:
|
||||
signature = token_data["signature"]
|
||||
|
||||
# Verify signature using HMAC and constant-time comparison
|
||||
payload_bytes = json.dumps(payload).encode('utf-8')
|
||||
payload_bytes = json.dumps(payload).encode("utf-8")
|
||||
expected_signature = hmac.new(
|
||||
settings.SECRET_KEY.encode('utf-8'),
|
||||
payload_bytes,
|
||||
hashlib.sha256
|
||||
settings.SECRET_KEY.encode("utf-8"), payload_bytes, hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
if not hmac.compare_digest(signature, expected_signature):
|
||||
@@ -136,34 +132,29 @@ def create_password_reset_token(email: str, expires_in: int = 3600) -> str:
|
||||
"email": email,
|
||||
"exp": int(time.time()) + expires_in,
|
||||
"nonce": secrets.token_hex(16), # Extra randomness
|
||||
"purpose": "password_reset"
|
||||
"purpose": "password_reset",
|
||||
}
|
||||
|
||||
# Convert to JSON and encode
|
||||
payload_bytes = json.dumps(payload).encode('utf-8')
|
||||
payload_bytes = json.dumps(payload).encode("utf-8")
|
||||
|
||||
# Create a signature using HMAC-SHA256 for security
|
||||
# This prevents length extension attacks that plain SHA-256 is vulnerable to
|
||||
signature = hmac.new(
|
||||
settings.SECRET_KEY.encode('utf-8'),
|
||||
payload_bytes,
|
||||
hashlib.sha256
|
||||
settings.SECRET_KEY.encode("utf-8"), payload_bytes, hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
# Combine payload and signature
|
||||
token_data = {
|
||||
"payload": payload,
|
||||
"signature": signature
|
||||
}
|
||||
token_data = {"payload": payload, "signature": signature}
|
||||
|
||||
# Encode the final token
|
||||
token_json = json.dumps(token_data)
|
||||
token = base64.urlsafe_b64encode(token_json.encode('utf-8')).decode('utf-8')
|
||||
token = base64.urlsafe_b64encode(token_json.encode("utf-8")).decode("utf-8")
|
||||
|
||||
return token
|
||||
|
||||
|
||||
def verify_password_reset_token(token: str) -> Optional[str]:
|
||||
def verify_password_reset_token(token: str) -> str | None:
|
||||
"""
|
||||
Verify a password reset token and return the email if valid.
|
||||
|
||||
@@ -182,7 +173,7 @@ def verify_password_reset_token(token: str) -> Optional[str]:
|
||||
"""
|
||||
try:
|
||||
# Decode the token
|
||||
token_json = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8')
|
||||
token_json = base64.urlsafe_b64decode(token.encode("utf-8")).decode("utf-8")
|
||||
token_data = json.loads(token_json)
|
||||
|
||||
# Extract payload and signature
|
||||
@@ -194,11 +185,9 @@ def verify_password_reset_token(token: str) -> Optional[str]:
|
||||
return None
|
||||
|
||||
# Verify signature using HMAC and constant-time comparison
|
||||
payload_bytes = json.dumps(payload).encode('utf-8')
|
||||
payload_bytes = json.dumps(payload).encode("utf-8")
|
||||
expected_signature = hmac.new(
|
||||
settings.SECRET_KEY.encode('utf-8'),
|
||||
payload_bytes,
|
||||
hashlib.sha256
|
||||
settings.SECRET_KEY.encode("utf-8"), payload_bytes, hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
if not hmac.compare_digest(signature, expected_signature):
|
||||
@@ -234,34 +223,29 @@ def create_email_verification_token(email: str, expires_in: int = 86400) -> str:
|
||||
"email": email,
|
||||
"exp": int(time.time()) + expires_in,
|
||||
"nonce": secrets.token_hex(16),
|
||||
"purpose": "email_verification"
|
||||
"purpose": "email_verification",
|
||||
}
|
||||
|
||||
# Convert to JSON and encode
|
||||
payload_bytes = json.dumps(payload).encode('utf-8')
|
||||
payload_bytes = json.dumps(payload).encode("utf-8")
|
||||
|
||||
# Create a signature using HMAC-SHA256 for security
|
||||
# This prevents length extension attacks that plain SHA-256 is vulnerable to
|
||||
signature = hmac.new(
|
||||
settings.SECRET_KEY.encode('utf-8'),
|
||||
payload_bytes,
|
||||
hashlib.sha256
|
||||
settings.SECRET_KEY.encode("utf-8"), payload_bytes, hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
# Combine payload and signature
|
||||
token_data = {
|
||||
"payload": payload,
|
||||
"signature": signature
|
||||
}
|
||||
token_data = {"payload": payload, "signature": signature}
|
||||
|
||||
# Encode the final token
|
||||
token_json = json.dumps(token_data)
|
||||
token = base64.urlsafe_b64encode(token_json.encode('utf-8')).decode('utf-8')
|
||||
token = base64.urlsafe_b64encode(token_json.encode("utf-8")).decode("utf-8")
|
||||
|
||||
return token
|
||||
|
||||
|
||||
def verify_email_verification_token(token: str) -> Optional[str]:
|
||||
def verify_email_verification_token(token: str) -> str | None:
|
||||
"""
|
||||
Verify an email verification token and return the email if valid.
|
||||
|
||||
@@ -280,7 +264,7 @@ def verify_email_verification_token(token: str) -> Optional[str]:
|
||||
"""
|
||||
try:
|
||||
# Decode the token
|
||||
token_json = base64.urlsafe_b64decode(token.encode('utf-8')).decode('utf-8')
|
||||
token_json = base64.urlsafe_b64decode(token.encode("utf-8")).decode("utf-8")
|
||||
token_data = json.loads(token_json)
|
||||
|
||||
# Extract payload and signature
|
||||
@@ -292,11 +276,9 @@ def verify_email_verification_token(token: str) -> Optional[str]:
|
||||
return None
|
||||
|
||||
# Verify signature using HMAC and constant-time comparison
|
||||
payload_bytes = json.dumps(payload).encode('utf-8')
|
||||
payload_bytes = json.dumps(payload).encode("utf-8")
|
||||
expected_signature = hmac.new(
|
||||
settings.SECRET_KEY.encode('utf-8'),
|
||||
payload_bytes,
|
||||
hashlib.sha256
|
||||
settings.SECRET_KEY.encode("utf-8"), payload_bytes, hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
if not hmac.compare_digest(signature, expected_signature):
|
||||
|
||||
@@ -9,17 +9,19 @@ from app.core.database import Base
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_test_engine():
|
||||
"""Create an SQLite in-memory engine specifically for testing"""
|
||||
test_engine = create_engine(
|
||||
"sqlite:///:memory:",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool, # Use static pool for in-memory testing
|
||||
echo=False
|
||||
echo=False,
|
||||
)
|
||||
|
||||
return test_engine
|
||||
|
||||
|
||||
def setup_test_db():
|
||||
"""Create a test database and session factory"""
|
||||
# Create a new engine for this test run
|
||||
@@ -30,14 +32,12 @@ def setup_test_db():
|
||||
|
||||
# Create session factory
|
||||
TestingSessionLocal = sessionmaker(
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
bind=test_engine,
|
||||
expire_on_commit=False
|
||||
autocommit=False, autoflush=False, bind=test_engine, expire_on_commit=False
|
||||
)
|
||||
|
||||
return test_engine, TestingSessionLocal
|
||||
|
||||
|
||||
def teardown_test_db(engine):
|
||||
"""Clean up after tests"""
|
||||
# Drop all tables
|
||||
@@ -46,13 +46,14 @@ def teardown_test_db(engine):
|
||||
# Dispose of engine
|
||||
engine.dispose()
|
||||
|
||||
|
||||
async def get_async_test_engine():
|
||||
"""Create an async SQLite in-memory engine specifically for testing"""
|
||||
test_engine = create_async_engine(
|
||||
"sqlite+aiosqlite:///:memory:",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool, # Use static pool for in-memory testing
|
||||
echo=False
|
||||
echo=False,
|
||||
)
|
||||
return test_engine
|
||||
|
||||
@@ -64,12 +65,12 @@ async def setup_async_test_db():
|
||||
async with test_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
AsyncTestingSessionLocal = sessionmaker(
|
||||
AsyncTestingSessionLocal = sessionmaker( # pyright: ignore[reportCallIssue]
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
bind=test_engine,
|
||||
bind=test_engine, # pyright: ignore[reportArgumentType]
|
||||
expire_on_commit=False,
|
||||
class_=AsyncSession
|
||||
class_=AsyncSession,
|
||||
)
|
||||
|
||||
return test_engine, AsyncTestingSessionLocal
|
||||
|
||||
@@ -79,12 +79,13 @@ This FastAPI backend application follows a **clean layered architecture** patter
|
||||
|
||||
### Authentication & Security
|
||||
|
||||
- **python-jose**: JWT token generation and validation
|
||||
- Cryptographic signing
|
||||
- **PyJWT**: JWT token generation and validation
|
||||
- Cryptographic signing (HS256, RS256)
|
||||
- Token expiration handling
|
||||
- Claims validation
|
||||
- JWK support for Google ID token verification
|
||||
|
||||
- **passlib + bcrypt**: Password hashing
|
||||
- **bcrypt**: Password hashing
|
||||
- Industry-standard bcrypt algorithm
|
||||
- Configurable cost factor
|
||||
- Salt generation
|
||||
@@ -117,7 +118,8 @@ backend/
|
||||
│ ├── api/ # API layer
|
||||
│ │ ├── dependencies/ # Dependency injection
|
||||
│ │ │ ├── auth.py # Authentication dependencies
|
||||
│ │ │ └── permissions.py # Authorization dependencies
|
||||
│ │ │ ├── permissions.py # Authorization dependencies
|
||||
│ │ │ └── services.py # Service singleton injection
|
||||
│ │ ├── routes/ # API endpoints
|
||||
│ │ │ ├── auth.py # Authentication routes
|
||||
│ │ │ ├── users.py # User management routes
|
||||
@@ -131,13 +133,14 @@ backend/
|
||||
│ │ ├── config.py # Application configuration
|
||||
│ │ ├── database.py # Database connection
|
||||
│ │ ├── exceptions.py # Custom exception classes
|
||||
│ │ ├── repository_exceptions.py # Repository-level exception hierarchy
|
||||
│ │ └── middleware.py # Custom middleware
|
||||
│ │
|
||||
│ ├── crud/ # Database operations
|
||||
│ │ ├── base.py # Generic CRUD base class
|
||||
│ │ ├── user.py # User CRUD operations
|
||||
│ │ ├── session.py # Session CRUD operations
|
||||
│ │ └── organization.py # Organization CRUD
|
||||
│ ├── repositories/ # Data access layer
|
||||
│ │ ├── base.py # Generic repository base class
|
||||
│ │ ├── user.py # User repository
|
||||
│ │ ├── session.py # Session repository
|
||||
│ │ └── organization.py # Organization repository
|
||||
│ │
|
||||
│ ├── models/ # SQLAlchemy models
|
||||
│ │ ├── base.py # Base model with mixins
|
||||
@@ -153,8 +156,11 @@ backend/
|
||||
│ │ ├── sessions.py # Session schemas
|
||||
│ │ └── organizations.py # Organization schemas
|
||||
│ │
|
||||
│ ├── services/ # Business logic
|
||||
│ ├── services/ # Business logic layer
|
||||
│ │ ├── auth_service.py # Authentication service
|
||||
│ │ ├── user_service.py # User management service
|
||||
│ │ ├── session_service.py # Session management service
|
||||
│ │ ├── organization_service.py # Organization service
|
||||
│ │ ├── email_service.py # Email service
|
||||
│ │ └── session_cleanup.py # Background cleanup
|
||||
│ │
|
||||
@@ -168,20 +174,25 @@ backend/
|
||||
│
|
||||
├── tests/ # Test suite
|
||||
│ ├── api/ # Integration tests
|
||||
│ ├── crud/ # CRUD tests
|
||||
│ ├── repositories/ # Repository unit tests
|
||||
│ ├── services/ # Service unit tests
|
||||
│ ├── models/ # Model tests
|
||||
│ ├── services/ # Service tests
|
||||
│ └── conftest.py # Test configuration
|
||||
│
|
||||
├── docs/ # Documentation
|
||||
│ ├── ARCHITECTURE.md # This file
|
||||
│ ├── CODING_STANDARDS.md # Coding standards
|
||||
│ ├── COMMON_PITFALLS.md # Common mistakes to avoid
|
||||
│ ├── E2E_TESTING.md # E2E testing guide
|
||||
│ └── FEATURE_EXAMPLE.md # Feature implementation guide
|
||||
│
|
||||
├── requirements.txt # Python dependencies
|
||||
├── pytest.ini # Pytest configuration
|
||||
├── .coveragerc # Coverage configuration
|
||||
└── alembic.ini # Alembic configuration
|
||||
├── pyproject.toml # Dependencies, tool configs (Ruff, pytest, coverage, Pyright)
|
||||
├── uv.lock # Locked dependency versions (commit to git)
|
||||
├── Makefile # Development commands (quality, security, testing)
|
||||
├── .pre-commit-config.yaml # Pre-commit hook configuration
|
||||
├── .secrets.baseline # detect-secrets baseline (known false positives)
|
||||
├── alembic.ini # Alembic configuration
|
||||
└── migrate.py # Migration helper script
|
||||
```
|
||||
|
||||
## Layered Architecture
|
||||
@@ -214,11 +225,11 @@ The application follows a strict 5-layer architecture:
|
||||
└──────────────────────────┬──────────────────────────────────┘
|
||||
│ calls
|
||||
┌──────────────────────────▼──────────────────────────────────┐
|
||||
│ CRUD Layer (crud/) │
|
||||
│ Repository Layer (repositories/) │
|
||||
│ - Database operations │
|
||||
│ - Query building │
|
||||
│ - Transaction management │
|
||||
│ - Error handling │
|
||||
│ - Custom repository exceptions │
|
||||
│ - No business logic │
|
||||
└──────────────────────────┬──────────────────────────────────┘
|
||||
│ uses
|
||||
┌──────────────────────────▼──────────────────────────────────┐
|
||||
@@ -262,7 +273,7 @@ async def get_current_user_info(
|
||||
|
||||
**Rules**:
|
||||
- Should NOT contain business logic
|
||||
- Should NOT directly perform database operations (use CRUD or services)
|
||||
- Should NOT directly call repositories (use services injected via `dependencies/services.py`)
|
||||
- Must validate all input via Pydantic schemas
|
||||
- Must specify response models
|
||||
- Should apply appropriate rate limits
|
||||
@@ -279,9 +290,9 @@ async def get_current_user_info(
|
||||
|
||||
**Example**:
|
||||
```python
|
||||
def get_current_user(
|
||||
async def get_current_user(
|
||||
token: str = Depends(oauth2_scheme),
|
||||
db: Session = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> User:
|
||||
"""
|
||||
Extract and validate user from JWT token.
|
||||
@@ -295,7 +306,7 @@ def get_current_user(
|
||||
except Exception:
|
||||
raise AuthenticationError("Invalid authentication credentials")
|
||||
|
||||
user = user_crud.get(db, id=user_id)
|
||||
user = await user_repo.get(db, id=user_id)
|
||||
if not user:
|
||||
raise AuthenticationError("User not found")
|
||||
|
||||
@@ -313,7 +324,7 @@ def get_current_user(
|
||||
**Responsibility**: Implement complex business logic
|
||||
|
||||
**Key Functions**:
|
||||
- Orchestrate multiple CRUD operations
|
||||
- Orchestrate multiple repository operations
|
||||
- Implement business rules
|
||||
- Handle external service integration
|
||||
- Coordinate transactions
|
||||
@@ -323,9 +334,9 @@ def get_current_user(
|
||||
class AuthService:
|
||||
"""Authentication service with business logic."""
|
||||
|
||||
def login(
|
||||
async def login(
|
||||
self,
|
||||
db: Session,
|
||||
db: AsyncSession,
|
||||
email: str,
|
||||
password: str,
|
||||
request: Request
|
||||
@@ -339,8 +350,8 @@ class AuthService:
|
||||
3. Generate tokens
|
||||
4. Return tokens and user info
|
||||
"""
|
||||
# Validate credentials
|
||||
user = user_crud.get_by_email(db, email=email)
|
||||
# Validate credentials via repository
|
||||
user = await user_repo.get_by_email(db, email=email)
|
||||
if not user or not verify_password(password, user.hashed_password):
|
||||
raise AuthenticationError("Invalid credentials")
|
||||
|
||||
@@ -350,11 +361,10 @@ class AuthService:
|
||||
# Extract device info
|
||||
device_info = extract_device_info(request)
|
||||
|
||||
# Create session
|
||||
session = session_crud.create_session(
|
||||
# Create session via repository
|
||||
session = await session_repo.create(
|
||||
db,
|
||||
user_id=user.id,
|
||||
device_info=device_info
|
||||
obj_in=SessionCreate(user_id=user.id, **device_info)
|
||||
)
|
||||
|
||||
# Generate tokens
|
||||
@@ -373,75 +383,60 @@ class AuthService:
|
||||
|
||||
**Rules**:
|
||||
- Contains business logic, not just data operations
|
||||
- Can call multiple CRUD operations
|
||||
- Can call multiple repository operations
|
||||
- Should handle complex workflows
|
||||
- Must maintain data consistency
|
||||
- Should use transactions when needed
|
||||
|
||||
#### 4. CRUD Layer (`app/crud/`)
|
||||
#### 4. Repository Layer (`app/repositories/`)
|
||||
|
||||
**Responsibility**: Database operations and queries
|
||||
**Responsibility**: Database operations and queries — no business logic
|
||||
|
||||
**Key Functions**:
|
||||
- Create, read, update, delete operations
|
||||
- Build database queries
|
||||
- Handle database errors
|
||||
- Raise custom repository exceptions (`DuplicateEntryError`, `IntegrityConstraintError`)
|
||||
- Manage soft deletes
|
||||
- Implement pagination and filtering
|
||||
|
||||
**Example**:
|
||||
```python
|
||||
class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
"""CRUD operations for user sessions."""
|
||||
class SessionRepository(RepositoryBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
"""Repository for user sessions — database operations only."""
|
||||
|
||||
def get_by_jti(self, db: Session, jti: UUID) -> Optional[UserSession]:
|
||||
async def get_by_jti(self, db: AsyncSession, *, jti: str) -> UserSession | None:
|
||||
"""Get session by refresh token JTI."""
|
||||
try:
|
||||
return (
|
||||
db.query(UserSession)
|
||||
.filter(UserSession.refresh_token_jti == jti)
|
||||
.first()
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting session by JTI: {str(e)}")
|
||||
return None
|
||||
result = await db.execute(
|
||||
select(UserSession).where(UserSession.refresh_token_jti == jti)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
def get_active_by_jti(
|
||||
self,
|
||||
db: Session,
|
||||
jti: UUID
|
||||
) -> Optional[UserSession]:
|
||||
"""Get active session by refresh token JTI."""
|
||||
session = self.get_by_jti(db, jti=jti)
|
||||
if session and session.is_active and not session.is_expired:
|
||||
return session
|
||||
return None
|
||||
|
||||
def deactivate(self, db: Session, session_id: UUID) -> bool:
|
||||
async def deactivate(self, db: AsyncSession, *, session_id: UUID) -> bool:
|
||||
"""Deactivate a session (logout)."""
|
||||
try:
|
||||
session = self.get(db, id=session_id)
|
||||
session = await self.get(db, id=session_id)
|
||||
if not session:
|
||||
return False
|
||||
|
||||
session.is_active = False
|
||||
db.commit()
|
||||
await db.commit()
|
||||
logger.info(f"Session {session_id} deactivated")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
await db.rollback()
|
||||
logger.error(f"Error deactivating session: {str(e)}")
|
||||
return False
|
||||
```
|
||||
|
||||
**Rules**:
|
||||
- Should NOT contain business logic
|
||||
- Must handle database exceptions
|
||||
- Must use parameterized queries (SQLAlchemy does this)
|
||||
- Must raise custom repository exceptions (not raw `ValueError`/`IntegrityError`)
|
||||
- Must use async SQLAlchemy 2.0 `select()` API (never `db.query()`)
|
||||
- Should log all database errors
|
||||
- Must rollback on errors
|
||||
- Should use soft deletes when possible
|
||||
- **Never imported directly by routes** — always called through services
|
||||
|
||||
#### 5. Data Layer (`app/models/` + `app/schemas/`)
|
||||
|
||||
@@ -546,51 +541,23 @@ SessionLocal = sessionmaker(
|
||||
#### Dependency Injection Pattern
|
||||
|
||||
```python
|
||||
def get_db() -> Generator[Session, None, None]:
|
||||
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""
|
||||
Database session dependency for FastAPI routes.
|
||||
Async database session dependency for FastAPI routes.
|
||||
|
||||
Automatically commits on success, rolls back on error.
|
||||
The session is passed to service methods; commit/rollback is
|
||||
managed inside service or repository methods.
|
||||
"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
async with AsyncSessionLocal() as db:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# Usage in routes
|
||||
# Usage in routes — always through a service, never direct repository
|
||||
@router.get("/users")
|
||||
def list_users(db: Session = Depends(get_db)):
|
||||
return user_crud.get_multi(db)
|
||||
```
|
||||
|
||||
#### Context Manager Pattern
|
||||
|
||||
```python
|
||||
@contextmanager
|
||||
def transaction_scope() -> Generator[Session, None, None]:
|
||||
"""
|
||||
Context manager for database transactions.
|
||||
|
||||
Use for complex operations requiring multiple steps.
|
||||
Automatically commits on success, rolls back on error.
|
||||
"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
db.commit()
|
||||
except Exception:
|
||||
db.rollback()
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# Usage in services
|
||||
def complex_operation():
|
||||
with transaction_scope() as db:
|
||||
user = user_crud.create(db, obj_in=user_data)
|
||||
session = session_crud.create(db, session_data)
|
||||
return user, session
|
||||
async def list_users(
|
||||
user_service: UserService = Depends(get_user_service),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
return await user_service.get_users(db)
|
||||
```
|
||||
|
||||
### Model Mixins
|
||||
@@ -782,22 +749,15 @@ def get_profile(
|
||||
|
||||
```python
|
||||
@router.delete("/sessions/{session_id}")
|
||||
def revoke_session(
|
||||
async def revoke_session(
|
||||
session_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
session_service: SessionService = Depends(get_session_service),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Users can only revoke their own sessions."""
|
||||
session = session_crud.get(db, id=session_id)
|
||||
|
||||
if not session:
|
||||
raise NotFoundError("Session not found")
|
||||
|
||||
# Check ownership
|
||||
if session.user_id != current_user.id:
|
||||
raise AuthorizationError("You can only revoke your own sessions")
|
||||
|
||||
session_crud.deactivate(db, session_id=session_id)
|
||||
# SessionService verifies ownership and raises NotFoundError / AuthorizationError
|
||||
await session_service.revoke_session(db, session_id=session_id, user_id=current_user.id)
|
||||
return MessageResponse(success=True, message="Session revoked")
|
||||
```
|
||||
|
||||
@@ -818,6 +778,84 @@ def add_member(
|
||||
pass
|
||||
```
|
||||
|
||||
### OAuth Integration
|
||||
|
||||
The system supports two OAuth modes:
|
||||
|
||||
#### OAuth Consumer Mode (Social Login)
|
||||
|
||||
Users can authenticate via Google or GitHub OAuth providers:
|
||||
|
||||
```python
|
||||
# Get authorization URL with PKCE support
|
||||
GET /oauth/authorize/{provider}?redirect_uri=https://yourapp.com/callback
|
||||
|
||||
# Handle callback and exchange code for tokens
|
||||
POST /oauth/callback/{provider}
|
||||
{
|
||||
"code": "authorization_code_from_provider",
|
||||
"state": "csrf_state_token"
|
||||
}
|
||||
```
|
||||
|
||||
**Security Features:**
|
||||
- PKCE (S256) for Google
|
||||
- State parameter for CSRF protection
|
||||
- Nonce for Google OIDC replay attack prevention
|
||||
- Google ID token signature verification via JWKS
|
||||
- Email normalization to prevent account duplication
|
||||
- Auto-linking by email (configurable)
|
||||
|
||||
#### OAuth Provider Mode (MCP Integration)
|
||||
|
||||
Full OAuth 2.0 Authorization Server for third-party clients (RFC compliant):
|
||||
|
||||
```
|
||||
┌─────────────┐ ┌─────────────┐
|
||||
│ MCP Client │ │ Backend │
|
||||
└──────┬──────┘ └──────┬──────┘
|
||||
│ │
|
||||
│ GET /.well-known/oauth-authorization-server│
|
||||
│─────────────────────────────────────────────>│
|
||||
│ {metadata} │
|
||||
│<─────────────────────────────────────────────│
|
||||
│ │
|
||||
│ GET /oauth/provider/authorize │
|
||||
│ ?response_type=code&client_id=... │
|
||||
│ &redirect_uri=...&code_challenge=... │
|
||||
│─────────────────────────────────────────────>│
|
||||
│ │
|
||||
│ (User consents) │
|
||||
│ │
|
||||
│ 302 redirect_uri?code=AUTH_CODE&state=... │
|
||||
│<─────────────────────────────────────────────│
|
||||
│ │
|
||||
│ POST /oauth/provider/token │
|
||||
│ {grant_type=authorization_code, │
|
||||
│ code=AUTH_CODE, code_verifier=...} │
|
||||
│─────────────────────────────────────────────>│
|
||||
│ │
|
||||
│ {access_token, refresh_token, expires_in} │
|
||||
│<─────────────────────────────────────────────│
|
||||
│ │
|
||||
```
|
||||
|
||||
**Endpoints:**
|
||||
- `GET /.well-known/oauth-authorization-server` - RFC 8414 metadata
|
||||
- `GET /oauth/provider/authorize` - Authorization endpoint
|
||||
- `POST /oauth/provider/token` - Token endpoint (authorization_code, refresh_token)
|
||||
- `POST /oauth/provider/revoke` - RFC 7009 token revocation
|
||||
- `POST /oauth/provider/introspect` - RFC 7662 token introspection
|
||||
|
||||
**Security Features:**
|
||||
- PKCE S256 required for public clients (plain method rejected)
|
||||
- Authorization codes are single-use with 10-minute expiry
|
||||
- Code reuse detection triggers security incident (all tokens revoked)
|
||||
- Refresh token rotation on use
|
||||
- Opaque refresh tokens (hashed in database)
|
||||
- JWT access tokens with standard claims
|
||||
- Consent management per client
|
||||
|
||||
## Error Handling
|
||||
|
||||
### Exception Hierarchy
|
||||
@@ -983,23 +1021,27 @@ from app.services.session_cleanup import cleanup_expired_sessions
|
||||
|
||||
scheduler = AsyncIOScheduler()
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""Start background jobs on application startup."""
|
||||
if not settings.IS_TEST: # Don't run in tests
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Application lifespan context manager."""
|
||||
# Startup
|
||||
if os.getenv("IS_TEST", "False") != "True":
|
||||
scheduler.add_job(
|
||||
cleanup_expired_sessions,
|
||||
"cron",
|
||||
hour=2, # Run at 2 AM daily
|
||||
id="cleanup_expired_sessions"
|
||||
id="cleanup_expired_sessions",
|
||||
replace_existing=True,
|
||||
)
|
||||
scheduler.start()
|
||||
logger.info("Background jobs started")
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
"""Stop background jobs on application shutdown."""
|
||||
scheduler.shutdown()
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
if os.getenv("IS_TEST", "False") != "True":
|
||||
scheduler.shutdown()
|
||||
await close_async_db() # Dispose database engine connections
|
||||
```
|
||||
|
||||
### Job Implementation
|
||||
@@ -1014,8 +1056,8 @@ async def cleanup_expired_sessions():
|
||||
Runs daily at 2 AM. Removes sessions expired for more than 30 days.
|
||||
"""
|
||||
try:
|
||||
with transaction_scope() as db:
|
||||
count = session_crud.cleanup_expired(db, keep_days=30)
|
||||
async with AsyncSessionLocal() as db:
|
||||
count = await session_repo.cleanup_expired(db, keep_days=30)
|
||||
logger.info(f"Cleaned up {count} expired sessions")
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up sessions: {str(e)}", exc_info=True)
|
||||
@@ -1032,7 +1074,7 @@ async def cleanup_expired_sessions():
|
||||
│Integration │ ← API endpoint tests
|
||||
│ Tests │
|
||||
├─────────────┤
|
||||
│ Unit │ ← CRUD, services, utilities
|
||||
│ Unit │ ← repositories, services, utilities
|
||||
│ Tests │
|
||||
└─────────────┘
|
||||
```
|
||||
@@ -1127,6 +1169,8 @@ app.add_middleware(
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
> 📖 For the full benchmarking guide (how to run, read results, write new benchmarks, and manage baselines), see **[BENCHMARKS.md](BENCHMARKS.md)**.
|
||||
|
||||
### Database Connection Pooling
|
||||
|
||||
- Pool size: 20 connections
|
||||
|
||||
311
backend/docs/BENCHMARKS.md
Normal file
311
backend/docs/BENCHMARKS.md
Normal file
@@ -0,0 +1,311 @@
|
||||
# Performance Benchmarks Guide
|
||||
|
||||
Automated performance benchmarking infrastructure using **pytest-benchmark** to detect latency regressions in critical API endpoints.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Why Benchmark?](#why-benchmark)
|
||||
- [Quick Start](#quick-start)
|
||||
- [How It Works](#how-it-works)
|
||||
- [Understanding Results](#understanding-results)
|
||||
- [Test Organization](#test-organization)
|
||||
- [Writing Benchmark Tests](#writing-benchmark-tests)
|
||||
- [Baseline Management](#baseline-management)
|
||||
- [CI/CD Integration](#cicd-integration)
|
||||
- [Troubleshooting](#troubleshooting)
|
||||
|
||||
---
|
||||
|
||||
## Why Benchmark?
|
||||
|
||||
Performance regressions are silent bugs — they don't break tests or cause errors, but they degrade the user experience over time. Common causes include:
|
||||
|
||||
- **Unintended N+1 queries** after adding a relationship
|
||||
- **Heavier serialization** after adding new fields to a response model
|
||||
- **Middleware overhead** from new security headers or logging
|
||||
- **Dependency upgrades** that introduce slower code paths
|
||||
|
||||
Without automated benchmarks, these regressions go unnoticed until users complain. Performance benchmarks serve as an **early warning system** — they measure endpoint latency on every run and flag significant deviations from an established baseline.
|
||||
|
||||
### What benchmarks give you
|
||||
|
||||
| Benefit | Description |
|
||||
|---------|-------------|
|
||||
| **Regression detection** | Automatically flags when an endpoint becomes significantly slower |
|
||||
| **Baseline tracking** | Stores known-good performance numbers for comparison |
|
||||
| **Confidence in refactors** | Verify that code changes don't degrade response times |
|
||||
| **Visibility** | Makes performance a first-class, measurable quality attribute |
|
||||
|
||||
---
|
||||
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
# Run benchmarks (no comparison, just see current numbers)
|
||||
make benchmark
|
||||
|
||||
# Save current results as the baseline
|
||||
make benchmark-save
|
||||
|
||||
# Run benchmarks and compare against the saved baseline
|
||||
make benchmark-check
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## How It Works
|
||||
|
||||
The benchmarking system has three layers:
|
||||
|
||||
### 1. pytest-benchmark integration
|
||||
|
||||
[pytest-benchmark](https://pytest-benchmark.readthedocs.io/) is a pytest plugin that provides a `benchmark` fixture. It handles:
|
||||
|
||||
- **Calibration**: Automatically determines how many iterations to run for statistical significance
|
||||
- **Timing**: Uses `time.perf_counter` for high-resolution measurements
|
||||
- **Statistics**: Computes min, max, mean, median, standard deviation, IQR, and outlier detection
|
||||
- **Comparison**: Compares current results against saved baselines and flags regressions
|
||||
|
||||
### 2. Benchmark types
|
||||
|
||||
The test suite includes two categories of performance tests:
|
||||
|
||||
| Type | How it works | Examples |
|
||||
|------|-------------|----------|
|
||||
| **pytest-benchmark tests** | Uses the `benchmark` fixture for precise, multi-round timing | `test_health_endpoint_performance`, `test_openapi_schema_performance`, `test_password_hashing_performance`, `test_password_verification_performance`, `test_access_token_creation_performance`, `test_refresh_token_creation_performance`, `test_token_decode_performance` |
|
||||
| **Manual latency tests** | Uses `time.perf_counter` with explicit thresholds (for async endpoints that pytest-benchmark doesn't support natively) | `test_login_latency`, `test_get_current_user_latency`, `test_register_latency`, `test_token_refresh_latency`, `test_sessions_list_latency`, `test_user_profile_update_latency` |
|
||||
|
||||
### 3. Regression detection
|
||||
|
||||
When running `make benchmark-check`, the system:
|
||||
|
||||
1. Runs all benchmark tests
|
||||
2. Compares results against the saved baseline (`.benchmarks/` directory)
|
||||
3. **Fails the build** if any test's mean time exceeds **200%** of the baseline (i.e., 3× slower)
|
||||
|
||||
The `200%` threshold in `--benchmark-compare-fail=mean:200%` means "fail if the mean increased by more than 200% relative to the baseline." This is deliberately generous to avoid false positives from normal run-to-run variance while still catching real regressions.
|
||||
|
||||
---
|
||||
|
||||
## Understanding Results
|
||||
|
||||
A typical benchmark output looks like this:
|
||||
|
||||
```
|
||||
--------------------------------------------------------------------------------------- benchmark: 2 tests --------------------------------------------------------------------------------------
|
||||
Name (time in ms) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations
|
||||
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
test_health_endpoint_performance 0.9841 (1.0) 1.5513 (1.0) 1.1390 (1.0) 0.1098 (1.0) 1.1151 (1.0) 0.1672 (1.0) 39;2 877.9666 (1.0) 133 1
|
||||
test_openapi_schema_performance 1.6523 (1.68) 2.0892 (1.35) 1.7843 (1.57) 0.1553 (1.41) 1.7200 (1.54) 0.1727 (1.03) 2;0 560.4471 (0.64) 10 1
|
||||
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
```
|
||||
|
||||
### Column reference
|
||||
|
||||
| Column | Meaning |
|
||||
|--------|---------|
|
||||
| **Min** | Fastest single execution |
|
||||
| **Max** | Slowest single execution |
|
||||
| **Mean** | Average across all rounds — the primary metric for regression detection |
|
||||
| **StdDev** | How much results vary between rounds (lower = more stable) |
|
||||
| **Median** | Middle value, less sensitive to outliers than mean |
|
||||
| **IQR** | Interquartile range — spread of the middle 50% of results |
|
||||
| **Outliers** | Format `A;B` — A = within 1 StdDev, B = within 1.5 IQR from quartiles |
|
||||
| **OPS** | Operations per second (`1 / Mean`) |
|
||||
| **Rounds** | How many times the test was executed (auto-calibrated) |
|
||||
| **Iterations** | Iterations per round (usually 1 for ms-scale tests) |
|
||||
|
||||
### The ratio numbers `(1.0)`, `(1.68)`, etc.
|
||||
|
||||
These show how each test compares **to the best result in that column**. The fastest test is always `(1.0)`, and others show their relative factor. For example, `(1.68)` means "1.68× slower than the fastest."
|
||||
|
||||
### Color coding
|
||||
|
||||
- **Green**: The fastest (best) value in each column
|
||||
- **Red**: The slowest (worst) value in each column
|
||||
|
||||
This is a **relative ranking within the current run** — red does NOT mean the test failed or that performance is bad. It simply highlights which endpoint is the slower one in the group.
|
||||
|
||||
### What's "normal"?
|
||||
|
||||
For this project's current endpoints:
|
||||
|
||||
| Test | Expected range | Why |
|
||||
|------|---------------|-----|
|
||||
| `GET /health` | ~1–1.5ms | Minimal logic, mocked DB check |
|
||||
| `GET /api/v1/openapi.json` | ~1.5–2.5ms | Serializes entire API schema |
|
||||
| `get_password_hash` | ~200ms | CPU-bound bcrypt hashing |
|
||||
| `verify_password` | ~200ms | CPU-bound bcrypt verification |
|
||||
| `create_access_token` | ~17–20µs | JWT encoding with HMAC-SHA256 |
|
||||
| `create_refresh_token` | ~17–20µs | JWT encoding with HMAC-SHA256 |
|
||||
| `decode_token` | ~20–25µs | JWT decoding and claim validation |
|
||||
| `POST /api/v1/auth/login` | < 500ms threshold | Includes bcrypt password verification |
|
||||
| `POST /api/v1/auth/register` | < 500ms threshold | Includes bcrypt password hashing |
|
||||
| `POST /api/v1/auth/refresh` | < 200ms threshold | Token rotation + DB session update |
|
||||
| `GET /api/v1/users/me` | < 200ms threshold | DB lookup + token validation |
|
||||
| `GET /api/v1/sessions/me` | < 200ms threshold | Session list query + token validation |
|
||||
| `PATCH /api/v1/users/me` | < 200ms threshold | DB update + token validation |
|
||||
|
||||
---
|
||||
|
||||
## Test Organization
|
||||
|
||||
```
|
||||
backend/tests/
|
||||
├── benchmarks/
|
||||
│ └── test_endpoint_performance.py # All performance benchmark tests
|
||||
│
|
||||
backend/.benchmarks/ # Saved baselines (auto-generated)
|
||||
└── Linux-CPython-3.12-64bit/
|
||||
└── 0001_baseline.json # Platform-specific baseline file
|
||||
```
|
||||
|
||||
### Test markers
|
||||
|
||||
All benchmark tests use the `@pytest.mark.benchmark` marker. The `--benchmark-only` flag ensures that only tests using the `benchmark` fixture are executed during benchmark runs, while manual latency tests (async) are skipped.
|
||||
|
||||
---
|
||||
|
||||
## Writing Benchmark Tests
|
||||
|
||||
### Stateless endpoint (using pytest-benchmark fixture)
|
||||
|
||||
```python
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
def test_my_endpoint_performance(sync_client, benchmark):
|
||||
"""Benchmark: GET /my-endpoint should respond within acceptable latency."""
|
||||
result = benchmark(sync_client.get, "/my-endpoint")
|
||||
assert result.status_code == 200
|
||||
```
|
||||
|
||||
The `benchmark` fixture handles all timing, calibration, and statistics automatically. Just pass it the callable and arguments.
|
||||
|
||||
### Async / DB-dependent endpoint (manual timing)
|
||||
|
||||
For async endpoints that require database access, use manual timing with an explicit threshold:
|
||||
|
||||
```python
|
||||
import time
|
||||
import pytest
|
||||
|
||||
MAX_RESPONSE_MS = 300
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_my_async_endpoint_latency(client, setup_fixture):
|
||||
"""Performance: endpoint must respond under threshold."""
|
||||
iterations = 5
|
||||
total_ms = 0.0
|
||||
|
||||
for _ in range(iterations):
|
||||
start = time.perf_counter()
|
||||
response = await client.get("/api/v1/my-endpoint")
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
total_ms += elapsed_ms
|
||||
assert response.status_code == 200
|
||||
|
||||
mean_ms = total_ms / iterations
|
||||
assert mean_ms < MAX_RESPONSE_MS, (
|
||||
f"Latency regression: {mean_ms:.1f}ms exceeds {MAX_RESPONSE_MS}ms threshold"
|
||||
)
|
||||
```
|
||||
|
||||
### Guidelines for new benchmarks
|
||||
|
||||
1. **Benchmark critical paths** — endpoints users hit frequently or where latency matters most
|
||||
2. **Mock external dependencies** for stateless tests to isolate endpoint overhead
|
||||
3. **Set generous thresholds** for manual tests — account for CI variability
|
||||
4. **Keep benchmarks fast** — they run on every check, so avoid heavy setup
|
||||
|
||||
---
|
||||
|
||||
## Baseline Management
|
||||
|
||||
### Saving a baseline
|
||||
|
||||
```bash
|
||||
make benchmark-save
|
||||
```
|
||||
|
||||
This runs all benchmarks and saves results to `.benchmarks/<platform>/0001_baseline.json`. The baseline captures:
|
||||
- Mean, min, max, median, stddev for each test
|
||||
- Machine info (CPU, OS, Python version)
|
||||
- Timestamp
|
||||
|
||||
### Comparing against baseline
|
||||
|
||||
```bash
|
||||
make benchmark-check
|
||||
```
|
||||
|
||||
If no baseline exists, this command automatically creates one and prints a warning. On subsequent runs, it compares current results against the saved baseline.
|
||||
|
||||
### When to update the baseline
|
||||
|
||||
- **After intentional performance changes** (e.g., you optimized an endpoint — save the new, faster baseline)
|
||||
- **After infrastructure changes** (e.g., new CI runner, different hardware)
|
||||
- **After adding new benchmark tests** (the new tests need a baseline entry)
|
||||
|
||||
```bash
|
||||
# Update the baseline after intentional changes
|
||||
make benchmark-save
|
||||
```
|
||||
|
||||
### Version control
|
||||
|
||||
The `.benchmarks/` directory can be committed to version control so that CI pipelines can compare against a known-good baseline. However, since benchmark results are machine-specific, you may prefer to generate baselines in CI rather than committing local results.
|
||||
|
||||
---
|
||||
|
||||
## CI/CD Integration
|
||||
|
||||
Add benchmark checking to your CI pipeline to catch regressions on every PR:
|
||||
|
||||
```yaml
|
||||
# Example GitHub Actions step
|
||||
- name: Performance regression check
|
||||
run: |
|
||||
cd backend
|
||||
make benchmark-save # Create baseline from main branch
|
||||
# ... apply PR changes ...
|
||||
make benchmark-check # Compare PR against baseline
|
||||
```
|
||||
|
||||
A more robust approach:
|
||||
1. Save the baseline on the `main` branch after each merge
|
||||
2. On PR branches, run `make benchmark-check` against the `main` baseline
|
||||
3. The pipeline fails if any endpoint regresses beyond the 200% threshold
|
||||
|
||||
---
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "No benchmark baseline found" warning
|
||||
|
||||
```
|
||||
⚠️ No benchmark baseline found. Run 'make benchmark-save' first to create one.
|
||||
```
|
||||
|
||||
This means no baseline file exists yet. The command will auto-create one. Future runs of `make benchmark-check` will compare against it.
|
||||
|
||||
### Machine info mismatch warning
|
||||
|
||||
```
|
||||
WARNING: benchmark machine_info is different
|
||||
```
|
||||
|
||||
This is expected when comparing baselines generated on a different machine or OS. The comparison still works, but absolute numbers may differ. Re-save the baseline on the current machine if needed.
|
||||
|
||||
### High variance (large StdDev)
|
||||
|
||||
If StdDev is high relative to the Mean, results may be unreliable. Common causes:
|
||||
- System under load during benchmark run
|
||||
- Garbage collection interference
|
||||
- Thermal throttling
|
||||
|
||||
Try running benchmarks on an idle system or increasing `min_rounds` in `pyproject.toml`.
|
||||
|
||||
### Only 7 of 13 tests run
|
||||
|
||||
The async tests (`test_login_latency`, `test_get_current_user_latency`, `test_register_latency`, `test_token_refresh_latency`, `test_sessions_list_latency`, `test_user_profile_update_latency`) are skipped during `--benchmark-only` runs because they don't use the `benchmark` fixture. They run as part of the normal test suite (`make test`) with manual threshold assertions.
|
||||
@@ -8,6 +8,7 @@ This document outlines the coding standards and best practices for the FastAPI b
|
||||
- [Code Organization](#code-organization)
|
||||
- [Naming Conventions](#naming-conventions)
|
||||
- [Error Handling](#error-handling)
|
||||
- [Data Models and Migrations](#data-models-and-migrations)
|
||||
- [Database Operations](#database-operations)
|
||||
- [API Endpoints](#api-endpoints)
|
||||
- [Authentication & Security](#authentication--security)
|
||||
@@ -74,15 +75,14 @@ def create_user(db: Session, user_in: UserCreate) -> User:
|
||||
### 4. Code Formatting
|
||||
|
||||
Use automated formatters:
|
||||
- **Black**: Code formatting
|
||||
- **isort**: Import sorting
|
||||
- **flake8**: Linting
|
||||
- **Ruff**: Code formatting and linting (replaces Black, isort, flake8)
|
||||
- **pyright**: Static type checking
|
||||
|
||||
Run before committing:
|
||||
Run before committing (or use `make validate`):
|
||||
```bash
|
||||
black app tests
|
||||
isort app tests
|
||||
flake8 app tests
|
||||
uv run ruff format app tests
|
||||
uv run ruff check app tests
|
||||
uv run pyright app
|
||||
```
|
||||
|
||||
## Code Organization
|
||||
@@ -93,19 +93,17 @@ Follow the 5-layer architecture strictly:
|
||||
|
||||
```
|
||||
API Layer (routes/)
|
||||
↓ calls
|
||||
Dependencies (dependencies/)
|
||||
↓ injects
|
||||
↓ calls (via service injected from dependencies/services.py)
|
||||
Service Layer (services/)
|
||||
↓ calls
|
||||
CRUD Layer (crud/)
|
||||
Repository Layer (repositories/)
|
||||
↓ uses
|
||||
Models & Schemas (models/, schemas/)
|
||||
```
|
||||
|
||||
**Rules:**
|
||||
- Routes should NOT directly call CRUD operations (use services when business logic is needed)
|
||||
- CRUD operations should NOT contain business logic
|
||||
- Routes must NEVER import repositories directly — always use a service
|
||||
- Services call repositories; repositories contain only database operations
|
||||
- Models should NOT import from higher layers
|
||||
- Each layer should only depend on the layer directly below it
|
||||
|
||||
@@ -124,7 +122,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
# 3. Local application imports
|
||||
from app.api.dependencies.auth import get_current_user
|
||||
from app.crud import user_crud
|
||||
from app.api.dependencies.services import get_user_service
|
||||
from app.models.user import User
|
||||
from app.schemas.users import UserResponse, UserCreate
|
||||
```
|
||||
@@ -216,7 +214,7 @@ if not user:
|
||||
|
||||
### Error Handling Pattern
|
||||
|
||||
Always follow this pattern in CRUD operations (Async version):
|
||||
Always follow this pattern in repository operations (Async version):
|
||||
|
||||
```python
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError, DataError
|
||||
@@ -282,9 +280,154 @@ All error responses follow this structure:
|
||||
}
|
||||
```
|
||||
|
||||
## Data Models and Migrations
|
||||
|
||||
### Model Definition Best Practices
|
||||
|
||||
To ensure Alembic autogenerate works reliably without drift, follow these rules:
|
||||
|
||||
#### 1. Simple Indexes: Use Column-Level or `__table_args__`, Not Both
|
||||
|
||||
```python
|
||||
# ❌ BAD - Creates DUPLICATE indexes with different names
|
||||
class User(Base):
|
||||
role = Column(String(50), index=True) # Creates ix_users_role
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_user_role", "role"), # Creates ANOTHER index!
|
||||
)
|
||||
|
||||
# ✅ GOOD - Choose ONE approach
|
||||
class User(Base):
|
||||
role = Column(String(50)) # No index=True
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_user_role", "role"), # Single index with explicit name
|
||||
)
|
||||
|
||||
# ✅ ALSO GOOD - For simple single-column indexes
|
||||
class User(Base):
|
||||
role = Column(String(50), index=True) # Auto-named ix_users_role
|
||||
```
|
||||
|
||||
#### 2. Composite Indexes: Always Use `__table_args__`
|
||||
|
||||
```python
|
||||
class UserOrganization(Base):
|
||||
__tablename__ = "user_organizations"
|
||||
|
||||
user_id = Column(UUID, nullable=False)
|
||||
organization_id = Column(UUID, nullable=False)
|
||||
is_active = Column(Boolean, default=True, nullable=False, index=True)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_user_org_user_active", "user_id", "is_active"),
|
||||
Index("ix_user_org_org_active", "organization_id", "is_active"),
|
||||
)
|
||||
```
|
||||
|
||||
#### 3. Functional/Partial Indexes: Use `ix_perf_` Prefix
|
||||
|
||||
Alembic **cannot** auto-detect:
|
||||
- **Functional indexes**: `LOWER(column)`, `UPPER(column)`, expressions
|
||||
- **Partial indexes**: Indexes with `WHERE` clauses
|
||||
|
||||
**Solution**: Use the `ix_perf_` naming prefix. Any index with this prefix is automatically excluded from autogenerate by `env.py`.
|
||||
|
||||
```python
|
||||
# In migration file (NOT in model) - use ix_perf_ prefix:
|
||||
op.create_index(
|
||||
"ix_perf_users_email_lower", # <-- ix_perf_ prefix!
|
||||
"users",
|
||||
[sa.text("LOWER(email)")], # Functional
|
||||
postgresql_where=sa.text("deleted_at IS NULL"), # Partial
|
||||
)
|
||||
```
|
||||
|
||||
**No need to update `env.py`** - the prefix convention handles it automatically:
|
||||
|
||||
```python
|
||||
# env.py - already configured:
|
||||
def include_object(object, name, type_, reflected, compare_to):
|
||||
if type_ == "index" and name:
|
||||
if name.startswith("ix_perf_"): # Auto-excluded!
|
||||
return False
|
||||
return True
|
||||
```
|
||||
|
||||
**To add new performance indexes:**
|
||||
1. Create a new migration file
|
||||
2. Name your indexes with `ix_perf_` prefix
|
||||
3. Done - Alembic will ignore them automatically
|
||||
|
||||
#### 4. Use Correct Types
|
||||
|
||||
```python
|
||||
# ✅ GOOD - PostgreSQL-native types
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
|
||||
class User(Base):
|
||||
id = Column(UUID(as_uuid=True), primary_key=True)
|
||||
preferences = Column(JSONB) # Not JSON!
|
||||
|
||||
# ❌ BAD - Generic types may cause migration drift
|
||||
from sqlalchemy import JSON
|
||||
preferences = Column(JSON) # May detect as different from JSONB
|
||||
```
|
||||
|
||||
### Migration Workflow
|
||||
|
||||
#### Creating Migrations
|
||||
|
||||
```bash
|
||||
# Generate autogenerate migration:
|
||||
python migrate.py generate "Add new field"
|
||||
|
||||
# Or inside Docker:
|
||||
docker exec -w /app backend uv run alembic revision --autogenerate -m "Add new field"
|
||||
|
||||
# Apply migration:
|
||||
python migrate.py apply
|
||||
# Or: docker exec -w /app backend uv run alembic upgrade head
|
||||
```
|
||||
|
||||
#### Testing for Drift
|
||||
|
||||
After any model changes, verify no unintended drift:
|
||||
|
||||
```bash
|
||||
# Generate test migration
|
||||
docker exec -w /app backend uv run alembic revision --autogenerate -m "test_drift"
|
||||
|
||||
# Check the generated file - should be empty (just 'pass')
|
||||
# If it has operations, investigate why
|
||||
|
||||
# Delete test file
|
||||
rm backend/app/alembic/versions/*_test_drift.py
|
||||
```
|
||||
|
||||
#### Migration File Structure
|
||||
|
||||
```
|
||||
backend/app/alembic/versions/
|
||||
├── cbddc8aa6eda_initial_models.py # Auto-generated, tracks all models
|
||||
├── 0002_performance_indexes.py # Manual, functional/partial indexes
|
||||
└── __init__.py
|
||||
```
|
||||
|
||||
### Summary: What Goes Where
|
||||
|
||||
| Index Type | In Model? | Alembic Detects? | Where to Define |
|
||||
|------------|-----------|------------------|-----------------|
|
||||
| Simple column (`index=True`) | Yes | Yes | Column definition |
|
||||
| Composite (`col1, col2`) | Yes | Yes | `__table_args__` |
|
||||
| Unique composite | Yes | Yes | `__table_args__` with `unique=True` |
|
||||
| Functional (`LOWER(col)`) | No | No | Migration with `ix_perf_` prefix |
|
||||
| Partial (`WHERE ...`) | No | No | Migration with `ix_perf_` prefix |
|
||||
|
||||
## Database Operations
|
||||
|
||||
### Async CRUD Pattern
|
||||
### Async Repository Pattern
|
||||
|
||||
**IMPORTANT**: This application uses **async SQLAlchemy** with modern patterns for better performance and testability.
|
||||
|
||||
@@ -296,19 +439,19 @@ All error responses follow this structure:
|
||||
4. **Testability**: Easy to mock and test
|
||||
5. **Consistent Ordering**: Always order queries for pagination
|
||||
|
||||
### Use the Async CRUD Base Class
|
||||
### Use the Async Repository Base Class
|
||||
|
||||
Always inherit from `CRUDBase` for database operations:
|
||||
Always inherit from `RepositoryBase` for database operations:
|
||||
|
||||
```python
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from app.crud.base import CRUDBase
|
||||
from app.repositories.base import RepositoryBase
|
||||
from app.models.user import User
|
||||
from app.schemas.users import UserCreate, UserUpdate
|
||||
|
||||
class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
"""CRUD operations for User model."""
|
||||
class UserRepository(RepositoryBase[User, UserCreate, UserUpdate]):
|
||||
"""Repository for User model — database operations only."""
|
||||
|
||||
async def get_by_email(
|
||||
self,
|
||||
@@ -321,7 +464,7 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
user_crud = CRUDUser(User)
|
||||
user_repo = UserRepository(User)
|
||||
```
|
||||
|
||||
**Key Points:**
|
||||
@@ -330,6 +473,7 @@ user_crud = CRUDUser(User)
|
||||
- Use `await db.execute()` for queries
|
||||
- Use `.scalar_one_or_none()` instead of `.first()`
|
||||
- Use `T | None` instead of `Optional[T]`
|
||||
- Repository instances are used internally by services — never import them in routes
|
||||
|
||||
### Modern SQLAlchemy Patterns
|
||||
|
||||
@@ -417,13 +561,13 @@ async def create_user(
|
||||
The database session is automatically managed by FastAPI.
|
||||
Commit on success, rollback on error.
|
||||
"""
|
||||
return await user_crud.create(db, obj_in=user_in)
|
||||
return await user_service.create_user(db, obj_in=user_in)
|
||||
```
|
||||
|
||||
**Key Points:**
|
||||
- Route functions must be `async def`
|
||||
- Database parameter is `AsyncSession`
|
||||
- Always `await` CRUD operations
|
||||
- Always `await` repository operations
|
||||
|
||||
#### In Services (Multiple Operations)
|
||||
|
||||
@@ -436,12 +580,11 @@ async def complex_operation(
|
||||
"""
|
||||
Perform multiple database operations atomically.
|
||||
|
||||
The session automatically commits on success or rolls back on error.
|
||||
Services call repositories; commit/rollback is handled inside
|
||||
each repository method.
|
||||
"""
|
||||
user = await user_crud.create(db, obj_in=user_data)
|
||||
session = await session_crud.create(db, obj_in=session_data)
|
||||
|
||||
# Commit is handled by the route's dependency
|
||||
user = await user_repo.create(db, obj_in=user_data)
|
||||
session = await session_repo.create(db, obj_in=session_data)
|
||||
return user, session
|
||||
```
|
||||
|
||||
@@ -451,10 +594,10 @@ Prefer soft deletes over hard deletes for audit trails:
|
||||
|
||||
```python
|
||||
# Good - Soft delete (sets deleted_at)
|
||||
await user_crud.soft_delete(db, id=user_id)
|
||||
await user_repo.soft_delete(db, id=user_id)
|
||||
|
||||
# Acceptable only when required - Hard delete
|
||||
user_crud.remove(db, id=user_id)
|
||||
await user_repo.remove(db, id=user_id)
|
||||
```
|
||||
|
||||
### Query Patterns
|
||||
@@ -594,9 +737,10 @@ Always implement pagination for list endpoints:
|
||||
from app.schemas.common import PaginationParams, PaginatedResponse
|
||||
|
||||
@router.get("/users", response_model=PaginatedResponse[UserResponse])
|
||||
def list_users(
|
||||
async def list_users(
|
||||
pagination: PaginationParams = Depends(),
|
||||
db: Session = Depends(get_db)
|
||||
user_service: UserService = Depends(get_user_service),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
List all users with pagination.
|
||||
@@ -604,10 +748,8 @@ def list_users(
|
||||
Default page size: 20
|
||||
Maximum page size: 100
|
||||
"""
|
||||
users, total = user_crud.get_multi_with_total(
|
||||
db,
|
||||
skip=pagination.offset,
|
||||
limit=pagination.limit
|
||||
users, total = await user_service.get_users(
|
||||
db, skip=pagination.offset, limit=pagination.limit
|
||||
)
|
||||
return PaginatedResponse(data=users, pagination=pagination.create_meta(total))
|
||||
```
|
||||
@@ -670,19 +812,17 @@ def admin_route(
|
||||
pass
|
||||
|
||||
# Check ownership
|
||||
def delete_resource(
|
||||
async def delete_resource(
|
||||
resource_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
resource_service: ResourceService = Depends(get_resource_service),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
resource = resource_crud.get(db, id=resource_id)
|
||||
if not resource:
|
||||
raise NotFoundError("Resource not found")
|
||||
|
||||
if resource.user_id != current_user.id and not current_user.is_superuser:
|
||||
raise AuthorizationError("You can only delete your own resources")
|
||||
|
||||
resource_crud.remove(db, id=resource_id)
|
||||
# Service handles ownership check and raises appropriate errors
|
||||
await resource_service.delete_resource(
|
||||
db, resource_id=resource_id, user_id=current_user.id,
|
||||
is_superuser=current_user.is_superuser,
|
||||
)
|
||||
```
|
||||
|
||||
### Input Validation
|
||||
@@ -716,9 +856,9 @@ tests/
|
||||
├── api/ # Integration tests
|
||||
│ ├── test_users.py
|
||||
│ └── test_auth.py
|
||||
├── crud/ # Unit tests for CRUD
|
||||
├── models/ # Model tests
|
||||
└── services/ # Service tests
|
||||
├── repositories/ # Unit tests for repositories
|
||||
├── services/ # Unit tests for services
|
||||
└── models/ # Model tests
|
||||
```
|
||||
|
||||
### Async Testing with pytest-asyncio
|
||||
@@ -781,7 +921,7 @@ async def test_user(db_session: AsyncSession) -> User:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user(db_session: AsyncSession, test_user: User):
|
||||
"""Test retrieving a user by ID."""
|
||||
user = await user_crud.get(db_session, id=test_user.id)
|
||||
user = await user_repo.get(db_session, id=test_user.id)
|
||||
assert user is not None
|
||||
assert user.email == test_user.email
|
||||
```
|
||||
|
||||
@@ -334,14 +334,14 @@ def login(request: Request, credentials: OAuth2PasswordRequestForm):
|
||||
# ❌ WRONG - Returns password hash!
|
||||
@router.get("/users/{user_id}")
|
||||
def get_user(user_id: UUID, db: Session = Depends(get_db)) -> User:
|
||||
return user_crud.get(db, id=user_id) # Returns ORM model with ALL fields!
|
||||
return user_repo.get(db, id=user_id) # Returns ORM model with ALL fields!
|
||||
```
|
||||
|
||||
```python
|
||||
# ✅ CORRECT - Use response schema
|
||||
@router.get("/users/{user_id}", response_model=UserResponse)
|
||||
def get_user(user_id: UUID, db: Session = Depends(get_db)):
|
||||
user = user_crud.get(db, id=user_id)
|
||||
user = user_repo.get(db, id=user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
return user # Pydantic filters to only UserResponse fields
|
||||
@@ -506,8 +506,8 @@ def revoke_session(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
session = session_crud.get(db, id=session_id)
|
||||
session_crud.deactivate(db, session_id=session_id)
|
||||
session = session_repo.get(db, id=session_id)
|
||||
session_repo.deactivate(db, session_id=session_id)
|
||||
# BUG: User can revoke ANYONE'S session!
|
||||
return {"message": "Session revoked"}
|
||||
```
|
||||
@@ -520,7 +520,7 @@ def revoke_session(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
session = session_crud.get(db, id=session_id)
|
||||
session = session_repo.get(db, id=session_id)
|
||||
|
||||
if not session:
|
||||
raise NotFoundError("Session not found")
|
||||
@@ -529,7 +529,7 @@ def revoke_session(
|
||||
if session.user_id != current_user.id:
|
||||
raise AuthorizationError("You can only revoke your own sessions")
|
||||
|
||||
session_crud.deactivate(db, session_id=session_id)
|
||||
session_repo.deactivate(db, session_id=session_id)
|
||||
return {"message": "Session revoked"}
|
||||
```
|
||||
|
||||
@@ -616,7 +616,43 @@ def create_user(
|
||||
return user
|
||||
```
|
||||
|
||||
**Rule**: Add type hints to ALL functions. Use `mypy` to enforce type checking.
|
||||
**Rule**: Add type hints to ALL functions. Use `pyright` to enforce type checking (`make type-check`).
|
||||
|
||||
---
|
||||
|
||||
---
|
||||
|
||||
### ❌ PITFALL #19: Importing Repositories Directly in Routes
|
||||
|
||||
**Issue**: Routes should never call repositories directly. The layered architecture requires all business operations to go through the service layer.
|
||||
|
||||
```python
|
||||
# ❌ WRONG - Route bypasses service layer
|
||||
from app.repositories.session import session_repo
|
||||
|
||||
@router.get("/sessions/me")
|
||||
async def list_sessions(
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
return await session_repo.get_user_sessions(db, user_id=current_user.id)
|
||||
```
|
||||
|
||||
```python
|
||||
# ✅ CORRECT - Route calls service injected via dependency
|
||||
from app.api.dependencies.services import get_session_service
|
||||
from app.services.session_service import SessionService
|
||||
|
||||
@router.get("/sessions/me")
|
||||
async def list_sessions(
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
session_service: SessionService = Depends(get_session_service),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
return await session_service.get_user_sessions(db, user_id=current_user.id)
|
||||
```
|
||||
|
||||
**Rule**: Routes import from `app.api.dependencies.services`, never from `app.repositories.*`. Services are the only callers of repositories.
|
||||
|
||||
---
|
||||
|
||||
@@ -649,6 +685,11 @@ Use this checklist to catch issues before code review:
|
||||
- [ ] Resource ownership verification
|
||||
- [ ] CORS configured (no wildcards in production)
|
||||
|
||||
### Architecture
|
||||
- [ ] Routes never import repositories directly (only services)
|
||||
- [ ] Services call repositories; repositories call database only
|
||||
- [ ] New service registered in `app/api/dependencies/services.py`
|
||||
|
||||
### Python
|
||||
- [ ] Use `==` not `is` for value comparison
|
||||
- [ ] No mutable default arguments
|
||||
@@ -661,21 +702,18 @@ Use this checklist to catch issues before code review:
|
||||
|
||||
### Pre-commit Checks
|
||||
|
||||
Add these to your development workflow:
|
||||
Add these to your development workflow (or use `make validate`):
|
||||
|
||||
```bash
|
||||
# Format code
|
||||
black app tests
|
||||
isort app tests
|
||||
# Format + lint (Ruff replaces Black, isort, flake8)
|
||||
uv run ruff format app tests
|
||||
uv run ruff check app tests
|
||||
|
||||
# Type checking
|
||||
mypy app --strict
|
||||
|
||||
# Linting
|
||||
flake8 app tests
|
||||
uv run pyright app
|
||||
|
||||
# Run tests
|
||||
pytest --cov=app --cov-report=term-missing
|
||||
IS_TEST=True uv run pytest --cov=app --cov-report=term-missing
|
||||
|
||||
# Check coverage (should be 80%+)
|
||||
coverage report --fail-under=80
|
||||
@@ -693,6 +731,6 @@ Add new entries when:
|
||||
|
||||
---
|
||||
|
||||
**Last Updated**: 2025-10-31
|
||||
**Issues Cataloged**: 18 common pitfalls
|
||||
**Last Updated**: 2026-02-28
|
||||
**Issues Cataloged**: 19 common pitfalls
|
||||
**Remember**: This document exists because these issues HAVE occurred. Don't skip it.
|
||||
|
||||
348
backend/docs/E2E_TESTING.md
Normal file
348
backend/docs/E2E_TESTING.md
Normal file
@@ -0,0 +1,348 @@
|
||||
# Backend E2E Testing Guide
|
||||
|
||||
End-to-end testing infrastructure using **Testcontainers** (real PostgreSQL) and **Schemathesis** (OpenAPI contract testing).
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Quick Start](#quick-start)
|
||||
- [Requirements](#requirements)
|
||||
- [How It Works](#how-it-works)
|
||||
- [Test Organization](#test-organization)
|
||||
- [Writing E2E Tests](#writing-e2e-tests)
|
||||
- [Running Tests](#running-tests)
|
||||
- [Troubleshooting](#troubleshooting)
|
||||
- [CI/CD Integration](#cicd-integration)
|
||||
|
||||
---
|
||||
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
# 1. Install E2E dependencies
|
||||
make install-e2e
|
||||
|
||||
# 2. Ensure Docker is running
|
||||
make check-docker
|
||||
|
||||
# 3. Run E2E tests
|
||||
make test-e2e
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Requirements
|
||||
|
||||
### Docker
|
||||
|
||||
E2E tests use Testcontainers to spin up real PostgreSQL containers. Docker must be running:
|
||||
|
||||
- **macOS/Windows**: Docker Desktop
|
||||
- **Linux**: Docker Engine (`sudo systemctl start docker`)
|
||||
|
||||
### Dependencies
|
||||
|
||||
E2E tests require additional packages beyond the standard dev dependencies:
|
||||
|
||||
```bash
|
||||
# Install E2E dependencies
|
||||
make install-e2e
|
||||
|
||||
# Or manually:
|
||||
uv sync --extra dev --extra e2e
|
||||
```
|
||||
|
||||
This installs:
|
||||
- `testcontainers[postgres]>=4.0.0` - Docker container management
|
||||
- `schemathesis>=3.30.0` - OpenAPI contract testing
|
||||
|
||||
---
|
||||
|
||||
## How It Works
|
||||
|
||||
### Testcontainers
|
||||
|
||||
Testcontainers automatically manages Docker containers for tests:
|
||||
|
||||
1. **Session-scoped container**: A single PostgreSQL 17 container starts once per test session
|
||||
2. **Function-scoped isolation**: Each test gets fresh tables (drop + recreate)
|
||||
3. **Automatic cleanup**: Container is destroyed when tests complete
|
||||
|
||||
This approach catches bugs that SQLite-based tests miss:
|
||||
- PostgreSQL-specific SQL behavior
|
||||
- Real constraint violations
|
||||
- Actual transaction semantics
|
||||
- JSONB column behavior
|
||||
|
||||
### Schemathesis
|
||||
|
||||
Schemathesis generates test cases from your OpenAPI schema:
|
||||
|
||||
1. **Schema loading**: Reads `/api/v1/openapi.json` from your FastAPI app
|
||||
2. **Test generation**: Creates test cases for each endpoint
|
||||
3. **Response validation**: Verifies responses match documented schema
|
||||
|
||||
This catches:
|
||||
- Undocumented response codes
|
||||
- Schema mismatches (wrong types, missing fields)
|
||||
- Edge cases in input validation
|
||||
|
||||
---
|
||||
|
||||
## Test Organization
|
||||
|
||||
```
|
||||
backend/tests/
|
||||
├── e2e/ # E2E tests (PostgreSQL, Docker required)
|
||||
│ ├── __init__.py
|
||||
│ ├── conftest.py # Testcontainers fixtures
|
||||
│ ├── test_api_contracts.py # Schemathesis schema tests
|
||||
│ └── test_database_workflows.py # PostgreSQL workflow tests
|
||||
│
|
||||
├── api/ # Integration tests (SQLite, fast)
|
||||
├── repositories/ # Repository unit tests
|
||||
└── conftest.py # Standard fixtures
|
||||
```
|
||||
|
||||
### Test Markers
|
||||
|
||||
Tests use pytest markers for filtering:
|
||||
|
||||
| Marker | Description |
|
||||
|--------|-------------|
|
||||
| `@pytest.mark.e2e` | End-to-end test requiring Docker |
|
||||
| `@pytest.mark.postgres` | PostgreSQL-specific test |
|
||||
| `@pytest.mark.schemathesis` | Schemathesis schema test |
|
||||
|
||||
---
|
||||
|
||||
## Writing E2E Tests
|
||||
|
||||
### Basic E2E Test
|
||||
|
||||
```python
|
||||
import pytest
|
||||
from uuid import uuid4
|
||||
|
||||
@pytest.mark.e2e
|
||||
@pytest.mark.postgres
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_workflow(e2e_client):
|
||||
"""Test user registration with real PostgreSQL."""
|
||||
email = f"test-{uuid4().hex[:8]}@example.com"
|
||||
|
||||
response = await e2e_client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": email,
|
||||
"password": "SecurePassword123!",
|
||||
"first_name": "Test",
|
||||
"last_name": "User",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code in [200, 201]
|
||||
assert response.json()["email"] == email
|
||||
```
|
||||
|
||||
### Available Fixtures
|
||||
|
||||
| Fixture | Scope | Description |
|
||||
|---------|-------|-------------|
|
||||
| `postgres_container` | session | Raw Testcontainers PostgreSQL container |
|
||||
| `async_postgres_url` | session | Asyncpg-compatible connection URL |
|
||||
| `e2e_db_session` | function | SQLAlchemy AsyncSession with fresh tables |
|
||||
| `e2e_client` | function | httpx AsyncClient connected to real DB |
|
||||
|
||||
### Schemathesis Test
|
||||
|
||||
```python
|
||||
import pytest
|
||||
import schemathesis
|
||||
from hypothesis import settings, Phase
|
||||
|
||||
from app.main import app
|
||||
|
||||
schema = schemathesis.from_asgi("/api/v1/openapi.json", app=app)
|
||||
|
||||
@pytest.mark.e2e
|
||||
@pytest.mark.schemathesis
|
||||
@schema.parametrize(endpoint="/api/v1/auth/register")
|
||||
@settings(max_examples=20)
|
||||
def test_registration_schema(case):
|
||||
"""Test registration endpoint conforms to schema."""
|
||||
response = case.call_asgi()
|
||||
case.validate_response(response)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Running Tests
|
||||
|
||||
### Commands
|
||||
|
||||
```bash
|
||||
# Run all E2E tests
|
||||
make test-e2e
|
||||
|
||||
# Run only Schemathesis schema tests
|
||||
make test-e2e-schema
|
||||
|
||||
# Run all tests (unit + integration + E2E)
|
||||
make test-all
|
||||
|
||||
# Check Docker availability
|
||||
make check-docker
|
||||
```
|
||||
|
||||
### Direct pytest
|
||||
|
||||
```bash
|
||||
# All E2E tests
|
||||
IS_TEST=True PYTHONPATH=. uv run pytest tests/e2e/ -v
|
||||
|
||||
# Only PostgreSQL tests
|
||||
IS_TEST=True PYTHONPATH=. uv run pytest tests/e2e/ -v -m postgres
|
||||
|
||||
# Only Schemathesis tests
|
||||
IS_TEST=True PYTHONPATH=. uv run pytest tests/e2e/ -v -m schemathesis
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Docker Not Running
|
||||
|
||||
**Error:**
|
||||
```
|
||||
Docker is not running!
|
||||
E2E tests require Docker to be running.
|
||||
```
|
||||
|
||||
**Solution:**
|
||||
```bash
|
||||
# macOS/Windows
|
||||
# Open Docker Desktop
|
||||
|
||||
# Linux
|
||||
sudo systemctl start docker
|
||||
```
|
||||
|
||||
### Testcontainers Not Installed
|
||||
|
||||
**Error:**
|
||||
```
|
||||
SKIPPED: testcontainers not installed - run: make install-e2e
|
||||
```
|
||||
|
||||
**Solution:**
|
||||
```bash
|
||||
make install-e2e
|
||||
# Or: uv sync --extra dev --extra e2e
|
||||
```
|
||||
|
||||
### Container Startup Timeout
|
||||
|
||||
**Error:**
|
||||
```
|
||||
testcontainers.core.waiting_utils.UnexpectedResponse
|
||||
```
|
||||
|
||||
**Solutions:**
|
||||
1. Increase Docker resources (memory, CPU)
|
||||
2. Pull the image manually: `docker pull postgres:17-alpine`
|
||||
3. Check Docker daemon logs: `docker logs`
|
||||
|
||||
### Port Conflicts
|
||||
|
||||
**Error:**
|
||||
```
|
||||
Error starting container: port is already allocated
|
||||
```
|
||||
|
||||
**Solution:**
|
||||
Testcontainers uses random ports, so conflicts are rare. If occurring:
|
||||
1. Stop other PostgreSQL containers: `docker stop $(docker ps -q)`
|
||||
2. Check for orphaned containers: `docker container prune`
|
||||
|
||||
### Ryuk/Reaper Port 8080 Issues
|
||||
|
||||
**Error:**
|
||||
```
|
||||
ConnectionError: Port mapping for container ... and port 8080 is not available
|
||||
```
|
||||
|
||||
**Solution:**
|
||||
This is related to the Testcontainers Reaper (Ryuk) which handles automatic cleanup.
|
||||
The `conftest.py` automatically disables Ryuk to avoid this issue. If you still encounter
|
||||
this error, ensure you're using the latest conftest.py or set the environment variable:
|
||||
|
||||
```bash
|
||||
export TESTCONTAINERS_RYUK_DISABLED=true
|
||||
```
|
||||
|
||||
### Parallel Test Execution Issues
|
||||
|
||||
**Error:**
|
||||
```
|
||||
ScopeMismatch: ... cannot use a higher-scoped fixture 'postgres_container'
|
||||
```
|
||||
|
||||
**Solution:**
|
||||
E2E tests must run sequentially (not in parallel) because they share a session-scoped
|
||||
PostgreSQL container. The Makefile commands use `-n 0` to disable parallel execution.
|
||||
If running pytest directly, add `-n 0`:
|
||||
|
||||
```bash
|
||||
IS_TEST=True PYTHONPATH=. uv run pytest tests/e2e/ -v -n 0
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## CI/CD Integration
|
||||
|
||||
### GitHub Actions
|
||||
|
||||
A workflow template is provided at `.github/workflows/backend-e2e-tests.yml.template`.
|
||||
|
||||
To enable:
|
||||
1. Rename to `backend-e2e-tests.yml`
|
||||
2. Push to repository
|
||||
|
||||
The workflow:
|
||||
- Runs on pushes to `main`/`develop` affecting `backend/`
|
||||
- Uses `continue-on-error: true` (E2E failures don't block merge)
|
||||
- Caches uv dependencies for speed
|
||||
|
||||
### Local CI Simulation
|
||||
|
||||
```bash
|
||||
# Run what CI runs
|
||||
make test-all
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Best Practices
|
||||
|
||||
### DO
|
||||
|
||||
- Use unique emails per test: `f"test-{uuid4().hex[:8]}@example.com"`
|
||||
- Mark tests with appropriate markers: `@pytest.mark.e2e`
|
||||
- Keep E2E tests focused on critical workflows
|
||||
- Use `e2e_client` fixture for most tests
|
||||
|
||||
### DON'T
|
||||
|
||||
- Share state between tests (each test gets fresh tables)
|
||||
- Test every endpoint in E2E (use unit tests for edge cases)
|
||||
- Skip the `IS_TEST=True` environment variable
|
||||
- Run E2E tests without Docker
|
||||
|
||||
---
|
||||
|
||||
## Further Reading
|
||||
|
||||
- [Testcontainers Documentation](https://testcontainers.com/guides/getting-started-with-testcontainers-for-python/)
|
||||
- [Schemathesis Documentation](https://schemathesis.readthedocs.io/)
|
||||
- [pytest-asyncio Documentation](https://pytest-asyncio.readthedocs.io/)
|
||||
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user