Compare commits
308 Commits
80c26c3df2
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a94e29d99c | ||
|
|
81e48c73ca | ||
|
|
a3f78dc801 | ||
|
|
07309013d7 | ||
|
|
846fc31190 | ||
|
|
ff7a67cb58 | ||
|
|
0760a8284d | ||
|
|
ce4d0c7b0d | ||
|
|
4ceb8ad98c | ||
|
|
f8aafb250d | ||
|
|
4385d20ca6 | ||
|
|
1a36907f10 | ||
|
|
0553a1fc53 | ||
|
|
57e969ed67 | ||
|
|
68275b1dd3 | ||
|
|
80d2dc0cb2 | ||
|
|
a8aa416ecb | ||
|
|
4c6bf55bcc | ||
|
|
98b455fdc3 | ||
|
|
0646c96b19 | ||
|
|
62afb328fe | ||
|
|
b9a746bc16 | ||
|
|
de8e18e97d | ||
|
|
a3e557d022 | ||
|
|
4e357db25d | ||
|
|
568aad3673 | ||
|
|
ddcf926158 | ||
|
|
865eeece58 | ||
|
|
05fb3612f9 | ||
|
|
1b2e7dde35 | ||
|
|
29074f26a6 | ||
|
|
77ed190310 | ||
|
|
2bbe925cef | ||
|
|
4a06b96b2e | ||
|
|
088c1725b0 | ||
|
|
7ba1767cea | ||
|
|
c63b6a4f76 | ||
|
|
803b720530 | ||
|
|
7ff00426f2 | ||
|
|
b3f0dd4005 | ||
|
|
707315facd | ||
|
|
38114b79f9 | ||
|
|
1cb3658369 | ||
|
|
dc875c5c95 | ||
|
|
0ea428b718 | ||
|
|
400d6f6f75 | ||
|
|
7716468d72 | ||
|
|
48f052200f | ||
|
|
fbb030da69 | ||
|
|
d49f819469 | ||
|
|
507f2e9c00 | ||
|
|
c0b253a010 | ||
|
|
fcbcff99e9 | ||
|
|
b49678b7df | ||
|
|
aeed9dfdbc | ||
|
|
13f617828b | ||
|
|
84e0a7fe81 | ||
|
|
063a35e698 | ||
|
|
a2246fb6e1 | ||
|
|
16ee4e0cb3 | ||
|
|
e6792c2d6c | ||
|
|
1d20b149dc | ||
|
|
570848cc2d | ||
|
|
6b970765ba | ||
|
|
e79215b4de | ||
|
|
3bf28aa121 | ||
|
|
cda9810a7e | ||
|
|
d47bd34a92 | ||
|
|
5b0ae54365 | ||
|
|
372af25aaa | ||
|
|
d0b717a128 | ||
|
|
9d40aece30 | ||
|
|
487c8a3863 | ||
|
|
8659e884e9 | ||
|
|
a05def5906 | ||
|
|
9f655913b1 | ||
|
|
13abd159fa | ||
|
|
acfe59c8b3 | ||
|
|
2e4700ae9b | ||
|
|
8c83e2a699 | ||
|
|
9b6356b0db | ||
|
|
a410586cfb | ||
|
|
0e34cab921 | ||
|
|
3cf3858fca | ||
|
|
db0c555041 | ||
|
|
51ad80071a | ||
|
|
d730ab7526 | ||
|
|
b218be9318 | ||
|
|
e6813c87c3 | ||
|
|
210204eb7a | ||
|
|
6ad4cda3f4 | ||
|
|
54ceaa6f5d | ||
|
|
34e7f69465 | ||
|
|
8fdbc2b359 | ||
|
|
28b1cc6e48 | ||
|
|
5a21847382 | ||
|
|
444d495f83 | ||
|
|
a943f79ce7 | ||
|
|
f54905abd0 | ||
|
|
0105e765b3 | ||
|
|
bb06b450fd | ||
|
|
c1d6a04276 | ||
|
|
d7b333385d | ||
|
|
f02320e57c | ||
|
|
3ec589293c | ||
|
|
7b1bea2966 | ||
|
|
da7b6b5bfa | ||
|
|
7aa63d79df | ||
|
|
333c9c40af | ||
|
|
0b192ce030 | ||
|
|
da021d0640 | ||
|
|
d1b47006f4 | ||
|
|
a73d3c7d3e | ||
|
|
55ae92c460 | ||
|
|
fe6a98c379 | ||
|
|
b7c1191335 | ||
|
|
68e04a911a | ||
|
|
3001484948 | ||
|
|
c9f4772196 | ||
|
|
14e5839476 | ||
|
|
228d12b379 | ||
|
|
46ff95d8b9 | ||
|
|
235c309e4e | ||
|
|
5c47be2ee5 | ||
|
|
e9f787040a | ||
|
|
2532d1ac3c | ||
|
|
1f45ca2b50 | ||
|
|
8a343580ce | ||
|
|
424ca166b8 | ||
|
|
c589b565f0 | ||
|
|
a5c671c133 | ||
|
|
d8bde80d4f | ||
|
|
35efa24ce5 | ||
|
|
96df7edf88 | ||
|
|
464a6140c4 | ||
|
|
b2f3ec8f25 | ||
|
|
c8f90e9e8c | ||
|
|
2169618bc8 | ||
|
|
a84fd11cc7 | ||
|
|
6824fd7c33 | ||
|
|
d5eb855ae1 | ||
|
|
a6a10855fa | ||
|
|
bf95aab7ec | ||
|
|
214d0b1765 | ||
|
|
b630559e0b | ||
|
|
fe289228e1 | ||
|
|
63c171f83e | ||
|
|
e02329b734 | ||
|
|
e1d5914e7f | ||
|
|
d6a06e45ec | ||
|
|
e74830bec5 | ||
|
|
51ef4632e6 | ||
|
|
b749f62abd | ||
|
|
3b28b5cf97 | ||
|
|
652fb6b180 | ||
|
|
6b556431d3 | ||
|
|
f8b77200f0 | ||
|
|
f99de75dc6 | ||
|
|
4420756741 | ||
|
|
dde4a5979d | ||
|
|
2696f44198 | ||
|
|
9dc1a70038 | ||
|
|
234c197ee1 | ||
|
|
ff758f5d10 | ||
|
|
da1f4e365a | ||
|
|
01e0b9ab21 | ||
|
|
96ae9295d3 | ||
|
|
94ebda084b | ||
|
|
5f3a098403 | ||
|
|
7556353078 | ||
|
|
f22f87250c | ||
|
|
91bc4f190d | ||
|
|
c10c1d1c39 | ||
|
|
dde091138e | ||
|
|
9c72fe87f9 | ||
|
|
abce06ad67 | ||
|
|
d0f1a7cc4b | ||
|
|
f9f58b5f27 | ||
|
|
67860c68e3 | ||
|
|
11a78dfcc3 | ||
|
|
402c041d15 | ||
|
|
e64b0e8085 | ||
|
|
df8ef98857 | ||
|
|
9ffd61527c | ||
|
|
63650f563d | ||
|
|
f23fdb974a | ||
|
|
7c98ceb5b9 | ||
|
|
26d43ff9e1 | ||
|
|
4bf34ea287 | ||
|
|
852c7eceff | ||
|
|
532577f36c | ||
|
|
9843cf8218 | ||
|
|
2ee48bf3fa | ||
|
|
a36c1b61bb | ||
|
|
0cba8ea62a | ||
|
|
01b406bca7 | ||
|
|
77b914ffa2 | ||
|
|
10ff6a1a96 | ||
|
|
88dc81735b | ||
|
|
e81f54564b | ||
|
|
f7133807fc | ||
|
|
388ca08724 | ||
|
|
54a14047be | ||
|
|
65f209c679 | ||
|
|
64a4b3fb11 | ||
|
|
1c7f34c078 | ||
|
|
fe5d152cee | ||
|
|
15f522b9b1 | ||
|
|
fded54e61a | ||
|
|
77594e478d | ||
|
|
ac3fac0426 | ||
|
|
0e554ef35e | ||
|
|
aedc770afb | ||
|
|
54c32bf97f | ||
|
|
1b9854d412 | ||
|
|
911d4a594e | ||
|
|
86d8e1cace | ||
|
|
2c05f17ec5 | ||
|
|
68e28e4c76 | ||
|
|
6d1b730ae7 | ||
|
|
29f98f059b | ||
|
|
b181182c3b | ||
|
|
92b7de352c | ||
|
|
aff76e3a69 | ||
|
|
13771c5354 | ||
|
|
c3c6a18dd1 | ||
|
|
68e7ebc4e0 | ||
|
|
df299e3e45 | ||
|
|
8e497770c9 | ||
|
|
58b761106b | ||
|
|
e734acf31d | ||
|
|
76d36e1b12 | ||
|
|
6e95469d99 | ||
|
|
6d9b98943c | ||
|
|
30cbaf8ad5 | ||
|
|
13f830ed6d | ||
|
|
c051bbf0aa | ||
|
|
b39b7b4c94 | ||
|
|
9f88736d13 | ||
|
|
ccd535cf0e | ||
|
|
30dca45097 | ||
|
|
a460e0e4f2 | ||
|
|
08511ae07b | ||
|
|
1439380126 | ||
|
|
378b04d505 | ||
|
|
af260e4748 | ||
|
|
30f0ec5a64 | ||
|
|
04110cbf1c | ||
|
|
461d3caf31 | ||
|
|
789a76071d | ||
|
|
4536c607eb | ||
|
|
bf04c98408 | ||
|
|
4885df80a7 | ||
|
|
29ff97f726 | ||
|
|
406c3bcc82 | ||
|
|
1aab73cb72 | ||
|
|
f77f2700f2 | ||
|
|
f354ec610b | ||
|
|
e25b010b57 | ||
|
|
0b0d1d2b06 | ||
|
|
bc53504cbf | ||
|
|
d75a8de91b | ||
|
|
a82e5ea0e6 | ||
|
|
189ad948ac | ||
|
|
e2a8656f81 | ||
|
|
ce5ed70dd2 | ||
|
|
230210f3db | ||
|
|
a9e972d583 | ||
|
|
a95b25cab8 | ||
|
|
976fd1d4ad | ||
|
|
293fbcb27e | ||
|
|
f117960323 | ||
|
|
a1b11fadcb | ||
|
|
b8d3248a48 | ||
|
|
a062daddc5 | ||
|
|
efcf10f9aa | ||
|
|
ee938ce6a6 | ||
|
|
035e6af446 | ||
|
|
c79b76be41 | ||
|
|
61173d0dc1 | ||
|
|
ea544ecbac | ||
|
|
3ad48843e4 | ||
|
|
544be2bea4 | ||
|
|
3fe5d301f8 | ||
|
|
819f3ba963 | ||
|
|
9ae89a20b3 | ||
|
|
c58cce358f | ||
|
|
38eb5313fc | ||
|
|
4de440ed2d | ||
|
|
cc98a76e24 | ||
|
|
925950d58e | ||
|
|
dbb05289b2 | ||
|
|
f4be8b56f0 | ||
|
|
31e2109278 | ||
|
|
b4866f9100 | ||
|
|
092a82ee07 | ||
|
|
92a8699479 | ||
|
|
8a7a3b9521 | ||
|
|
6d811747ee | ||
|
|
76023694f8 | ||
|
|
cf5bb41c17 | ||
|
|
1f15ee6db3 | ||
|
|
26ff08d9f9 | ||
|
|
19ecd04a41 | ||
|
|
9554782202 | ||
|
|
59f8c8076b | ||
|
|
e8156b751e | ||
|
|
86f67a925c |
55
.env.demo
Normal file
55
.env.demo
Normal file
@@ -0,0 +1,55 @@
|
||||
# Common settings
|
||||
PROJECT_NAME=App
|
||||
VERSION=1.0.0
|
||||
|
||||
# Database settings
|
||||
POSTGRES_USER=postgres
|
||||
POSTGRES_PASSWORD=postgres
|
||||
POSTGRES_DB=app
|
||||
POSTGRES_HOST=db
|
||||
POSTGRES_PORT=5432
|
||||
DATABASE_URL=postgresql://postgres:postgres@db:5432/app
|
||||
|
||||
# Backend settings
|
||||
BACKEND_PORT=8000
|
||||
# CRITICAL: Generate a secure SECRET_KEY for production!
|
||||
# Generate with: python -c 'import secrets; print(secrets.token_urlsafe(32))'
|
||||
# Must be at least 32 characters
|
||||
SECRET_KEY=demo_secret_key_for_testing_only_do_not_use_in_prod
|
||||
ENVIRONMENT=development
|
||||
DEMO_MODE=true
|
||||
DEBUG=true
|
||||
BACKEND_CORS_ORIGINS=["http://localhost:3000"]
|
||||
FIRST_SUPERUSER_EMAIL=admin@example.com
|
||||
# IMPORTANT: Use a strong password (min 12 chars, mixed case, digits)
|
||||
# Default weak passwords like 'Admin123' are rejected
|
||||
FIRST_SUPERUSER_PASSWORD=Admin123!
|
||||
|
||||
# OAuth Configuration (Social Login)
|
||||
# Set OAUTH_ENABLED=true and configure at least one provider
|
||||
OAUTH_ENABLED=false
|
||||
OAUTH_AUTO_LINK_BY_EMAIL=true
|
||||
|
||||
# Google OAuth (from Google Cloud Console > APIs & Services > Credentials)
|
||||
# https://console.cloud.google.com/apis/credentials
|
||||
# OAUTH_GOOGLE_CLIENT_ID=your-google-client-id.apps.googleusercontent.com
|
||||
# OAUTH_GOOGLE_CLIENT_SECRET=your-google-client-secret
|
||||
|
||||
# GitHub OAuth (from GitHub > Settings > Developer settings > OAuth Apps)
|
||||
# https://github.com/settings/developers
|
||||
# OAUTH_GITHUB_CLIENT_ID=your-github-client-id
|
||||
# OAUTH_GITHUB_CLIENT_SECRET=your-github-client-secret
|
||||
|
||||
# OAuth Provider Mode (Authorization Server for MCP/third-party clients)
|
||||
# Set OAUTH_PROVIDER_ENABLED=true to act as an OAuth 2.0 Authorization Server
|
||||
OAUTH_PROVIDER_ENABLED=true
|
||||
# IMPORTANT: Must be HTTPS in production!
|
||||
OAUTH_ISSUER=http://localhost:8000
|
||||
|
||||
# Frontend settings
|
||||
FRONTEND_PORT=3000
|
||||
FRONTEND_URL=http://localhost:3000
|
||||
NEXT_PUBLIC_API_URL=http://localhost:8000
|
||||
NEXT_PUBLIC_API_BASE_URL=http://localhost:8000
|
||||
NEXT_PUBLIC_APP_URL=http://localhost:3000
|
||||
NODE_ENV=development
|
||||
@@ -5,7 +5,7 @@ VERSION=1.0.0
|
||||
# Database settings
|
||||
POSTGRES_USER=postgres
|
||||
POSTGRES_PASSWORD=postgres
|
||||
POSTGRES_DB=App
|
||||
POSTGRES_DB=app
|
||||
POSTGRES_HOST=db
|
||||
POSTGRES_PORT=5432
|
||||
DATABASE_URL=postgresql://${POSTGRES_USER}:${POSTGRES_PASSWORD}@${POSTGRES_HOST}:${POSTGRES_PORT}/${POSTGRES_DB}
|
||||
@@ -17,6 +17,7 @@ BACKEND_PORT=8000
|
||||
# Must be at least 32 characters
|
||||
SECRET_KEY=your_secret_key_here_REPLACE_WITH_GENERATED_KEY_32_CHARS_MIN
|
||||
ENVIRONMENT=development
|
||||
DEMO_MODE=false
|
||||
DEBUG=true
|
||||
BACKEND_CORS_ORIGINS=["http://localhost:3000"]
|
||||
FIRST_SUPERUSER_EMAIL=admin@example.com
|
||||
@@ -24,7 +25,31 @@ FIRST_SUPERUSER_EMAIL=admin@example.com
|
||||
# Default weak passwords like 'Admin123' are rejected
|
||||
FIRST_SUPERUSER_PASSWORD=YourStrongPassword123!
|
||||
|
||||
# OAuth Configuration (Social Login)
|
||||
# Set OAUTH_ENABLED=true and configure at least one provider
|
||||
OAUTH_ENABLED=false
|
||||
OAUTH_AUTO_LINK_BY_EMAIL=true
|
||||
|
||||
# Google OAuth (from Google Cloud Console > APIs & Services > Credentials)
|
||||
# https://console.cloud.google.com/apis/credentials
|
||||
# OAUTH_GOOGLE_CLIENT_ID=your-google-client-id.apps.googleusercontent.com
|
||||
# OAUTH_GOOGLE_CLIENT_SECRET=your-google-client-secret
|
||||
|
||||
# GitHub OAuth (from GitHub > Settings > Developer settings > OAuth Apps)
|
||||
# https://github.com/settings/developers
|
||||
# OAUTH_GITHUB_CLIENT_ID=your-github-client-id
|
||||
# OAUTH_GITHUB_CLIENT_SECRET=your-github-client-secret
|
||||
|
||||
# OAuth Provider Mode (Authorization Server for MCP/third-party clients)
|
||||
# Set OAUTH_PROVIDER_ENABLED=true to act as an OAuth 2.0 Authorization Server
|
||||
OAUTH_PROVIDER_ENABLED=false
|
||||
# IMPORTANT: Must be HTTPS in production!
|
||||
OAUTH_ISSUER=http://localhost:8000
|
||||
|
||||
# Frontend settings
|
||||
FRONTEND_PORT=3000
|
||||
FRONTEND_URL=http://localhost:3000
|
||||
NEXT_PUBLIC_API_URL=http://localhost:8000
|
||||
NEXT_PUBLIC_API_BASE_URL=http://localhost:8000
|
||||
NEXT_PUBLIC_APP_URL=http://localhost:3000
|
||||
NODE_ENV=development
|
||||
|
||||
108
.github/workflows/README.md
vendored
Normal file
108
.github/workflows/README.md
vendored
Normal file
@@ -0,0 +1,108 @@
|
||||
# GitHub Actions Workflows
|
||||
|
||||
This directory contains CI/CD workflow templates for automated testing and deployment.
|
||||
|
||||
## 🚀 Quick Setup
|
||||
|
||||
To enable CI/CD workflows:
|
||||
|
||||
1. **Rename template files** by removing the `.template` extension:
|
||||
```bash
|
||||
mv backend-tests.yml.template backend-tests.yml
|
||||
mv frontend-tests.yml.template frontend-tests.yml
|
||||
mv e2e-tests.yml.template e2e-tests.yml
|
||||
```
|
||||
|
||||
2. **Set up Codecov** (optional, for coverage badges):
|
||||
- Sign up at https://codecov.io
|
||||
- Add your repository
|
||||
- Get your `CODECOV_TOKEN`
|
||||
- Add it to GitHub repository secrets
|
||||
|
||||
3. **Update README badges**:
|
||||
Replace the static badges in the main README.md with:
|
||||
```markdown
|
||||
[](https://github.com/YOUR_ORG/YOUR_REPO/actions/workflows/backend-tests.yml)
|
||||
[](https://codecov.io/gh/YOUR_ORG/YOUR_REPO)
|
||||
[](https://github.com/YOUR_ORG/YOUR_REPO/actions/workflows/frontend-tests.yml)
|
||||
[](https://codecov.io/gh/YOUR_ORG/YOUR_REPO)
|
||||
[](https://github.com/YOUR_ORG/YOUR_REPO/actions/workflows/e2e-tests.yml)
|
||||
```
|
||||
|
||||
## 📋 Workflow Descriptions
|
||||
|
||||
### `backend-tests.yml`
|
||||
- Runs on: Push to main/develop, PRs affecting backend code
|
||||
- Tests: Backend unit and integration tests
|
||||
- Coverage: Uploads to Codecov
|
||||
- Database: Spins up PostgreSQL service
|
||||
|
||||
### `frontend-tests.yml`
|
||||
- Runs on: Push to main/develop, PRs affecting frontend code
|
||||
- Tests: Frontend unit tests (Jest)
|
||||
- Coverage: Uploads to Codecov
|
||||
- Fast: Uses bun cache
|
||||
|
||||
### `e2e-tests.yml`
|
||||
- Runs on: All pushes and PRs
|
||||
- Tests: End-to-end tests (Playwright)
|
||||
- Coverage: Full stack integration
|
||||
- Artifacts: Saves test reports for 30 days
|
||||
|
||||
## 🔧 Customization
|
||||
|
||||
### Adjust trigger paths
|
||||
Modify the `paths` section to control when workflows run:
|
||||
```yaml
|
||||
paths:
|
||||
- 'backend/**'
|
||||
- 'shared/**' # Add if you have shared code
|
||||
```
|
||||
|
||||
### Change test commands
|
||||
Update the test steps to match your needs:
|
||||
```yaml
|
||||
- name: Run tests
|
||||
run: pytest -v --custom-flag
|
||||
```
|
||||
|
||||
### Add deployment
|
||||
Create a new workflow for deployment:
|
||||
```yaml
|
||||
name: Deploy to Production
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
tags: [ 'v*' ]
|
||||
```
|
||||
|
||||
## 🛡️ Security
|
||||
|
||||
- Never commit secrets to workflows
|
||||
- Use GitHub Secrets for sensitive data
|
||||
- Review workflow permissions
|
||||
- Keep actions up to date
|
||||
|
||||
## 📊 Coverage Reports
|
||||
|
||||
With Codecov enabled, you'll get:
|
||||
- Coverage trends over time
|
||||
- PR coverage comparisons
|
||||
- Coverage per file/folder
|
||||
- Interactive coverage explorer
|
||||
|
||||
Access at: `https://codecov.io/gh/YOUR_ORG/YOUR_REPO`
|
||||
|
||||
## 💡 Tips
|
||||
|
||||
- **PR checks**: Workflows run on PRs automatically
|
||||
- **Status checks**: Set as required in branch protection
|
||||
- **Debug logs**: Re-run with debug logging enabled
|
||||
- **Artifacts**: Download from workflow run page
|
||||
- **Matrix builds**: Test multiple Python/Node versions
|
||||
|
||||
## 📚 Further Reading
|
||||
|
||||
- [GitHub Actions Documentation](https://docs.github.com/en/actions)
|
||||
- [Codecov Documentation](https://docs.codecov.com)
|
||||
- [Playwright CI Guide](https://playwright.dev/docs/ci)
|
||||
77
.github/workflows/backend-e2e-tests.yml.template
vendored
Normal file
77
.github/workflows/backend-e2e-tests.yml.template
vendored
Normal file
@@ -0,0 +1,77 @@
|
||||
# Backend E2E Tests CI Pipeline
|
||||
#
|
||||
# Runs end-to-end tests with real PostgreSQL via Testcontainers
|
||||
# and validates API contracts with Schemathesis.
|
||||
#
|
||||
# To enable: Rename this file to backend-e2e-tests.yml
|
||||
|
||||
name: Backend E2E Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main, develop]
|
||||
paths:
|
||||
- 'backend/**'
|
||||
- '.github/workflows/backend-e2e-tests.yml'
|
||||
pull_request:
|
||||
branches: [main, develop]
|
||||
paths:
|
||||
- 'backend/**'
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
e2e-tests:
|
||||
runs-on: ubuntu-latest
|
||||
# E2E test failures don't block merge - they're advisory
|
||||
continue-on-error: true
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.12'
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
with:
|
||||
version: "latest"
|
||||
|
||||
- name: Cache uv dependencies
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/uv
|
||||
key: uv-${{ runner.os }}-${{ hashFiles('backend/uv.lock') }}
|
||||
restore-keys: |
|
||||
uv-${{ runner.os }}-
|
||||
|
||||
- name: Install dependencies (with E2E)
|
||||
working-directory: ./backend
|
||||
run: uv sync --extra dev --extra e2e
|
||||
|
||||
- name: Check Docker availability
|
||||
id: docker-check
|
||||
run: |
|
||||
if docker info > /dev/null 2>&1; then
|
||||
echo "available=true" >> $GITHUB_OUTPUT
|
||||
echo "Docker is available"
|
||||
else
|
||||
echo "available=false" >> $GITHUB_OUTPUT
|
||||
echo "::warning::Docker not available - E2E tests will be skipped"
|
||||
fi
|
||||
|
||||
- name: Run E2E tests
|
||||
if: steps.docker-check.outputs.available == 'true'
|
||||
working-directory: ./backend
|
||||
env:
|
||||
IS_TEST: "True"
|
||||
SECRET_KEY: "e2e-test-secret-key-minimum-32-characters-long"
|
||||
PYTHONPATH: "."
|
||||
run: |
|
||||
uv run pytest tests/e2e/ -v --tb=short
|
||||
|
||||
- name: E2E tests skipped
|
||||
if: steps.docker-check.outputs.available != 'true'
|
||||
run: echo "E2E tests were skipped due to Docker unavailability"
|
||||
86
.github/workflows/backend-tests.yml.template
vendored
Normal file
86
.github/workflows/backend-tests.yml.template
vendored
Normal file
@@ -0,0 +1,86 @@
|
||||
# Backend Unit Tests CI Pipeline
|
||||
#
|
||||
# Rename this file to backend-tests.yml to enable it
|
||||
# This will make the backend test badges dynamic
|
||||
#
|
||||
# Required repository secrets:
|
||||
# - None (uses default GITHUB_TOKEN)
|
||||
|
||||
name: Backend Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main, develop ]
|
||||
paths:
|
||||
- 'backend/**'
|
||||
- '.github/workflows/backend-tests.yml'
|
||||
pull_request:
|
||||
branches: [ main, develop ]
|
||||
paths:
|
||||
- 'backend/**'
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
services:
|
||||
postgres:
|
||||
image: postgres:15
|
||||
env:
|
||||
POSTGRES_USER: postgres
|
||||
POSTGRES_PASSWORD: postgres
|
||||
POSTGRES_DB: test_db
|
||||
options: >-
|
||||
--health-cmd pg_isready
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 5
|
||||
ports:
|
||||
- 5432:5432
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.12'
|
||||
cache: 'pip'
|
||||
|
||||
- name: Install dependencies
|
||||
working-directory: ./backend
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r requirements.txt
|
||||
|
||||
- name: Run tests with coverage
|
||||
working-directory: ./backend
|
||||
env:
|
||||
IS_TEST: True
|
||||
POSTGRES_HOST: localhost
|
||||
POSTGRES_PORT: 5432
|
||||
POSTGRES_USER: postgres
|
||||
POSTGRES_PASSWORD: postgres
|
||||
POSTGRES_DB: test_db
|
||||
SECRET_KEY: test-secret-key-for-ci-only
|
||||
run: |
|
||||
pytest --cov=app --cov-report=xml --cov-report=term-missing -v
|
||||
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v4
|
||||
with:
|
||||
files: ./backend/coverage.xml
|
||||
flags: backend
|
||||
name: backend-coverage
|
||||
fail_ci_if_error: true
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
|
||||
- name: Generate coverage badge
|
||||
uses: schneegans/dynamic-badges-action@v1.7.0
|
||||
with:
|
||||
auth: ${{ secrets.GIST_SECRET }}
|
||||
gistID: YOUR_GIST_ID_HERE
|
||||
filename: backend-coverage.json
|
||||
label: backend coverage
|
||||
message: ${{ env.COVERAGE }}%
|
||||
color: brightgreen
|
||||
105
.github/workflows/e2e-tests.yml.template
vendored
Normal file
105
.github/workflows/e2e-tests.yml.template
vendored
Normal file
@@ -0,0 +1,105 @@
|
||||
# End-to-End Tests CI Pipeline
|
||||
#
|
||||
# Rename this file to e2e-tests.yml to enable it
|
||||
# This will make the E2E test badges dynamic
|
||||
#
|
||||
# Required repository secrets:
|
||||
# - None (uses default GITHUB_TOKEN)
|
||||
|
||||
name: E2E Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main, develop ]
|
||||
pull_request:
|
||||
branches: [ main, develop ]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 20
|
||||
|
||||
services:
|
||||
postgres:
|
||||
image: postgres:15
|
||||
env:
|
||||
POSTGRES_USER: postgres
|
||||
POSTGRES_PASSWORD: postgres
|
||||
POSTGRES_DB: test_db
|
||||
options: >-
|
||||
--health-cmd pg_isready
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 5
|
||||
ports:
|
||||
- 5432:5432
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.12'
|
||||
cache: 'pip'
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '20'
|
||||
cache: 'npm'
|
||||
cache-dependency-path: './frontend/package-lock.json'
|
||||
|
||||
- name: Install backend dependencies
|
||||
working-directory: ./backend
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r requirements.txt
|
||||
|
||||
- name: Setup backend database
|
||||
working-directory: ./backend
|
||||
env:
|
||||
POSTGRES_HOST: localhost
|
||||
POSTGRES_PORT: 5432
|
||||
POSTGRES_USER: postgres
|
||||
POSTGRES_PASSWORD: postgres
|
||||
POSTGRES_DB: test_db
|
||||
SECRET_KEY: test-secret-key-for-ci-only
|
||||
run: |
|
||||
alembic upgrade head
|
||||
python -c "from app.init_db import init_db; import asyncio; asyncio.run(init_db())"
|
||||
|
||||
- name: Start backend server
|
||||
working-directory: ./backend
|
||||
env:
|
||||
POSTGRES_HOST: localhost
|
||||
POSTGRES_PORT: 5432
|
||||
POSTGRES_USER: postgres
|
||||
POSTGRES_PASSWORD: postgres
|
||||
POSTGRES_DB: test_db
|
||||
SECRET_KEY: test-secret-key-for-ci-only
|
||||
run: |
|
||||
uvicorn app.main:app --host 0.0.0.0 --port 8000 &
|
||||
sleep 5 # Wait for server to start
|
||||
|
||||
- name: Install frontend dependencies
|
||||
working-directory: ./frontend
|
||||
run: npm ci
|
||||
|
||||
- name: Install Playwright browsers
|
||||
working-directory: ./frontend
|
||||
run: npx playwright install --with-deps chromium
|
||||
|
||||
- name: Run E2E tests
|
||||
working-directory: ./frontend
|
||||
env:
|
||||
NEXT_PUBLIC_API_URL: http://localhost:8000/api/v1
|
||||
run: npm run test:e2e
|
||||
|
||||
- name: Upload test results
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: playwright-report
|
||||
path: frontend/playwright-report/
|
||||
retention-days: 30
|
||||
51
.github/workflows/frontend-tests.yml.template
vendored
Normal file
51
.github/workflows/frontend-tests.yml.template
vendored
Normal file
@@ -0,0 +1,51 @@
|
||||
# Frontend Unit Tests CI Pipeline
|
||||
#
|
||||
# Rename this file to frontend-tests.yml to enable it
|
||||
# This will make the frontend test badges dynamic
|
||||
#
|
||||
# Required repository secrets:
|
||||
# - CODECOV_TOKEN (for coverage upload)
|
||||
|
||||
name: Frontend Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main, develop ]
|
||||
paths:
|
||||
- 'frontend/**'
|
||||
- '.github/workflows/frontend-tests.yml'
|
||||
pull_request:
|
||||
branches: [ main, develop ]
|
||||
paths:
|
||||
- 'frontend/**'
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '20'
|
||||
cache: 'npm'
|
||||
cache-dependency-path: './frontend/package-lock.json'
|
||||
|
||||
- name: Install dependencies
|
||||
working-directory: ./frontend
|
||||
run: npm ci
|
||||
|
||||
- name: Run unit tests with coverage
|
||||
working-directory: ./frontend
|
||||
run: npm run test:coverage
|
||||
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v4
|
||||
with:
|
||||
files: ./frontend/coverage/coverage-final.json
|
||||
flags: frontend
|
||||
name: frontend-coverage
|
||||
fail_ci_if_error: true
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
11
.gitignore
vendored
Normal file → Executable file
11
.gitignore
vendored
Normal file → Executable file
@@ -27,6 +27,10 @@ coverage
|
||||
# nyc test coverage
|
||||
.nyc_output
|
||||
|
||||
# Playwright authentication state (contains test auth tokens)
|
||||
frontend/e2e/.auth/
|
||||
**/playwright/.auth/
|
||||
|
||||
# Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files)
|
||||
.grunt
|
||||
|
||||
@@ -147,7 +151,6 @@ dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
@@ -175,6 +178,7 @@ htmlcov/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
coverage.json
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
@@ -183,7 +187,7 @@ coverage.xml
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
backend/.benchmarks
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
@@ -264,6 +268,7 @@ celerybeat.pid
|
||||
.env
|
||||
.env.*
|
||||
!.env.template
|
||||
!.env.demo
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
@@ -302,6 +307,6 @@ cython_debug/
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
.idea/
|
||||
*.iml
|
||||
|
||||
.junie/*
|
||||
# Docker volumes
|
||||
postgres_data*/
|
||||
|
||||
315
AGENTS.md
Normal file
315
AGENTS.md
Normal file
@@ -0,0 +1,315 @@
|
||||
# AGENTS.md
|
||||
|
||||
AI coding assistant context for FastAPI + Next.js Full-Stack Template.
|
||||
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
# Backend (Python with uv)
|
||||
cd backend
|
||||
make install-dev # Install dependencies
|
||||
make test # Run tests
|
||||
uv run uvicorn app.main:app --reload # Start dev server
|
||||
|
||||
# Frontend (Node.js)
|
||||
cd frontend
|
||||
bun install # Install dependencies
|
||||
bun run dev # Start dev server
|
||||
bun run generate:api # Generate API client from OpenAPI
|
||||
bun run test:e2e # Run E2E tests
|
||||
```
|
||||
|
||||
**Access points:**
|
||||
- Frontend: **http://localhost:3000**
|
||||
- Backend API: **http://localhost:8000**
|
||||
- API Docs: **http://localhost:8000/docs**
|
||||
|
||||
Default superuser (change in production):
|
||||
- Email: `admin@example.com`
|
||||
- Password: `admin123`
|
||||
|
||||
## Project Architecture
|
||||
|
||||
**Full-stack TypeScript/Python application:**
|
||||
|
||||
```
|
||||
├── backend/ # FastAPI backend
|
||||
│ ├── app/
|
||||
│ │ ├── api/ # API routes (auth, users, organizations, admin)
|
||||
│ │ ├── core/ # Core functionality (auth, config, database)
|
||||
│ │ ├── repositories/ # Repository pattern (database operations)
|
||||
│ │ ├── models/ # SQLAlchemy ORM models
|
||||
│ │ ├── schemas/ # Pydantic request/response schemas
|
||||
│ │ ├── services/ # Business logic layer
|
||||
│ │ └── utils/ # Utilities (security, device detection)
|
||||
│ ├── tests/ # 96% coverage, 987 tests
|
||||
│ └── alembic/ # Database migrations
|
||||
│
|
||||
└── frontend/ # Next.js 16 frontend
|
||||
├── src/
|
||||
│ ├── app/ # App Router pages (Next.js 16)
|
||||
│ ├── components/ # React components
|
||||
│ ├── lib/
|
||||
│ │ ├── api/ # Auto-generated API client
|
||||
│ │ └── stores/ # Zustand state management
|
||||
│ └── hooks/ # Custom React hooks
|
||||
└── e2e/ # Playwright E2E tests (56 passing)
|
||||
```
|
||||
|
||||
## Critical Implementation Notes
|
||||
|
||||
### Authentication Flow
|
||||
- **JWT-based**: Access tokens (15 min) + refresh tokens (7 days)
|
||||
- **OAuth/Social Login**: Google and GitHub with PKCE support
|
||||
- **Session tracking**: Database-backed with device info, IP, user agent
|
||||
- **Token refresh**: Validates JTI in database, not just JWT signature
|
||||
- **Authorization**: FastAPI dependencies in `api/dependencies/auth.py`
|
||||
- `get_current_user`: Requires valid access token
|
||||
- `get_current_active_user`: Requires active account
|
||||
- `get_optional_current_user`: Accepts authenticated or anonymous
|
||||
- `get_current_superuser`: Requires superuser flag
|
||||
|
||||
### OAuth Provider Mode (MCP Integration)
|
||||
Full OAuth 2.0 Authorization Server for MCP (Model Context Protocol) clients:
|
||||
- **Authorization Code Flow with PKCE**: RFC 7636 compliant
|
||||
- **JWT access tokens**: Self-contained, no DB lookup required
|
||||
- **Opaque refresh tokens**: Stored hashed in database, supports rotation
|
||||
- **Token introspection**: RFC 7662 compliant endpoint
|
||||
- **Token revocation**: RFC 7009 compliant endpoint
|
||||
- **Server metadata**: RFC 8414 compliant discovery endpoint
|
||||
- **Consent management**: User can review and revoke app permissions
|
||||
|
||||
**API endpoints:**
|
||||
- `GET /.well-known/oauth-authorization-server` - Server metadata
|
||||
- `GET /oauth/provider/authorize` - Authorization endpoint
|
||||
- `POST /oauth/provider/authorize/consent` - Consent submission
|
||||
- `POST /oauth/provider/token` - Token endpoint
|
||||
- `POST /oauth/provider/revoke` - Token revocation
|
||||
- `POST /oauth/provider/introspect` - Token introspection
|
||||
- Client management endpoints (admin only)
|
||||
|
||||
**Scopes supported:** `openid`, `profile`, `email`, `read:users`, `write:users`, `admin`
|
||||
|
||||
**OAuth Configuration (backend `.env`):**
|
||||
```bash
|
||||
# OAuth Social Login (as OAuth Consumer)
|
||||
OAUTH_ENABLED=true # Enable OAuth social login
|
||||
OAUTH_AUTO_LINK_BY_EMAIL=true # Auto-link accounts by email
|
||||
OAUTH_STATE_EXPIRE_MINUTES=10 # CSRF state expiration
|
||||
|
||||
# Google OAuth
|
||||
OAUTH_GOOGLE_CLIENT_ID=your-google-client-id
|
||||
OAUTH_GOOGLE_CLIENT_SECRET=your-google-client-secret
|
||||
|
||||
# GitHub OAuth
|
||||
OAUTH_GITHUB_CLIENT_ID=your-github-client-id
|
||||
OAUTH_GITHUB_CLIENT_SECRET=your-github-client-secret
|
||||
|
||||
# OAuth Provider Mode (as Authorization Server for MCP)
|
||||
OAUTH_PROVIDER_ENABLED=true # Enable OAuth provider mode
|
||||
OAUTH_ISSUER=https://api.yourdomain.com # JWT issuer URL (must be HTTPS in production)
|
||||
```
|
||||
|
||||
### Database Pattern
|
||||
- **Async SQLAlchemy 2.0** with PostgreSQL
|
||||
- **Connection pooling**: 20 base connections, 50 max overflow
|
||||
- **Repository base class**: `repositories/base.py` with common operations
|
||||
- **Migrations**: Alembic with helper script `migrate.py`
|
||||
- `python migrate.py auto "message"` - Generate and apply
|
||||
- `python migrate.py list` - View history
|
||||
|
||||
### Frontend State Management
|
||||
- **Zustand stores**: Lightweight state management
|
||||
- **TanStack Query**: API data fetching/caching
|
||||
- **Auto-generated client**: From OpenAPI spec via `bun run generate:api`
|
||||
- **Dependency Injection**: ALWAYS use `useAuth()` from `AuthContext`, NEVER import `useAuthStore` directly
|
||||
|
||||
### Internationalization (i18n)
|
||||
- **next-intl v4**: Type-safe internationalization library
|
||||
- **Locale routing**: `/en/*`, `/it/*` (English and Italian supported)
|
||||
- **Translation files**: `frontend/messages/en.json`, `frontend/messages/it.json`
|
||||
- **LocaleSwitcher**: Component for seamless language switching
|
||||
- **SEO-friendly**: Locale-aware metadata, sitemaps, and robots.txt
|
||||
- **Type safety**: Full TypeScript support for translations
|
||||
- **Utilities**: `frontend/src/lib/i18n/` (metadata, routing, utils)
|
||||
|
||||
### Organization System
|
||||
Three-tier RBAC:
|
||||
- **Owner**: Full control (delete org, manage all members)
|
||||
- **Admin**: Add/remove members, assign admin role (not owner)
|
||||
- **Member**: Read-only organization access
|
||||
|
||||
Permission dependencies in `api/dependencies/permissions.py`:
|
||||
- `require_organization_owner`
|
||||
- `require_organization_admin`
|
||||
- `require_organization_member`
|
||||
- `can_manage_organization_member`
|
||||
|
||||
### Testing Infrastructure
|
||||
|
||||
**Backend Unit/Integration (pytest + SQLite):**
|
||||
- 96% coverage, 987 tests
|
||||
- Security-focused: JWT attacks, session hijacking, privilege escalation
|
||||
- Async fixtures in `tests/conftest.py`
|
||||
- Run: `IS_TEST=True uv run pytest` or `make test`
|
||||
- Coverage: `make test-cov`
|
||||
|
||||
**Backend E2E (pytest + Testcontainers + Schemathesis):**
|
||||
- Real PostgreSQL via Docker containers
|
||||
- OpenAPI contract testing with Schemathesis
|
||||
- Install: `make install-e2e`
|
||||
- Run: `make test-e2e`
|
||||
- Schema tests: `make test-e2e-schema`
|
||||
- Docs: `backend/docs/E2E_TESTING.md`
|
||||
|
||||
**Frontend Unit Tests (Jest):**
|
||||
- 97% coverage
|
||||
- Component, hook, and utility testing
|
||||
- Run: `bun run test`
|
||||
- Coverage: `bun run test:coverage`
|
||||
|
||||
**Frontend E2E Tests (Playwright):**
|
||||
- 56 passing, 1 skipped (zero flaky tests)
|
||||
- Complete user flows (auth, navigation, settings)
|
||||
- Run: `bun run test:e2e`
|
||||
- UI mode: `bun run test:e2e:ui`
|
||||
|
||||
### Development Tooling
|
||||
|
||||
**Backend:**
|
||||
- **uv**: Modern Python package manager (10-100x faster than pip)
|
||||
- **Ruff**: All-in-one linting/formatting (replaces Black, Flake8, isort)
|
||||
- **Pyright**: Static type checking (strict mode)
|
||||
- **pip-audit**: Dependency vulnerability scanning (OSV database)
|
||||
- **detect-secrets**: Hardcoded secrets detection
|
||||
- **pip-licenses**: License compliance checking
|
||||
- **pre-commit**: Git hook framework (Ruff, detect-secrets, standard checks)
|
||||
- **Makefile**: `make help` for all commands
|
||||
|
||||
**Frontend:**
|
||||
- **Next.js 16**: App Router with React 19
|
||||
- **TypeScript**: Full type safety
|
||||
- **TailwindCSS + shadcn/ui**: Design system
|
||||
- **ESLint + Prettier**: Code quality
|
||||
|
||||
### Environment Configuration
|
||||
|
||||
**Backend** (`.env`):
|
||||
```bash
|
||||
POSTGRES_USER=postgres
|
||||
POSTGRES_PASSWORD=your_password
|
||||
POSTGRES_HOST=db
|
||||
POSTGRES_PORT=5432
|
||||
POSTGRES_DB=app
|
||||
|
||||
SECRET_KEY=your-secret-key-min-32-chars
|
||||
ENVIRONMENT=development|production
|
||||
CSP_MODE=relaxed|strict|disabled
|
||||
|
||||
FIRST_SUPERUSER_EMAIL=admin@example.com
|
||||
FIRST_SUPERUSER_PASSWORD=admin123
|
||||
|
||||
BACKEND_CORS_ORIGINS=["http://localhost:3000"]
|
||||
```
|
||||
|
||||
**Frontend** (`.env.local`):
|
||||
```bash
|
||||
NEXT_PUBLIC_API_URL=http://localhost:8000/api/v1
|
||||
```
|
||||
|
||||
## Common Development Workflows
|
||||
|
||||
### Adding a New API Endpoint
|
||||
|
||||
1. **Define schema** in `backend/app/schemas/`
|
||||
2. **Create repository** in `backend/app/repositories/`
|
||||
3. **Implement route** in `backend/app/api/routes/`
|
||||
4. **Register router** in `backend/app/api/main.py`
|
||||
5. **Write tests** in `backend/tests/api/`
|
||||
6. **Generate frontend client**: `bun run generate:api`
|
||||
|
||||
### Database Migrations
|
||||
|
||||
```bash
|
||||
cd backend
|
||||
python migrate.py generate "description" # Create migration
|
||||
python migrate.py apply # Apply migrations
|
||||
python migrate.py auto "description" # Generate + apply
|
||||
```
|
||||
|
||||
### Frontend Component Development
|
||||
|
||||
1. **Create component** in `frontend/src/components/`
|
||||
2. **Follow design system** (see `frontend/docs/design-system/`)
|
||||
3. **Use dependency injection** for auth (`useAuth()` not `useAuthStore`)
|
||||
4. **Write tests** in `frontend/tests/` or `__tests__/`
|
||||
5. **Run type check**: `bun run type-check`
|
||||
|
||||
## Security Features
|
||||
|
||||
- **Password hashing**: bcrypt with salt rounds
|
||||
- **Rate limiting**: 60 req/min default, 10 req/min on auth endpoints
|
||||
- **Security headers**: CSP, X-Frame-Options, HSTS, etc.
|
||||
- **CSRF protection**: Built into FastAPI
|
||||
- **Session revocation**: Database-backed session tracking
|
||||
- **Comprehensive security tests**: JWT algorithm attacks, session hijacking, privilege escalation
|
||||
- **Dependency vulnerability scanning**: `make dep-audit` (pip-audit against OSV database)
|
||||
- **License compliance**: `make license-check` (blocks GPL-3.0/AGPL)
|
||||
- **Secrets detection**: Pre-commit hook blocks hardcoded secrets
|
||||
- **Unified security pipeline**: `make audit` (all security checks), `make check` (quality + security + tests)
|
||||
|
||||
## Docker Deployment
|
||||
|
||||
```bash
|
||||
# Development (with hot reload)
|
||||
docker-compose -f docker-compose.dev.yml up
|
||||
|
||||
# Production
|
||||
docker-compose up -d
|
||||
|
||||
# Run migrations
|
||||
docker-compose exec backend alembic upgrade head
|
||||
|
||||
# Create first superuser
|
||||
docker-compose exec backend python -c "from app.init_db import init_db; import asyncio; asyncio.run(init_db())"
|
||||
```
|
||||
|
||||
## Documentation
|
||||
|
||||
**For comprehensive documentation, see:**
|
||||
- **[README.md](./README.md)** - User-facing project overview
|
||||
- **[CLAUDE.md](./CLAUDE.md)** - Claude Code-specific guidance
|
||||
- **Backend docs**: `backend/docs/` (Architecture, Coding Standards, Common Pitfalls, Feature Examples)
|
||||
- **Frontend docs**: `frontend/docs/` (Design System, Architecture, E2E Testing)
|
||||
- **API docs**: http://localhost:8000/docs (Swagger UI when running)
|
||||
|
||||
## Current Status (Nov 2025)
|
||||
|
||||
### Completed Features ✅
|
||||
- Authentication system (JWT with refresh tokens, OAuth/social login)
|
||||
- **OAuth Provider Mode (MCP-ready)**: Full OAuth 2.0 Authorization Server
|
||||
- Session management (device tracking, revocation)
|
||||
- User management (full lifecycle, password change)
|
||||
- Organization system (multi-tenant with RBAC)
|
||||
- Admin panel (user/org management, bulk operations)
|
||||
- **Internationalization (i18n)** with English and Italian
|
||||
- Comprehensive test coverage (96% backend, 97% frontend unit, 56 E2E tests)
|
||||
- Design system documentation
|
||||
- **Marketing landing page** with animations
|
||||
- **`/dev` documentation portal** with live examples
|
||||
- **Toast notifications**, charts, markdown rendering
|
||||
- **SEO optimization** (sitemap, robots.txt, locale metadata)
|
||||
- Docker deployment
|
||||
|
||||
### In Progress 🚧
|
||||
- Frontend admin pages (70% complete)
|
||||
- Email integration (templates ready, SMTP pending)
|
||||
|
||||
### Planned 🔮
|
||||
- GitHub Actions CI/CD
|
||||
- Additional languages (Spanish, French, German, etc.)
|
||||
- SSO/SAML authentication
|
||||
- Real-time notifications (WebSockets)
|
||||
- Webhook system
|
||||
- Background job processing
|
||||
- File upload/storage
|
||||
253
CLAUDE.md
Normal file
253
CLAUDE.md
Normal file
@@ -0,0 +1,253 @@
|
||||
# CLAUDE.md
|
||||
|
||||
Claude Code context for FastAPI + Next.js Full-Stack Template.
|
||||
|
||||
**See [AGENTS.md](./AGENTS.md) for project context, architecture, and development commands.**
|
||||
|
||||
## Claude Code-Specific Guidance
|
||||
|
||||
### Critical User Preferences
|
||||
|
||||
#### File Operations - NEVER Use Heredoc/Cat Append
|
||||
**ALWAYS use Read/Write/Edit tools instead of `cat >> file << EOF` commands.**
|
||||
|
||||
This triggers manual approval dialogs and disrupts workflow.
|
||||
|
||||
```bash
|
||||
# WRONG ❌
|
||||
cat >> file.txt << EOF
|
||||
content
|
||||
EOF
|
||||
|
||||
# CORRECT ✅ - Use Read, then Write tools
|
||||
```
|
||||
|
||||
#### Work Style
|
||||
- User prefers autonomous operation without frequent interruptions
|
||||
- Ask for batch permissions upfront for long work sessions
|
||||
- Work independently, document decisions clearly
|
||||
- Only use emojis if the user explicitly requests it
|
||||
|
||||
### When Working with This Stack
|
||||
|
||||
**Dependency Management:**
|
||||
- Backend uses **uv** (modern Python package manager), not pip
|
||||
- Always use `uv run` prefix: `IS_TEST=True uv run pytest`
|
||||
- Or use Makefile commands: `make test`, `make install-dev`
|
||||
- Add dependencies: `uv add <package>` or `uv add --dev <package>`
|
||||
|
||||
**Database Migrations:**
|
||||
- Use the `migrate.py` helper script, not Alembic directly
|
||||
- Generate + apply: `python migrate.py auto "message"`
|
||||
- Never commit migrations without testing them first
|
||||
- Check current state: `python migrate.py current`
|
||||
|
||||
**Frontend API Client Generation:**
|
||||
- Run `bun run generate:api` after backend schema changes
|
||||
- Client is auto-generated from OpenAPI spec
|
||||
- Located in `frontend/src/lib/api/generated/`
|
||||
- NEVER manually edit generated files
|
||||
|
||||
**Testing Commands:**
|
||||
- Backend unit/integration: `IS_TEST=True uv run pytest` (always prefix with `IS_TEST=True`)
|
||||
- Backend E2E (requires Docker): `make test-e2e`
|
||||
- Frontend unit: `bun run test`
|
||||
- Frontend E2E: `bun run test:e2e`
|
||||
- Use `make test` or `make test-cov` in backend for convenience
|
||||
|
||||
**Security & Quality Commands (Backend):**
|
||||
- `make validate` — lint + format + type checks
|
||||
- `make audit` — dependency vulnerabilities + license compliance
|
||||
- `make validate-all` — quality + security checks
|
||||
- `make check` — **full pipeline**: quality + security + tests
|
||||
|
||||
**Backend E2E Testing (requires Docker):**
|
||||
- Install deps: `make install-e2e`
|
||||
- Run all E2E tests: `make test-e2e`
|
||||
- Run schema tests only: `make test-e2e-schema`
|
||||
- Run all tests: `make test-all` (unit + E2E)
|
||||
- Uses Testcontainers (real PostgreSQL) + Schemathesis (OpenAPI contract testing)
|
||||
- Markers: `@pytest.mark.e2e`, `@pytest.mark.postgres`, `@pytest.mark.schemathesis`
|
||||
- See: `backend/docs/E2E_TESTING.md` for complete guide
|
||||
|
||||
### 🔴 CRITICAL: Auth Store Dependency Injection Pattern
|
||||
|
||||
**ALWAYS use `useAuth()` from `AuthContext`, NEVER import `useAuthStore` directly!**
|
||||
|
||||
```typescript
|
||||
// ❌ WRONG - Bypasses dependency injection
|
||||
import { useAuthStore } from '@/lib/stores/authStore';
|
||||
const { user, isAuthenticated } = useAuthStore();
|
||||
|
||||
// ✅ CORRECT - Uses dependency injection
|
||||
import { useAuth } from '@/lib/auth/AuthContext';
|
||||
const { user, isAuthenticated } = useAuth();
|
||||
```
|
||||
|
||||
**Why This Matters:**
|
||||
- E2E tests inject mock stores via `window.__TEST_AUTH_STORE__`
|
||||
- Unit tests inject via `<AuthProvider store={mockStore}>`
|
||||
- Direct `useAuthStore` imports bypass this injection → **tests fail**
|
||||
- ESLint will catch violations (added Nov 2025)
|
||||
|
||||
**Exceptions:**
|
||||
1. `AuthContext.tsx` - DI boundary, legitimately needs real store
|
||||
2. `client.ts` - Non-React context, uses dynamic import + `__TEST_AUTH_STORE__` check
|
||||
|
||||
### E2E Test Best Practices
|
||||
|
||||
When writing or fixing Playwright tests:
|
||||
|
||||
**Navigation Pattern:**
|
||||
```typescript
|
||||
// ✅ CORRECT - Use Promise.all for Next.js Link clicks
|
||||
await Promise.all([
|
||||
page.waitForURL('/target', { timeout: 10000 }),
|
||||
link.click()
|
||||
]);
|
||||
```
|
||||
|
||||
**Selectors:**
|
||||
- Use ID-based selectors for validation errors: `#email-error`
|
||||
- Error IDs use dashes not underscores: `#new-password-error`
|
||||
- Target `.border-destructive[role="alert"]` to avoid Next.js route announcer conflicts
|
||||
- Avoid generic `[role="alert"]` which matches multiple elements
|
||||
|
||||
**URL Assertions:**
|
||||
```typescript
|
||||
// ✅ Use regex to handle query params
|
||||
await expect(page).toHaveURL(/\/auth\/login/);
|
||||
|
||||
// ❌ Don't use exact strings (fails with query params)
|
||||
await expect(page).toHaveURL('/auth/login');
|
||||
```
|
||||
|
||||
**Configuration:**
|
||||
- Uses 12 workers in non-CI mode (`playwright.config.ts`)
|
||||
- Reduces to 2 workers in CI for stability
|
||||
- Tests are designed to be non-flaky with proper waits
|
||||
|
||||
### Important Implementation Details
|
||||
|
||||
**Authentication Testing:**
|
||||
- Backend fixtures in `tests/conftest.py`:
|
||||
- `async_test_db`: Fresh SQLite per test
|
||||
- `async_test_user` / `async_test_superuser`: Pre-created users
|
||||
- `user_token` / `superuser_token`: Access tokens for API calls
|
||||
- Always use `@pytest.mark.asyncio` for async tests
|
||||
- Use `@pytest_asyncio.fixture` for async fixtures
|
||||
|
||||
**Database Testing:**
|
||||
```python
|
||||
# Mock database exceptions correctly
|
||||
from unittest.mock import patch, AsyncMock
|
||||
|
||||
async def mock_commit():
|
||||
raise OperationalError("Connection lost", {}, Exception())
|
||||
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
||||
with pytest.raises(OperationalError):
|
||||
await repo_method(session, obj_in=data)
|
||||
mock_rollback.assert_called_once()
|
||||
```
|
||||
|
||||
**Frontend Component Development:**
|
||||
- Follow design system docs in `frontend/docs/design-system/`
|
||||
- Read `08-ai-guidelines.md` for AI code generation rules
|
||||
- Use parent-controlled spacing (see `04-spacing-philosophy.md`)
|
||||
- WCAG AA compliance required (see `07-accessibility.md`)
|
||||
|
||||
**Security Considerations:**
|
||||
- Backend has comprehensive security tests (JWT attacks, session hijacking)
|
||||
- Never skip security headers in production
|
||||
- Rate limiting is configured in route decorators: `@limiter.limit("10/minute")`
|
||||
- Session revocation is database-backed, not just JWT expiry
|
||||
- Run `make audit` to check for dependency vulnerabilities and license compliance
|
||||
- Run `make check` for the full pipeline: quality + security + tests
|
||||
- Pre-commit hooks enforce Ruff lint/format and detect-secrets on every commit
|
||||
- Setup hooks: `cd backend && uv run pre-commit install`
|
||||
|
||||
### Common Workflows Guidance
|
||||
|
||||
**When Adding a New Feature:**
|
||||
1. Start with backend schema and repository
|
||||
2. Implement API route with proper authorization
|
||||
3. Write backend tests (aim for >90% coverage)
|
||||
4. Generate frontend API client: `bun run generate:api`
|
||||
5. Implement frontend components
|
||||
6. Write frontend unit tests
|
||||
7. Add E2E tests for critical flows
|
||||
8. Update relevant documentation
|
||||
|
||||
**When Fixing Tests:**
|
||||
- Backend: Check test database isolation and async fixture usage
|
||||
- Frontend unit: Verify mocking of `useAuth()` not `useAuthStore`
|
||||
- E2E: Use `Promise.all()` pattern and regex URL assertions
|
||||
|
||||
**When Debugging:**
|
||||
- Backend: Check `IS_TEST=True` environment variable is set
|
||||
- Frontend: Run `bun run type-check` first
|
||||
- E2E: Use `bun run test:e2e:debug` for step-by-step debugging
|
||||
- Check logs: Backend has detailed error logging
|
||||
|
||||
**Demo Mode (Frontend-Only Showcase):**
|
||||
- Enable: `echo "NEXT_PUBLIC_DEMO_MODE=true" > frontend/.env.local`
|
||||
- Uses MSW (Mock Service Worker) to intercept API calls in browser
|
||||
- Zero backend required - perfect for Vercel deployments
|
||||
- **Fully Automated**: MSW handlers auto-generated from OpenAPI spec
|
||||
- Run `bun run generate:api` → updates both API client AND MSW handlers
|
||||
- No manual synchronization needed!
|
||||
- Demo credentials (any password ≥8 chars works):
|
||||
- User: `demo@example.com` / `DemoPass123`
|
||||
- Admin: `admin@example.com` / `AdminPass123`
|
||||
- **Safe**: MSW never runs during tests (Jest or Playwright)
|
||||
- **Coverage**: Mock files excluded from linting and coverage
|
||||
- **Documentation**: `frontend/docs/DEMO_MODE.md` for complete guide
|
||||
|
||||
### Tool Usage Preferences
|
||||
|
||||
**Prefer specialized tools over bash:**
|
||||
- Use Read/Write/Edit tools for file operations
|
||||
- Never use `cat`, `echo >`, or heredoc for file manipulation
|
||||
- Use Task tool with `subagent_type=Explore` for codebase exploration
|
||||
- Use Grep tool for code search, not bash `grep`
|
||||
|
||||
**When to use parallel tool calls:**
|
||||
- Independent git commands: `git status`, `git diff`, `git log`
|
||||
- Reading multiple unrelated files
|
||||
- Running multiple test suites simultaneously
|
||||
- Independent validation steps
|
||||
|
||||
## Custom Skills
|
||||
|
||||
No Claude Code Skills installed yet. To create one, invoke the built-in "skill-creator" skill.
|
||||
|
||||
**Potential skill ideas for this project:**
|
||||
- API endpoint generator workflow (schema → repository → route → tests → frontend client)
|
||||
- Component generator with design system compliance
|
||||
- Database migration troubleshooting helper
|
||||
- Test coverage analyzer and improvement suggester
|
||||
- E2E test generator for new features
|
||||
|
||||
## Additional Resources
|
||||
|
||||
**Comprehensive Documentation:**
|
||||
- [AGENTS.md](./AGENTS.md) - Framework-agnostic AI assistant context
|
||||
- [README.md](./README.md) - User-facing project overview
|
||||
- `backend/docs/` - Backend architecture, coding standards, common pitfalls
|
||||
- `frontend/docs/design-system/` - Complete design system guide
|
||||
|
||||
**API Documentation (when running):**
|
||||
- Swagger UI: http://localhost:8000/docs
|
||||
- ReDoc: http://localhost:8000/redoc
|
||||
- OpenAPI JSON: http://localhost:8000/api/v1/openapi.json
|
||||
|
||||
**Testing Documentation:**
|
||||
- Backend tests: `backend/tests/` (97% coverage)
|
||||
- Frontend E2E: `frontend/e2e/README.md`
|
||||
- Design system: `frontend/docs/design-system/08-ai-guidelines.md`
|
||||
|
||||
---
|
||||
|
||||
**For project architecture, development commands, and general context, see [AGENTS.md](./AGENTS.md).**
|
||||
392
CONTRIBUTING.md
Normal file
392
CONTRIBUTING.md
Normal file
@@ -0,0 +1,392 @@
|
||||
# Contributing to FastAPI + Next.js Template
|
||||
|
||||
First off, thank you for considering contributing to this project! 🎉
|
||||
|
||||
This template aims to be a rock-solid foundation for full-stack applications, and your contributions help make that possible.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Code of Conduct](#code-of-conduct)
|
||||
- [How Can I Contribute?](#how-can-i-contribute)
|
||||
- [Development Setup](#development-setup)
|
||||
- [Coding Standards](#coding-standards)
|
||||
- [Testing Guidelines](#testing-guidelines)
|
||||
- [Commit Messages](#commit-messages)
|
||||
- [Pull Request Process](#pull-request-process)
|
||||
|
||||
---
|
||||
|
||||
## Code of Conduct
|
||||
|
||||
This project is committed to providing a welcoming and inclusive environment. We expect all contributors to:
|
||||
|
||||
- Be respectful and considerate
|
||||
- Welcome newcomers and help them learn
|
||||
- Focus on constructive criticism
|
||||
- Accept feedback gracefully
|
||||
- Prioritize the community's well-being
|
||||
|
||||
Unacceptable behavior includes harassment, trolling, insulting comments, and personal attacks.
|
||||
|
||||
---
|
||||
|
||||
## How Can I Contribute?
|
||||
|
||||
### Reporting Bugs
|
||||
|
||||
Found a bug? Help us fix it!
|
||||
|
||||
1. **Check existing issues** to avoid duplicates
|
||||
2. **Create a new issue** with:
|
||||
- Clear, descriptive title
|
||||
- Steps to reproduce
|
||||
- Expected vs. actual behavior
|
||||
- Environment details (OS, Python/Node version, etc.)
|
||||
- Screenshots/logs if applicable
|
||||
|
||||
### Suggesting Features
|
||||
|
||||
Have an idea for improvement?
|
||||
|
||||
1. **Check existing issues/discussions** first
|
||||
2. **Open a discussion** to gauge interest
|
||||
3. **Explain the use case** and benefits
|
||||
4. **Consider implementation complexity**
|
||||
|
||||
Remember: This is a *template*, not a full application. Features should be:
|
||||
- Broadly useful
|
||||
- Well-documented
|
||||
- Thoroughly tested
|
||||
- Maintainable long-term
|
||||
|
||||
### Improving Documentation
|
||||
|
||||
Documentation improvements are always welcome!
|
||||
|
||||
- Fix typos or unclear explanations
|
||||
- Add examples or diagrams
|
||||
- Expand on complex topics
|
||||
- Update outdated information
|
||||
- Translate documentation (future)
|
||||
|
||||
### Contributing Code
|
||||
|
||||
Ready to write some code? Awesome!
|
||||
|
||||
1. **Pick an issue** (or create one)
|
||||
2. **Comment** that you're working on it
|
||||
3. **Fork and branch** from `main`
|
||||
4. **Write code** following our standards
|
||||
5. **Add tests** (required for features)
|
||||
6. **Update docs** if needed
|
||||
7. **Submit a PR** with clear description
|
||||
|
||||
---
|
||||
|
||||
## Development Setup
|
||||
|
||||
### Backend Development
|
||||
|
||||
```bash
|
||||
cd backend
|
||||
|
||||
# Install dependencies (uv manages virtual environment automatically)
|
||||
make install-dev
|
||||
|
||||
# Setup pre-commit hooks
|
||||
uv run pre-commit install
|
||||
|
||||
# Setup environment
|
||||
cp .env.example .env
|
||||
# Edit .env with your settings
|
||||
|
||||
# Run migrations
|
||||
python migrate.py apply
|
||||
|
||||
# Run quality + security checks
|
||||
make validate-all
|
||||
|
||||
# Run tests
|
||||
make test
|
||||
|
||||
# Run full pipeline (quality + security + tests)
|
||||
make check
|
||||
|
||||
# Start dev server
|
||||
uvicorn app.main:app --reload
|
||||
```
|
||||
|
||||
### Frontend Development
|
||||
|
||||
```bash
|
||||
cd frontend
|
||||
|
||||
# Install dependencies
|
||||
bun install
|
||||
|
||||
# Setup environment
|
||||
cp .env.local.example .env.local
|
||||
|
||||
# Generate API client
|
||||
bun run generate:api
|
||||
|
||||
# Run tests
|
||||
bun run test
|
||||
bun run test:e2e:ui
|
||||
|
||||
# Start dev server
|
||||
bun run dev
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Coding Standards
|
||||
|
||||
### Backend (Python)
|
||||
|
||||
- **Style**: Follow PEP 8
|
||||
- **Type hints**: Use type annotations
|
||||
- **Async**: Use async/await for I/O operations
|
||||
- **Documentation**: Docstrings for all public functions/classes
|
||||
- **Error handling**: Use custom exceptions appropriately
|
||||
- **Security**: Never trust user input, validate everything
|
||||
|
||||
Example:
|
||||
```python
|
||||
async def get_user_by_email(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
email: str
|
||||
) -> Optional[User]:
|
||||
"""
|
||||
Get user by email address.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
email: User's email address
|
||||
|
||||
Returns:
|
||||
User if found, None otherwise
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(User).where(User.email == email)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
```
|
||||
|
||||
### Frontend (TypeScript/React)
|
||||
|
||||
- **Style**: Use Prettier (configured)
|
||||
- **TypeScript**: Strict mode, no `any` types
|
||||
- **Components**: Functional components with hooks
|
||||
- **Naming**: PascalCase for components, camelCase for functions
|
||||
- **Imports**: Use absolute imports with `@/` alias
|
||||
- **Dependencies**: Use provided auth context (never import stores directly)
|
||||
|
||||
Example:
|
||||
```typescript
|
||||
interface UserProfileProps {
|
||||
userId: string;
|
||||
}
|
||||
|
||||
export function UserProfile({ userId }: UserProfileProps) {
|
||||
const { user } = useAuth();
|
||||
const { data, isLoading } = useQuery({
|
||||
queryKey: ['user', userId],
|
||||
queryFn: () => fetchUser(userId),
|
||||
});
|
||||
|
||||
if (isLoading) return <LoadingSpinner />;
|
||||
|
||||
return <div>...</div>;
|
||||
}
|
||||
```
|
||||
|
||||
### Key Patterns
|
||||
|
||||
- **Backend**: Use repository pattern, keep routes thin, business logic in services
|
||||
- **Frontend**: Use React Query for server state, Zustand for client state
|
||||
- **Both**: Handle errors gracefully, log appropriately, write tests
|
||||
|
||||
---
|
||||
|
||||
## Testing Guidelines
|
||||
|
||||
### Backend Tests
|
||||
|
||||
- **Coverage target**: >90% for new code
|
||||
- **Test types**: Unit, integration, and security tests
|
||||
- **Fixtures**: Use pytest fixtures from `conftest.py`
|
||||
- **Database**: Use `async_test_db` fixture for isolation
|
||||
- **Assertions**: Be specific about what you're testing
|
||||
|
||||
```python
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_user(client, async_test_superuser, superuser_token):
|
||||
"""Test creating a new user."""
|
||||
response = await client.post(
|
||||
"/api/v1/admin/users",
|
||||
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||
json={
|
||||
"email": "newuser@example.com",
|
||||
"password": "SecurePass123!",
|
||||
"first_name": "New",
|
||||
"last_name": "User"
|
||||
}
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert data["email"] == "newuser@example.com"
|
||||
assert "password" not in data # Never expose passwords
|
||||
```
|
||||
|
||||
### Frontend E2E Tests
|
||||
|
||||
- **Use Playwright**: For end-to-end user flows
|
||||
- **Be specific**: Use accessible selectors (roles, labels)
|
||||
- **Be reliable**: Avoid flaky tests with proper waits
|
||||
- **Be fast**: Group related tests, use parallel execution
|
||||
|
||||
```typescript
|
||||
test('user can login and view profile', async ({ page }) => {
|
||||
// Login
|
||||
await page.goto('/auth/login');
|
||||
await page.fill('#email', 'user@example.com');
|
||||
await page.fill('#password', 'password123');
|
||||
await page.click('button[type="submit"]');
|
||||
|
||||
// Should redirect to dashboard
|
||||
await expect(page).toHaveURL(/\/dashboard/);
|
||||
|
||||
// Should see user name
|
||||
await expect(page.getByText('Welcome, John')).toBeVisible();
|
||||
});
|
||||
```
|
||||
|
||||
### Unit Tests (Frontend)
|
||||
|
||||
- **Test behavior**: Not implementation details
|
||||
- **Mock dependencies**: Use Jest mocks appropriately
|
||||
- **Test accessibility**: Include a11y checks when relevant
|
||||
|
||||
---
|
||||
|
||||
## Commit Messages
|
||||
|
||||
Write clear, descriptive commit messages:
|
||||
|
||||
### Format
|
||||
|
||||
```
|
||||
<type>: <subject>
|
||||
|
||||
<body (optional)>
|
||||
|
||||
<footer (optional)>
|
||||
```
|
||||
|
||||
### Types
|
||||
|
||||
- `feat`: New feature
|
||||
- `fix`: Bug fix
|
||||
- `docs`: Documentation changes
|
||||
- `style`: Code style changes (formatting, no logic change)
|
||||
- `refactor`: Code refactoring
|
||||
- `test`: Adding or updating tests
|
||||
- `chore`: Maintenance tasks
|
||||
|
||||
### Examples
|
||||
|
||||
**Good:**
|
||||
```
|
||||
feat: add password reset flow
|
||||
|
||||
Implements complete password reset with email tokens.
|
||||
Tokens expire after 1 hour for security.
|
||||
|
||||
Closes #123
|
||||
```
|
||||
|
||||
**Also good (simple change):**
|
||||
```
|
||||
fix: correct pagination offset calculation
|
||||
```
|
||||
|
||||
**Not great:**
|
||||
```
|
||||
Fixed stuff
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Pull Request Process
|
||||
|
||||
### Before Submitting
|
||||
|
||||
- [ ] Code follows project style guidelines
|
||||
- [ ] `make check` passes (quality + security + tests) in backend
|
||||
- [ ] New tests added for new features
|
||||
- [ ] Documentation updated if needed
|
||||
- [ ] No merge conflicts with `main`
|
||||
- [ ] Commits are logical and well-described
|
||||
|
||||
### PR Template
|
||||
|
||||
```markdown
|
||||
## Description
|
||||
Brief description of changes
|
||||
|
||||
## Type of Change
|
||||
- [ ] Bug fix
|
||||
- [ ] New feature
|
||||
- [ ] Documentation update
|
||||
- [ ] Refactoring
|
||||
|
||||
## Testing
|
||||
How was this tested?
|
||||
|
||||
## Screenshots (if applicable)
|
||||
|
||||
## Checklist
|
||||
- [ ] Tests added/updated
|
||||
- [ ] Documentation updated
|
||||
- [ ] No breaking changes
|
||||
- [ ] Follows coding standards
|
||||
```
|
||||
|
||||
### Review Process
|
||||
|
||||
1. **Submit PR** with clear description
|
||||
2. **CI checks** must pass (when implemented)
|
||||
3. **Code review** by maintainers
|
||||
4. **Address feedback** if requested
|
||||
5. **Approval** from at least one maintainer
|
||||
6. **Merge** by maintainer
|
||||
|
||||
### After Merge
|
||||
|
||||
- Your contribution will be in the next release
|
||||
- You'll be added to contributors list
|
||||
- Feel awesome! 🎉
|
||||
|
||||
---
|
||||
|
||||
## Questions?
|
||||
|
||||
- **Documentation issues?** Ask in your PR or issue
|
||||
- **Unsure about implementation?** Open a discussion first
|
||||
- **Need help?** Tag maintainers in your issue/PR
|
||||
|
||||
---
|
||||
|
||||
## Recognition
|
||||
|
||||
Contributors are recognized in:
|
||||
- GitHub contributors page
|
||||
- Release notes (for significant contributions)
|
||||
- README acknowledgments (for major features)
|
||||
|
||||
---
|
||||
|
||||
Thank you for contributing! Every contribution, no matter how small, makes this template better for everyone. 🚀
|
||||
115
Makefile
Normal file → Executable file
115
Makefile
Normal file → Executable file
@@ -1,27 +1,124 @@
|
||||
.PHONY: dev prod down clean
|
||||
.PHONY: help dev dev-full prod down logs logs-dev clean clean-slate drop-db reset-db push-images deploy scan-images
|
||||
|
||||
VERSION ?= latest
|
||||
REGISTRY := gitea.pragmazest.com/cardosofelipe/app
|
||||
REGISTRY ?= ghcr.io/cardosofelipe/pragma-stack
|
||||
|
||||
# Default target
|
||||
help:
|
||||
@echo "FastAPI + Next.js Full-Stack Template"
|
||||
@echo ""
|
||||
@echo "Development:"
|
||||
@echo " make dev - Start backend + db (frontend runs separately)"
|
||||
@echo " make dev-full - Start all services including frontend"
|
||||
@echo " make down - Stop all services"
|
||||
@echo " make logs-dev - Follow dev container logs"
|
||||
@echo ""
|
||||
@echo "Database:"
|
||||
@echo " make drop-db - Drop and recreate empty database"
|
||||
@echo " make reset-db - Drop database and apply all migrations"
|
||||
@echo ""
|
||||
@echo "Production:"
|
||||
@echo " make prod - Start production stack"
|
||||
@echo " make deploy - Pull and deploy latest images"
|
||||
@echo " make push-images - Build and push images to registry"
|
||||
@echo " make scan-images - Scan production images for CVEs (requires trivy)"
|
||||
@echo " make logs - Follow production container logs"
|
||||
@echo ""
|
||||
@echo "Cleanup:"
|
||||
@echo " make clean - Stop containers"
|
||||
@echo " make clean-slate - Stop containers AND delete volumes (DATA LOSS!)"
|
||||
@echo ""
|
||||
@echo "Subdirectory commands:"
|
||||
@echo " cd backend && make help - Backend-specific commands"
|
||||
@echo " cd frontend && npm run - Frontend-specific commands"
|
||||
|
||||
# ============================================================================
|
||||
# Development
|
||||
# ============================================================================
|
||||
|
||||
dev:
|
||||
docker compose -f docker-compose.dev.yml up --build -d
|
||||
# Bring up all dev services except the frontend
|
||||
docker compose -f docker-compose.dev.yml up --build -d --scale frontend=0
|
||||
@echo ""
|
||||
@echo "Frontend is not started by 'make dev'."
|
||||
@echo "To run the frontend locally, open a new terminal and run:"
|
||||
@echo " cd frontend && npm run dev"
|
||||
|
||||
prod:
|
||||
docker compose up --build -d
|
||||
dev-full:
|
||||
# Bring up all dev services including the frontend (full stack)
|
||||
docker compose -f docker-compose.dev.yml up --build -d
|
||||
|
||||
down:
|
||||
docker compose down
|
||||
|
||||
logs:
|
||||
docker compose logs -f
|
||||
|
||||
logs-dev:
|
||||
docker compose -f docker-compose.dev.yml logs -f
|
||||
|
||||
# ============================================================================
|
||||
# Database Management
|
||||
# ============================================================================
|
||||
|
||||
drop-db:
|
||||
@echo "Dropping local database..."
|
||||
@docker compose -f docker-compose.dev.yml exec -T db psql -U postgres -c "DROP DATABASE IF EXISTS app WITH (FORCE);" 2>/dev/null || \
|
||||
docker compose -f docker-compose.dev.yml exec -T db psql -U postgres -c "DROP DATABASE IF EXISTS app;"
|
||||
@docker compose -f docker-compose.dev.yml exec -T db psql -U postgres -c "CREATE DATABASE app;"
|
||||
@echo "Database dropped and recreated (empty)"
|
||||
|
||||
reset-db: drop-db
|
||||
@echo "Applying migrations..."
|
||||
@cd backend && uv run python migrate.py --local apply
|
||||
@echo "Database reset complete!"
|
||||
|
||||
# ============================================================================
|
||||
# Production / Deployment
|
||||
# ============================================================================
|
||||
|
||||
prod:
|
||||
docker compose up --build -d
|
||||
|
||||
deploy:
|
||||
docker compose -f docker-compose.deploy.yml pull
|
||||
docker compose -f docker-compose.deploy.yml up -d
|
||||
|
||||
clean:
|
||||
docker compose down -
|
||||
|
||||
push-images:
|
||||
docker build -t $(REGISTRY)/backend:$(VERSION) ./backend
|
||||
docker build -t $(REGISTRY)/frontend:$(VERSION) ./frontend
|
||||
docker push $(REGISTRY)/backend:$(VERSION)
|
||||
docker push $(REGISTRY)/frontend:$(VERSION)
|
||||
docker push $(REGISTRY)/frontend:$(VERSION)
|
||||
|
||||
scan-images:
|
||||
@docker info > /dev/null 2>&1 || (echo "❌ Docker is not running!"; exit 1)
|
||||
@echo "🐳 Building and scanning production images for CVEs..."
|
||||
docker build -t $(REGISTRY)/backend:scan --target production ./backend
|
||||
docker build -t $(REGISTRY)/frontend:scan --target runner ./frontend
|
||||
@echo ""
|
||||
@echo "=== Backend Image Scan ==="
|
||||
@if command -v trivy > /dev/null 2>&1; then \
|
||||
trivy image --severity HIGH,CRITICAL --exit-code 1 $(REGISTRY)/backend:scan; \
|
||||
else \
|
||||
echo "ℹ️ Trivy not found locally, using Docker to run Trivy..."; \
|
||||
docker run --rm -v /var/run/docker.sock:/var/run/docker.sock aquasec/trivy image --severity HIGH,CRITICAL --exit-code 1 $(REGISTRY)/backend:scan; \
|
||||
fi
|
||||
@echo ""
|
||||
@echo "=== Frontend Image Scan ==="
|
||||
@if command -v trivy > /dev/null 2>&1; then \
|
||||
trivy image --severity HIGH,CRITICAL --exit-code 1 $(REGISTRY)/frontend:scan; \
|
||||
else \
|
||||
docker run --rm -v /var/run/docker.sock:/var/run/docker.sock aquasec/trivy image --severity HIGH,CRITICAL --exit-code 1 $(REGISTRY)/frontend:scan; \
|
||||
fi
|
||||
@echo "✅ No HIGH/CRITICAL CVEs found in production images!"
|
||||
|
||||
# ============================================================================
|
||||
# Cleanup
|
||||
# ============================================================================
|
||||
|
||||
clean:
|
||||
docker compose down
|
||||
|
||||
# WARNING! THIS REMOVES CONTAINERS AND VOLUMES AS WELL - DO NOT USE THIS UNLESS YOU WANT TO START OVER WITH DATA AND ALL
|
||||
clean-slate:
|
||||
docker compose -f docker-compose.dev.yml down -v --remove-orphans
|
||||
|
||||
853
README.md
853
README.md
@@ -1,260 +1,659 @@
|
||||
# FastNext Stack
|
||||
# <img src="frontend/public/logo.svg" alt="PragmaStack" width="32" height="32" style="vertical-align: middle" /> PragmaStack
|
||||
|
||||
A modern, Docker-ready full-stack template combining FastAPI, Next.js, and PostgreSQL. Built for developers who need a robust starting point for web applications with TypeScript frontend and Python backend.
|
||||
> **The Pragmatic Full-Stack Template. Production-ready, security-first, and opinionated.**
|
||||
|
||||
## Features
|
||||
[](./backend/tests)
|
||||
[](./frontend/tests)
|
||||
[](./frontend/e2e)
|
||||
[](./LICENSE)
|
||||
[](./CONTRIBUTING.md)
|
||||
|
||||
- 🐍 **FastAPI Backend**
|
||||
- Python 3.12 with modern async support
|
||||
- SQLAlchemy ORM with async capabilities
|
||||
- Alembic migrations
|
||||
- JWT authentication ready
|
||||
- Pydantic data validation
|
||||
- Comprehensive testing setup
|
||||

|
||||
|
||||
- ⚛️ **Next.js Frontend**
|
||||
- React 19 with TypeScript
|
||||
- Tailwind CSS for styling
|
||||
- Modern app router architecture
|
||||
- Built-in API route support
|
||||
- SEO-friendly by default
|
||||
---
|
||||
|
||||
- 🛠️ **Development Experience**
|
||||
- Docker-based development environment
|
||||
- Hot-reloading for both frontend and backend
|
||||
- Unified development workflow
|
||||
- Comprehensive testing setup
|
||||
- Type safety across the stack
|
||||
## Why PragmaStack?
|
||||
|
||||
- 🚀 **Production Ready**
|
||||
- Multi-stage Docker builds
|
||||
- Production-optimized configurations
|
||||
- Environment-based settings
|
||||
- Health checks and container orchestration
|
||||
- CORS security configured
|
||||
Building a modern full-stack application often leads to "analysis paralysis" or "boilerplate fatigue". You spend weeks setting up authentication, testing, and linting before writing a single line of business logic.
|
||||
|
||||
## Quick Start
|
||||
**PragmaStack cuts through the noise.**
|
||||
|
||||
1. Clone the template:
|
||||
```bash
|
||||
git clone https://github.com/yourusername/fastnext-stack myproject
|
||||
cd myproject
|
||||
```
|
||||
We provide a **pragmatic**, opinionated foundation that prioritizes:
|
||||
- **Speed**: Ship features, not config files.
|
||||
- **Robustness**: Security and testing are not optional.
|
||||
- **Clarity**: Code that is easy to read and maintain.
|
||||
|
||||
2. Create environment files:
|
||||
```bash
|
||||
cp .env.template .env
|
||||
```
|
||||
Whether you're building a SaaS, an internal tool, or a side project, PragmaStack gives you a solid starting point without the bloat.
|
||||
|
||||
3. Start development environment:
|
||||
```bash
|
||||
make dev
|
||||
```
|
||||
---
|
||||
|
||||
4. Access the applications:
|
||||
- Frontend: http://localhost:3000
|
||||
- Backend: http://localhost:8000
|
||||
- API Docs: http://localhost:8000/docs
|
||||
## ✨ Features
|
||||
|
||||
## Project Structure
|
||||
### 🔐 **Authentication & Security**
|
||||
- JWT-based authentication with access + refresh tokens
|
||||
- **OAuth/Social Login** (Google, GitHub) with PKCE support
|
||||
- **OAuth 2.0 Authorization Server** (MCP-ready) for third-party integrations
|
||||
- Session management with device tracking and revocation
|
||||
- Password reset flow (email integration ready)
|
||||
- Secure password hashing (bcrypt)
|
||||
- CSRF protection, rate limiting, and security headers
|
||||
- Comprehensive security tests (JWT algorithm attacks, session hijacking, privilege escalation)
|
||||
|
||||
```
|
||||
fast-next-template/
|
||||
├── backend/ # FastAPI backend
|
||||
│ ├── app/
|
||||
│ │ ├── alembic/ # Database migrations
|
||||
│ │ ├── api/ # API routes and dependencies
|
||||
│ │ ├── core/ # Core functionality (auth, config, db)
|
||||
│ │ ├── crud/ # Database CRUD operations
|
||||
│ │ ├── models/ # SQLAlchemy models
|
||||
│ │ ├── schemas/ # Pydantic schemas
|
||||
│ │ ├── services/ # Business logic services
|
||||
│ │ ├── utils/ # Utility functions
|
||||
│ │ ├── init_db.py # Database initialization script
|
||||
│ │ └── main.py # FastAPI application entry
|
||||
│ ├── tests/ # Comprehensive test suite
|
||||
│ ├── migrate.py # Migration helper CLI
|
||||
│ ├── requirements.txt # Python dependencies
|
||||
│ └── Dockerfile # Multi-stage container build
|
||||
├── frontend/ # Next.js frontend
|
||||
│ ├── src/
|
||||
│ │ ├── app/ # Next.js app router
|
||||
│ │ └── components/ # React components
|
||||
│ ├── public/ # Static assets
|
||||
│ └── Dockerfile # Next.js container build
|
||||
├── docker-compose.yml # Production compose configuration
|
||||
├── docker-compose.dev.yml # Development compose configuration
|
||||
├── docker-compose.deploy.yml # Deployment with pre-built images
|
||||
└── .env.template # Environment variables template
|
||||
```
|
||||
### 🔌 **OAuth Provider Mode (MCP Integration)**
|
||||
Full OAuth 2.0 Authorization Server for Model Context Protocol (MCP) and third-party clients:
|
||||
- **RFC 7636**: Authorization Code Flow with PKCE (S256 only)
|
||||
- **RFC 8414**: Server metadata discovery at `/.well-known/oauth-authorization-server`
|
||||
- **RFC 7662**: Token introspection endpoint
|
||||
- **RFC 7009**: Token revocation endpoint
|
||||
- **JWT access tokens**: Self-contained, configurable lifetime
|
||||
- **Opaque refresh tokens**: Secure rotation, database-backed revocation
|
||||
- **Consent management**: Users can review and revoke app permissions
|
||||
- **Client management**: Admin endpoints for registering OAuth clients
|
||||
- **Scopes**: `openid`, `profile`, `email`, `read:users`, `write:users`, `admin`
|
||||
|
||||
## Backend Features
|
||||
### 👥 **Multi-Tenancy & Organizations**
|
||||
- Full organization system with role-based access control (Owner, Admin, Member)
|
||||
- Invite/remove members, manage permissions
|
||||
- Organization-scoped data access
|
||||
- User can belong to multiple organizations
|
||||
|
||||
### Authentication System
|
||||
- **JWT-based authentication** with access and refresh tokens
|
||||
- **User management** with email/password authentication
|
||||
- **Password hashing** using bcrypt
|
||||
- **Token expiration** handling (access: 1 day, refresh: 60 days)
|
||||
- **Optional authentication** support for public/private endpoints
|
||||
- **Superuser** authorization support
|
||||
### 🛠️ **Admin Panel**
|
||||
- Complete user management (full lifecycle, activate/deactivate, bulk operations)
|
||||
- Organization management (create, edit, delete, member management)
|
||||
- Session monitoring across all users
|
||||
- Real-time statistics dashboard
|
||||
- Admin-only routes with proper authorization
|
||||
|
||||
### Database Management
|
||||
- **PostgreSQL** with optimized connection pooling
|
||||
- **Alembic migrations** with auto-generation support
|
||||
- **Migration CLI helper** (`migrate.py`) for easy database management:
|
||||
```bash
|
||||
python migrate.py generate "add users table" # Generate migration
|
||||
python migrate.py apply # Apply migrations
|
||||
python migrate.py list # List all migrations
|
||||
python migrate.py current # Show current revision
|
||||
python migrate.py check # Check DB connection
|
||||
python migrate.py auto "message" # Generate and apply
|
||||
```
|
||||
- **Automatic database initialization** with first superuser creation
|
||||
### 🎨 **Modern Frontend**
|
||||
- Next.js 16 with App Router and React 19
|
||||
- **PragmaStack Design System** built on shadcn/ui + TailwindCSS
|
||||
- Pre-configured theme with dark mode support (coming soon)
|
||||
- Responsive, accessible components (WCAG AA compliant)
|
||||
- Rich marketing landing page with animated components
|
||||
- Live component showcase and documentation at `/dev`
|
||||
|
||||
### Testing Infrastructure
|
||||
- **92 comprehensive tests** covering all core functionality
|
||||
- **SQLite in-memory** database for fast test execution
|
||||
- **Auth test utilities** for easy endpoint testing
|
||||
- **Mocking support** for external dependencies
|
||||
- **Test fixtures** for common scenarios
|
||||
### 🌍 **Internationalization (i18n)**
|
||||
- Built-in multi-language support with next-intl v4
|
||||
- Locale-based routing (`/en/*`, `/it/*`)
|
||||
- Seamless language switching with LocaleSwitcher component
|
||||
- SEO-friendly URLs and metadata per locale
|
||||
- Translation files for English and Italian (easily extensible)
|
||||
- Type-safe translations throughout the app
|
||||
|
||||
### Security Utilities
|
||||
- **Upload token system** for secure file operations
|
||||
- **HMAC-based signing** for token validation
|
||||
- **Time-limited tokens** with expiration
|
||||
- **Nonce support** to prevent token reuse
|
||||
### 🎯 **Content & UX Features**
|
||||
- **Toast notifications** with Sonner for elegant user feedback
|
||||
- **Smooth animations** powered by Framer Motion
|
||||
- **Markdown rendering** with syntax highlighting (GitHub Flavored Markdown)
|
||||
- **Charts and visualizations** ready with Recharts
|
||||
- **SEO optimization** with dynamic sitemap and robots.txt generation
|
||||
- **Session tracking UI** with device information and revocation controls
|
||||
|
||||
## Development
|
||||
### 🧪 **Comprehensive Testing**
|
||||
- **Backend Testing**: ~97% unit test coverage
|
||||
- Unit, integration, and security tests
|
||||
- Async database testing with SQLAlchemy
|
||||
- API endpoint testing with fixtures
|
||||
- Security vulnerability tests (JWT attacks, session hijacking, privilege escalation)
|
||||
- **Frontend Unit Tests**: ~97% coverage with Jest
|
||||
- Component testing
|
||||
- Hook testing
|
||||
- Utility function testing
|
||||
- **End-to-End Tests**: Playwright with zero flaky tests
|
||||
- Complete user flows (auth, navigation, settings)
|
||||
- Parallel execution for speed
|
||||
- Visual regression testing ready
|
||||
|
||||
### Running Tests
|
||||
### 📚 **Developer Experience**
|
||||
- Auto-generated TypeScript API client from OpenAPI spec
|
||||
- Interactive API documentation (Swagger + ReDoc)
|
||||
- Database migrations with Alembic helper script
|
||||
- Hot reload in development for both frontend and backend
|
||||
- Comprehensive code documentation and design system docs
|
||||
- Live component playground at `/dev` with code examples
|
||||
- Docker support for easy deployment
|
||||
- VSCode workspace settings included
|
||||
|
||||
```bash
|
||||
# Backend tests
|
||||
cd backend
|
||||
source .venv/bin/activate
|
||||
pytest tests/ -v
|
||||
### 📊 **Ready for Production**
|
||||
- Docker + docker-compose setup
|
||||
- Environment-based configuration
|
||||
- Database connection pooling
|
||||
- Error handling and logging
|
||||
- Health check endpoints
|
||||
- Production security headers
|
||||
- Rate limiting on sensitive endpoints
|
||||
- SEO optimization with dynamic sitemaps and robots.txt
|
||||
- Multi-language SEO with locale-specific metadata
|
||||
- Performance monitoring and bundle analysis
|
||||
|
||||
# With coverage
|
||||
pytest tests/ --cov=app --cov-report=html
|
||||
```
|
||||
---
|
||||
|
||||
### Database Migrations
|
||||
## 📸 Screenshots
|
||||
|
||||
```bash
|
||||
# Using the migration helper
|
||||
python migrate.py generate "your migration message"
|
||||
python migrate.py apply
|
||||
<details>
|
||||
<summary>Click to view screenshots</summary>
|
||||
|
||||
# Or using alembic directly
|
||||
alembic revision --autogenerate -m "your message"
|
||||
alembic upgrade head
|
||||
```
|
||||
### Landing Page
|
||||

|
||||
|
||||
### First Superuser
|
||||
|
||||
The backend automatically creates a superuser on initialization. Configure via environment variables:
|
||||
|
||||
```bash
|
||||
FIRST_SUPERUSER_EMAIL=admin@example.com
|
||||
FIRST_SUPERUSER_PASSWORD=admin123
|
||||
```
|
||||
|
||||
If not configured, defaults to `admin@example.com` / `admin123`.
|
||||
|
||||
## Deployment
|
||||
|
||||
### Option 1: Build and Deploy Locally
|
||||
|
||||
For production with local builds:
|
||||
|
||||
```bash
|
||||
docker-compose up -d
|
||||
```
|
||||
|
||||
### Option 2: Deploy with Pre-built Images
|
||||
|
||||
For deployment using images from a container registry:
|
||||
|
||||
1. Build and push your images:
|
||||
```bash
|
||||
# Build images
|
||||
docker-compose build
|
||||
|
||||
# Tag for your registry
|
||||
docker tag fast-next-template-backend:latest your-registry/your-project-backend:latest
|
||||
docker tag fast-next-template-frontend:latest your-registry/your-project-frontend:latest
|
||||
|
||||
# Push to registry
|
||||
docker push your-registry/your-project-backend:latest
|
||||
docker push your-registry/your-project-frontend:latest
|
||||
```
|
||||
|
||||
2. Update `docker-compose.deploy.yml` with your image references:
|
||||
```yaml
|
||||
services:
|
||||
backend:
|
||||
image: your-registry/your-project-backend:latest
|
||||
frontend:
|
||||
image: your-registry/your-project-frontend:latest
|
||||
```
|
||||
|
||||
3. Deploy:
|
||||
```bash
|
||||
docker-compose -f docker-compose.deploy.yml up -d
|
||||
```
|
||||
|
||||
### Environment Variables
|
||||
|
||||
Create a `.env` file based on `.env.template`:
|
||||
|
||||
```bash
|
||||
# Project
|
||||
PROJECT_NAME=MyApp
|
||||
VERSION=1.0.0
|
||||
|
||||
# Database
|
||||
POSTGRES_USER=postgres
|
||||
POSTGRES_PASSWORD=your-secure-password
|
||||
POSTGRES_DB=app
|
||||
POSTGRES_HOST=db
|
||||
POSTGRES_PORT=5432
|
||||
|
||||
# Backend
|
||||
BACKEND_PORT=8000
|
||||
SECRET_KEY=your-secret-key-change-this-in-production
|
||||
ENVIRONMENT=production
|
||||
DEBUG=false
|
||||
BACKEND_CORS_ORIGINS=["http://localhost:3000"]
|
||||
|
||||
# First Superuser
|
||||
FIRST_SUPERUSER_EMAIL=admin@example.com
|
||||
FIRST_SUPERUSER_PASSWORD=admin123
|
||||
|
||||
# Frontend
|
||||
NEXT_PUBLIC_API_URL=http://localhost:8000/api/v1
|
||||
```
|
||||
|
||||
## API Documentation
|
||||
|
||||
Once the backend is running, visit:
|
||||
- Swagger UI: http://localhost:8000/docs
|
||||
- ReDoc: http://localhost:8000/redoc
|
||||
|
||||
## Available Endpoints
|
||||
|
||||
### Authentication
|
||||
- `POST /api/v1/auth/register` - User registration
|
||||
- `POST /api/v1/auth/login` - User login (JSON)
|
||||
- `POST /api/v1/auth/login/oauth` - OAuth2-compatible login
|
||||
- `POST /api/v1/auth/refresh` - Refresh access token
|
||||
- `POST /api/v1/auth/change-password` - Change password
|
||||
- `GET /api/v1/auth/me` - Get current user info
|
||||

|
||||
|
||||
## Contributing
|
||||
|
||||
This is a template project. Feel free to fork and customize for your needs.
|
||||
|
||||
## License
|
||||
### Admin Dashboard
|
||||

|
||||
|
||||
MIT License - feel free to use this template for your projects.
|
||||
|
||||
|
||||
### Design System
|
||||

|
||||
|
||||
</details>
|
||||
|
||||
---
|
||||
|
||||
## 🎭 Demo Mode
|
||||
|
||||
**Try the frontend without a backend!** Perfect for:
|
||||
- **Free deployment** on Vercel (no backend costs)
|
||||
- **Portfolio showcasing** with live demos
|
||||
- **Client presentations** without infrastructure setup
|
||||
|
||||
### Quick Start
|
||||
|
||||
```bash
|
||||
cd frontend
|
||||
echo "NEXT_PUBLIC_DEMO_MODE=true" > .env.local
|
||||
bun run dev
|
||||
```
|
||||
|
||||
**Demo Credentials:**
|
||||
- Regular user: `demo@example.com` / `DemoPass123`
|
||||
- Admin user: `admin@example.com` / `AdminPass123`
|
||||
|
||||
Demo mode uses [Mock Service Worker (MSW)](https://mswjs.io/) to intercept API calls in the browser. Your code remains unchanged - the same components work with both real and mocked backends.
|
||||
|
||||
**Key Features:**
|
||||
- ✅ Zero backend required
|
||||
- ✅ All features functional (auth, admin, stats)
|
||||
- ✅ Realistic network delays and errors
|
||||
- ✅ Does NOT interfere with tests (97%+ coverage maintained)
|
||||
- ✅ One-line toggle: `NEXT_PUBLIC_DEMO_MODE=true`
|
||||
|
||||
📖 **[Complete Demo Mode Documentation](./frontend/docs/DEMO_MODE.md)**
|
||||
|
||||
---
|
||||
|
||||
## 🚀 Tech Stack
|
||||
|
||||
### Backend
|
||||
- **[FastAPI](https://fastapi.tiangolo.com/)** - Modern async Python web framework
|
||||
- **[SQLAlchemy 2.0](https://www.sqlalchemy.org/)** - Powerful ORM with async support
|
||||
- **[PostgreSQL](https://www.postgresql.org/)** - Robust relational database
|
||||
- **[Alembic](https://alembic.sqlalchemy.org/)** - Database migrations
|
||||
- **[Pydantic v2](https://docs.pydantic.dev/)** - Data validation with type hints
|
||||
- **[pytest](https://pytest.org/)** - Testing framework with async support
|
||||
|
||||
### Frontend
|
||||
- **[Next.js 16](https://nextjs.org/)** - React framework with App Router
|
||||
- **[React 19](https://react.dev/)** - UI library
|
||||
- **[TypeScript](https://www.typescriptlang.org/)** - Type-safe JavaScript
|
||||
- **[TailwindCSS](https://tailwindcss.com/)** - Utility-first CSS framework
|
||||
- **[shadcn/ui](https://ui.shadcn.com/)** - Beautiful, accessible component library
|
||||
- **[next-intl](https://next-intl.dev/)** - Internationalization (i18n) with type safety
|
||||
- **[TanStack Query](https://tanstack.com/query)** - Powerful data fetching/caching
|
||||
- **[Zustand](https://zustand-demo.pmnd.rs/)** - Lightweight state management
|
||||
- **[Framer Motion](https://www.framer.com/motion/)** - Production-ready animation library
|
||||
- **[Sonner](https://sonner.emilkowal.ski/)** - Beautiful toast notifications
|
||||
- **[Recharts](https://recharts.org/)** - Composable charting library
|
||||
- **[React Markdown](https://github.com/remarkjs/react-markdown)** - Markdown rendering with GFM support
|
||||
- **[Playwright](https://playwright.dev/)** - End-to-end testing
|
||||
|
||||
### DevOps
|
||||
- **[Docker](https://www.docker.com/)** - Containerization
|
||||
- **[docker-compose](https://docs.docker.com/compose/)** - Multi-container orchestration
|
||||
- **GitHub Actions** (coming soon) - CI/CD pipelines
|
||||
|
||||
---
|
||||
|
||||
## 📋 Prerequisites
|
||||
|
||||
- **Docker & Docker Compose** (recommended) - [Install Docker](https://docs.docker.com/get-docker/)
|
||||
- **OR manually:**
|
||||
- Python 3.12+
|
||||
- Node.js 18+ (Node 20+ recommended)
|
||||
- PostgreSQL 15+
|
||||
|
||||
---
|
||||
|
||||
## 🏃 Quick Start (Docker)
|
||||
|
||||
The fastest way to get started is with Docker:
|
||||
|
||||
```bash
|
||||
# Clone the repository
|
||||
git clone https://github.com/cardosofelipe/pragma-stack.git
|
||||
cd fast-next-template
|
||||
|
||||
# Copy environment file
|
||||
cp .env.template .env
|
||||
|
||||
# Start all services (backend, frontend, database)
|
||||
docker-compose up
|
||||
|
||||
# In another terminal, run database migrations
|
||||
docker-compose exec backend alembic upgrade head
|
||||
|
||||
# Create first superuser (optional)
|
||||
docker-compose exec backend python -c "from app.init_db import init_db; import asyncio; asyncio.run(init_db())"
|
||||
```
|
||||
|
||||
**That's it! 🎉**
|
||||
|
||||
- Frontend: http://localhost:3000
|
||||
- Backend API: http://localhost:8000
|
||||
- API Docs: http://localhost:8000/docs
|
||||
|
||||
Default superuser credentials:
|
||||
- Email: `admin@example.com`
|
||||
- Password: `admin123`
|
||||
|
||||
**⚠️ Change these immediately in production!**
|
||||
|
||||
---
|
||||
|
||||
## 🛠️ Manual Setup (Development)
|
||||
|
||||
### Backend Setup
|
||||
|
||||
```bash
|
||||
cd backend
|
||||
|
||||
# Create virtual environment
|
||||
python -m venv .venv
|
||||
source .venv/bin/activate # On Windows: .venv\Scripts\activate
|
||||
|
||||
# Install dependencies
|
||||
pip install -r requirements.txt
|
||||
|
||||
# Setup environment
|
||||
cp .env.example .env
|
||||
# Edit .env with your database credentials
|
||||
|
||||
# Run migrations
|
||||
alembic upgrade head
|
||||
|
||||
# Initialize database with first superuser
|
||||
python -c "from app.init_db import init_db; import asyncio; asyncio.run(init_db())"
|
||||
|
||||
# Start development server
|
||||
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||
### Frontend Setup
|
||||
|
||||
```bash
|
||||
cd frontend
|
||||
|
||||
# Install dependencies
|
||||
bun install
|
||||
|
||||
# Setup environment
|
||||
cp .env.local.example .env.local
|
||||
# Edit .env.local with your backend URL
|
||||
|
||||
# Generate API client
|
||||
bun run generate:api
|
||||
|
||||
# Start development server
|
||||
bun run dev
|
||||
```
|
||||
|
||||
Visit http://localhost:3000 to see your app!
|
||||
|
||||
---
|
||||
|
||||
## 📂 Project Structure
|
||||
|
||||
```
|
||||
├── backend/ # FastAPI backend
|
||||
│ ├── app/
|
||||
│ │ ├── api/ # API routes and dependencies
|
||||
│ │ ├── core/ # Core functionality (auth, config, database)
|
||||
│ │ ├── repositories/ # Repository pattern (database operations)
|
||||
│ │ ├── models/ # SQLAlchemy models
|
||||
│ │ ├── schemas/ # Pydantic schemas
|
||||
│ │ ├── services/ # Business logic
|
||||
│ │ └── utils/ # Utilities
|
||||
│ ├── tests/ # Backend tests (97% coverage)
|
||||
│ ├── alembic/ # Database migrations
|
||||
│ └── docs/ # Backend documentation
|
||||
│
|
||||
├── frontend/ # Next.js frontend
|
||||
│ ├── src/
|
||||
│ │ ├── app/ # Next.js App Router pages
|
||||
│ │ ├── components/ # React components
|
||||
│ │ ├── lib/ # Libraries and utilities
|
||||
│ │ │ ├── api/ # API client (auto-generated)
|
||||
│ │ │ └── stores/ # Zustand stores
|
||||
│ │ └── hooks/ # Custom React hooks
|
||||
│ ├── e2e/ # Playwright E2E tests
|
||||
│ ├── tests/ # Unit tests (Jest)
|
||||
│ └── docs/ # Frontend documentation
|
||||
│ └── design-system/ # Comprehensive design system docs
|
||||
│
|
||||
├── docker-compose.yml # Docker orchestration
|
||||
├── docker-compose.dev.yml # Development with hot reload
|
||||
└── README.md # You are here!
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🧪 Testing
|
||||
|
||||
This template takes testing seriously with comprehensive coverage across all layers:
|
||||
|
||||
### Backend Unit & Integration Tests
|
||||
|
||||
**High coverage (~97%)** across all critical paths including security-focused tests.
|
||||
|
||||
```bash
|
||||
cd backend
|
||||
|
||||
# Run all tests
|
||||
IS_TEST=True pytest
|
||||
|
||||
# Run with coverage report
|
||||
IS_TEST=True pytest --cov=app --cov-report=term-missing
|
||||
|
||||
# Run specific test file
|
||||
IS_TEST=True pytest tests/api/test_auth.py -v
|
||||
|
||||
# Generate HTML coverage report
|
||||
IS_TEST=True pytest --cov=app --cov-report=html
|
||||
open htmlcov/index.html
|
||||
```
|
||||
|
||||
**Test types:**
|
||||
- **Unit tests**: Repository operations, utilities, business logic
|
||||
- **Integration tests**: API endpoints with database
|
||||
- **Security tests**: JWT algorithm attacks, session hijacking, privilege escalation
|
||||
- **Error handling tests**: Database failures, validation errors
|
||||
|
||||
### Frontend Unit Tests
|
||||
|
||||
**High coverage (~97%)** with Jest and React Testing Library.
|
||||
|
||||
```bash
|
||||
cd frontend
|
||||
|
||||
# Run unit tests
|
||||
bun run test
|
||||
|
||||
# Run with coverage
|
||||
bun run test:coverage
|
||||
|
||||
# Watch mode
|
||||
bun run test:watch
|
||||
```
|
||||
|
||||
**Test types:**
|
||||
- Component rendering and interactions
|
||||
- Custom hooks behavior
|
||||
- State management
|
||||
- Utility functions
|
||||
- API integration mocks
|
||||
|
||||
### End-to-End Tests
|
||||
|
||||
**Zero flaky tests** with Playwright covering complete user journeys.
|
||||
|
||||
```bash
|
||||
cd frontend
|
||||
|
||||
# Run E2E tests
|
||||
bun run test:e2e
|
||||
|
||||
# Run E2E tests in UI mode (recommended for development)
|
||||
bun run test:e2e:ui
|
||||
|
||||
# Run specific test file
|
||||
npx playwright test auth-login.spec.ts
|
||||
|
||||
# Generate test report
|
||||
npx playwright show-report
|
||||
```
|
||||
|
||||
**Test coverage:**
|
||||
- Complete authentication flows
|
||||
- Navigation and routing
|
||||
- Form submissions and validation
|
||||
- Settings and profile management
|
||||
- Session management
|
||||
- Admin panel workflows (in progress)
|
||||
|
||||
---
|
||||
|
||||
## 🤖 AI-Friendly Documentation
|
||||
|
||||
This project includes comprehensive documentation designed for AI coding assistants:
|
||||
|
||||
- **[AGENTS.md](./AGENTS.md)** - Framework-agnostic AI assistant context for PragmaStack
|
||||
- **[CLAUDE.md](./CLAUDE.md)** - Claude Code-specific guidance
|
||||
|
||||
These files provide AI assistants with the **PragmaStack** architecture, patterns, and best practices.
|
||||
|
||||
---
|
||||
|
||||
## 🗄️ Database Migrations
|
||||
|
||||
The template uses Alembic for database migrations:
|
||||
|
||||
```bash
|
||||
cd backend
|
||||
|
||||
# Generate migration from model changes
|
||||
python migrate.py generate "description of changes"
|
||||
|
||||
# Apply migrations
|
||||
python migrate.py apply
|
||||
|
||||
# Or do both in one command
|
||||
python migrate.py auto "description"
|
||||
|
||||
# View migration history
|
||||
python migrate.py list
|
||||
|
||||
# Check current revision
|
||||
python migrate.py current
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 📖 Documentation
|
||||
|
||||
### AI Assistant Documentation
|
||||
|
||||
- **[AGENTS.md](./AGENTS.md)** - Framework-agnostic AI coding assistant context
|
||||
- **[CLAUDE.md](./CLAUDE.md)** - Claude Code-specific guidance and preferences
|
||||
|
||||
### Backend Documentation
|
||||
|
||||
- **[ARCHITECTURE.md](./backend/docs/ARCHITECTURE.md)** - System architecture and design patterns
|
||||
- **[CODING_STANDARDS.md](./backend/docs/CODING_STANDARDS.md)** - Code quality standards
|
||||
- **[COMMON_PITFALLS.md](./backend/docs/COMMON_PITFALLS.md)** - Common mistakes to avoid
|
||||
- **[FEATURE_EXAMPLE.md](./backend/docs/FEATURE_EXAMPLE.md)** - Step-by-step feature guide
|
||||
|
||||
### Frontend Documentation
|
||||
|
||||
- **[PragmaStack Design System](./frontend/docs/design-system/)** - Complete design system guide
|
||||
- Quick start, foundations (colors, typography, spacing)
|
||||
- Component library guide
|
||||
- Layout patterns, spacing philosophy
|
||||
- Forms, accessibility, AI guidelines
|
||||
- **[E2E Testing Guide](./frontend/e2e/README.md)** - E2E testing setup and best practices
|
||||
|
||||
### API Documentation
|
||||
|
||||
When the backend is running:
|
||||
- **Swagger UI**: http://localhost:8000/docs
|
||||
- **ReDoc**: http://localhost:8000/redoc
|
||||
- **OpenAPI JSON**: http://localhost:8000/api/v1/openapi.json
|
||||
|
||||
---
|
||||
|
||||
## 🚢 Deployment
|
||||
|
||||
### Docker Production Deployment
|
||||
|
||||
```bash
|
||||
# Build and start all services
|
||||
docker-compose up -d
|
||||
|
||||
# Run migrations
|
||||
docker-compose exec backend alembic upgrade head
|
||||
|
||||
# View logs
|
||||
docker-compose logs -f
|
||||
|
||||
# Stop services
|
||||
docker-compose down
|
||||
```
|
||||
|
||||
### Production Checklist
|
||||
|
||||
- [ ] Change default superuser credentials
|
||||
- [ ] Set strong `SECRET_KEY` in backend `.env`
|
||||
- [ ] Configure production database (PostgreSQL)
|
||||
- [ ] Set `ENVIRONMENT=production` in backend
|
||||
- [ ] Configure CORS origins for your domain
|
||||
- [ ] Setup SSL/TLS certificates
|
||||
- [ ] Configure email service for password resets
|
||||
- [ ] Setup monitoring and logging
|
||||
- [ ] Configure backup strategy
|
||||
- [ ] Review and adjust rate limits
|
||||
- [ ] Test security headers
|
||||
|
||||
---
|
||||
|
||||
## 🛣️ Roadmap & Status
|
||||
|
||||
### ✅ Completed
|
||||
- [x] Authentication system (JWT, refresh tokens, session management, OAuth)
|
||||
- [x] User management (full lifecycle, profile, password change)
|
||||
- [x] Organization system with RBAC (Owner, Admin, Member)
|
||||
- [x] Admin panel (users, organizations, sessions, statistics)
|
||||
- [x] **Internationalization (i18n)** with next-intl (English + Italian)
|
||||
- [x] Backend testing infrastructure (~97% coverage)
|
||||
- [x] Frontend unit testing infrastructure (~97% coverage)
|
||||
- [x] Frontend E2E testing (Playwright, zero flaky tests)
|
||||
- [x] Design system documentation
|
||||
- [x] **Marketing landing page** with animated components
|
||||
- [x] **`/dev` documentation portal** with live component examples
|
||||
- [x] **Toast notifications** system (Sonner)
|
||||
- [x] **Charts and visualizations** (Recharts)
|
||||
- [x] **Animation system** (Framer Motion)
|
||||
- [x] **Markdown rendering** with syntax highlighting
|
||||
- [x] **SEO optimization** (sitemap, robots.txt, locale-aware metadata)
|
||||
- [x] Database migrations with helper script
|
||||
- [x] Docker deployment
|
||||
- [x] API documentation (OpenAPI/Swagger)
|
||||
|
||||
### 🚧 In Progress
|
||||
- [ ] Email integration (templates ready, SMTP pending)
|
||||
|
||||
### 🔮 Planned
|
||||
- [ ] GitHub Actions CI/CD pipelines
|
||||
- [ ] Dynamic test coverage badges from CI
|
||||
- [ ] E2E test coverage reporting
|
||||
- [ ] OAuth token encryption at rest (security hardening)
|
||||
- [ ] Additional languages (Spanish, French, German, etc.)
|
||||
- [ ] SSO/SAML authentication
|
||||
- [ ] Real-time notifications with WebSockets
|
||||
- [ ] Webhook system
|
||||
- [ ] File upload/storage (S3-compatible)
|
||||
- [ ] Audit logging system
|
||||
- [ ] API versioning example
|
||||
|
||||
|
||||
---
|
||||
|
||||
## 🤝 Contributing
|
||||
|
||||
Contributions are welcome! Whether you're fixing bugs, improving documentation, or proposing new features, we'd love your help.
|
||||
|
||||
### How to Contribute
|
||||
|
||||
1. **Fork the repository**
|
||||
2. **Create a feature branch** (`git checkout -b feature/amazing-feature`)
|
||||
3. **Make your changes**
|
||||
- Follow existing code style
|
||||
- Add tests for new features
|
||||
- Update documentation as needed
|
||||
4. **Run tests** to ensure everything works
|
||||
5. **Commit your changes** (`git commit -m 'Add amazing feature'`)
|
||||
6. **Push to your branch** (`git push origin feature/amazing-feature`)
|
||||
7. **Open a Pull Request**
|
||||
|
||||
### Development Guidelines
|
||||
|
||||
- Write tests for new features (aim for >90% coverage)
|
||||
- Follow the existing architecture patterns
|
||||
- Update documentation when adding features
|
||||
- Keep commits atomic and well-described
|
||||
- Be respectful and constructive in discussions
|
||||
|
||||
### Reporting Issues
|
||||
|
||||
Found a bug? Have a suggestion? [Open an issue](https://github.com/cardosofelipe/pragma-stack/issues)!
|
||||
|
||||
Please include:
|
||||
- Clear description of the issue/suggestion
|
||||
- Steps to reproduce (for bugs)
|
||||
- Expected vs. actual behavior
|
||||
- Environment details (OS, Python/Node version, etc.)
|
||||
|
||||
---
|
||||
|
||||
## 📄 License
|
||||
|
||||
This project is licensed under the **MIT License** - see the [LICENSE](./LICENSE) file for details.
|
||||
|
||||
**TL;DR**: You can use this template for any purpose, commercial or non-commercial. Attribution is appreciated but not required!
|
||||
|
||||
---
|
||||
|
||||
## 🙏 Acknowledgments
|
||||
|
||||
This template is built on the shoulders of giants:
|
||||
|
||||
- [FastAPI](https://fastapi.tiangolo.com/) by Sebastián Ramírez
|
||||
- [Next.js](https://nextjs.org/) by Vercel
|
||||
- [shadcn/ui](https://ui.shadcn.com/) by shadcn
|
||||
- [TanStack Query](https://tanstack.com/query) by Tanner Linsley
|
||||
- [Playwright](https://playwright.dev/) by Microsoft
|
||||
- And countless other open-source projects that make modern development possible
|
||||
|
||||
---
|
||||
|
||||
## 💬 Questions?
|
||||
|
||||
- **Documentation**: Check the `/docs` folders in backend and frontend
|
||||
- **Issues**: [GitHub Issues](https://github.com/cardosofelipe/pragma-stack/issues)
|
||||
- **Discussions**: [GitHub Discussions](https://github.com/cardosofelipe/pragma-stack/discussions)
|
||||
|
||||
---
|
||||
|
||||
## ⭐ Star This Repo
|
||||
|
||||
If this template saves you time, consider giving it a star! It helps others discover the project and motivates continued development.
|
||||
|
||||
**Happy coding! 🚀**
|
||||
|
||||
---
|
||||
|
||||
<div align="center">
|
||||
Made with ❤️ by a developer who got tired of rebuilding the same boilerplate
|
||||
</div>
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
[run]
|
||||
source = app
|
||||
concurrency = thread,greenlet
|
||||
omit =
|
||||
# Migration files - these are generated code and shouldn't be tested
|
||||
app/alembic/versions/*
|
||||
@@ -10,16 +11,19 @@ omit =
|
||||
app/utils/auth_test_utils.py
|
||||
|
||||
# Async implementations not yet in use
|
||||
app/crud/base_async.py
|
||||
app/repositories/base_async.py
|
||||
app/core/database_async.py
|
||||
|
||||
# CLI scripts - run manually, not tested
|
||||
app/init_db.py
|
||||
|
||||
# __init__ files with no logic
|
||||
app/__init__.py
|
||||
app/api/__init__.py
|
||||
app/api/routes/__init__.py
|
||||
app/api/dependencies/__init__.py
|
||||
app/core/__init__.py
|
||||
app/crud/__init__.py
|
||||
app/repositories/__init__.py
|
||||
app/models/__init__.py
|
||||
app/schemas/__init__.py
|
||||
app/services/__init__.py
|
||||
@@ -61,6 +65,10 @@ exclude_lines =
|
||||
# Pass statements (often in abstract base classes or placeholders)
|
||||
pass
|
||||
|
||||
# Skip test environment checks (production-only code)
|
||||
if os\.getenv\("IS_TEST".*\) == "True":
|
||||
if os\.getenv\("IS_TEST".*\) != "True":
|
||||
|
||||
[html]
|
||||
directory = htmlcov
|
||||
|
||||
|
||||
@@ -1,2 +1,17 @@
|
||||
.venv
|
||||
*.iml
|
||||
*.iml
|
||||
|
||||
# Python build and cache artifacts
|
||||
__pycache__/
|
||||
.pytest_cache/
|
||||
.mypy_cache/
|
||||
.ruff_cache/
|
||||
*.pyc
|
||||
*.pyo
|
||||
|
||||
# Packaging artifacts
|
||||
*.egg-info/
|
||||
build/
|
||||
dist/
|
||||
htmlcov/
|
||||
.uv_cache/
|
||||
44
backend/.pre-commit-config.yaml
Normal file
44
backend/.pre-commit-config.yaml
Normal file
@@ -0,0 +1,44 @@
|
||||
# Pre-commit hooks for backend quality and security checks.
|
||||
#
|
||||
# Install:
|
||||
# cd backend && uv run pre-commit install
|
||||
#
|
||||
# Run manually on all files:
|
||||
# cd backend && uv run pre-commit run --all-files
|
||||
#
|
||||
# Skip hooks temporarily:
|
||||
# git commit --no-verify
|
||||
#
|
||||
repos:
|
||||
# ── Code Quality ──────────────────────────────────────────────────────────
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.14.4
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix, --exit-non-zero-on-fix]
|
||||
- id: ruff-format
|
||||
|
||||
# ── General File Hygiene ──────────────────────────────────────────────────
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v5.0.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
- id: end-of-file-fixer
|
||||
- id: check-yaml
|
||||
- id: check-toml
|
||||
- id: check-merge-conflict
|
||||
- id: check-added-large-files
|
||||
args: [--maxkb=500]
|
||||
- id: debug-statements
|
||||
|
||||
# ── Security ──────────────────────────────────────────────────────────────
|
||||
- repo: https://github.com/Yelp/detect-secrets
|
||||
rev: v1.5.0
|
||||
hooks:
|
||||
- id: detect-secrets
|
||||
args: ['--baseline', '.secrets.baseline']
|
||||
exclude: |
|
||||
(?x)^(
|
||||
.*\.lock$|
|
||||
.*\.svg$
|
||||
)$
|
||||
1073
backend/.secrets.baseline
Normal file
1073
backend/.secrets.baseline
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,6 +0,0 @@
|
||||
Requirement already satisfied: alembic in ./.venv/lib/python3.12/site-packages (1.14.1)
|
||||
Requirement already satisfied: SQLAlchemy>=1.3.0 in ./.venv/lib/python3.12/site-packages (from alembic) (2.0.38)
|
||||
Requirement already satisfied: Mako in ./.venv/lib/python3.12/site-packages (from alembic) (1.3.9)
|
||||
Requirement already satisfied: typing-extensions>=4 in ./.venv/lib/python3.12/site-packages (from alembic) (4.12.2)
|
||||
Requirement already satisfied: greenlet!=0.4.17 in ./.venv/lib/python3.12/site-packages (from SQLAlchemy>=1.3.0->alembic) (3.1.1)
|
||||
Requirement already satisfied: MarkupSafe>=0.9.2 in ./.venv/lib/python3.12/site-packages (from Mako->alembic) (3.0.2)
|
||||
@@ -1,53 +1,67 @@
|
||||
# Development stage
|
||||
FROM python:3.12-slim AS development
|
||||
|
||||
# Create non-root user
|
||||
RUN groupadd -r appuser && useradd -r -g appuser appuser
|
||||
|
||||
WORKDIR /app
|
||||
ENV PYTHONDONTWRITEBYTECODE=1 \
|
||||
PYTHONUNBUFFERED=1 \
|
||||
PYTHONPATH=/app
|
||||
PYTHONPATH=/app \
|
||||
UV_COMPILE_BYTECODE=1 \
|
||||
UV_LINK_MODE=copy \
|
||||
UV_NO_CACHE=1
|
||||
|
||||
# Install system dependencies and uv
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends gcc postgresql-client curl && \
|
||||
apt-get install -y --no-install-recommends gcc postgresql-client curl ca-certificates && \
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh && \
|
||||
mv /root/.local/bin/uv* /usr/local/bin/ && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
# Copy dependency files
|
||||
COPY pyproject.toml uv.lock ./
|
||||
|
||||
# Install dependencies using uv (development mode with dev dependencies)
|
||||
RUN uv sync --extra dev --frozen
|
||||
|
||||
# Copy application code
|
||||
COPY . .
|
||||
COPY entrypoint.sh /usr/local/bin/
|
||||
RUN chmod +x /usr/local/bin/entrypoint.sh
|
||||
|
||||
# Set ownership to non-root user
|
||||
RUN chown -R appuser:appuser /app
|
||||
|
||||
# Switch to non-root user
|
||||
USER appuser
|
||||
# Note: Running as root in development for bind mount compatibility
|
||||
# Production stage uses non-root user for security
|
||||
|
||||
ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
|
||||
|
||||
# Production stage
|
||||
FROM python:3.12-slim AS production
|
||||
# Production stage — Alpine eliminates glibc CVEs (e.g. CVE-2026-0861)
|
||||
FROM python:3.12-alpine AS production
|
||||
|
||||
# Create non-root user
|
||||
RUN groupadd -r appuser && useradd -r -g appuser appuser
|
||||
RUN addgroup -S appuser && adduser -S -G appuser appuser
|
||||
|
||||
WORKDIR /app
|
||||
ENV PYTHONDONTWRITEBYTECODE=1 \
|
||||
PYTHONUNBUFFERED=1 \
|
||||
PYTHONPATH=/app
|
||||
PYTHONPATH=/app \
|
||||
UV_COMPILE_BYTECODE=1 \
|
||||
UV_LINK_MODE=copy \
|
||||
UV_NO_CACHE=1
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends postgresql-client curl && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
# Install system dependencies and uv
|
||||
RUN apk add --no-cache postgresql-client curl ca-certificates && \
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh && \
|
||||
mv /root/.local/bin/uv* /usr/local/bin/
|
||||
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
# Copy dependency files
|
||||
COPY pyproject.toml uv.lock ./
|
||||
|
||||
# Install build dependencies, compile Python packages, then remove build deps
|
||||
RUN apk add --no-cache --virtual .build-deps \
|
||||
gcc g++ musl-dev python3-dev linux-headers libffi-dev openssl-dev && \
|
||||
uv sync --frozen --no-dev && \
|
||||
apk del .build-deps
|
||||
|
||||
# Copy application code
|
||||
COPY . .
|
||||
COPY entrypoint.sh /usr/local/bin/
|
||||
RUN chmod +x /usr/local/bin/entrypoint.sh
|
||||
@@ -63,4 +77,4 @@ HEALTHCHECK --interval=30s --timeout=10s --start-period=40s --retries=3 \
|
||||
CMD curl -f http://localhost:8000/health || exit 1
|
||||
|
||||
ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
CMD ["uv", "run", "uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
|
||||
220
backend/Makefile
Normal file
220
backend/Makefile
Normal file
@@ -0,0 +1,220 @@
|
||||
.PHONY: help lint lint-fix format format-check type-check test test-cov validate clean install-dev sync check-docker install-e2e test-e2e test-e2e-schema test-all dep-audit license-check audit validate-all check benchmark benchmark-check benchmark-save scan-image test-api-security
|
||||
|
||||
# Prevent a stale VIRTUAL_ENV in the caller's shell from confusing uv
|
||||
unexport VIRTUAL_ENV
|
||||
|
||||
# Default target
|
||||
help:
|
||||
@echo "🚀 FastAPI Backend - Development Commands"
|
||||
@echo ""
|
||||
@echo "Setup:"
|
||||
@echo " make install-dev - Install all dependencies with uv (includes dev)"
|
||||
@echo " make install-e2e - Install E2E test dependencies (requires Docker)"
|
||||
@echo " make sync - Sync dependencies from uv.lock"
|
||||
@echo ""
|
||||
@echo "Quality Checks:"
|
||||
@echo " make lint - Run Ruff linter (check only)"
|
||||
@echo " make lint-fix - Run Ruff linter with auto-fix"
|
||||
@echo " make format - Format code with Ruff"
|
||||
@echo " make format-check - Check if code is formatted"
|
||||
@echo " make type-check - Run pyright type checking"
|
||||
@echo " make validate - Run all checks (lint + format + types + schema fuzz)"
|
||||
@echo ""
|
||||
@echo "Performance:"
|
||||
@echo " make benchmark - Run performance benchmarks"
|
||||
@echo " make benchmark-save - Run benchmarks and save as baseline"
|
||||
@echo " make benchmark-check - Run benchmarks and compare against baseline"
|
||||
@echo ""
|
||||
@echo "Security & Audit:"
|
||||
@echo " make dep-audit - Scan dependencies for known vulnerabilities"
|
||||
@echo " make license-check - Check dependency license compliance"
|
||||
@echo " make audit - Run all security audits (deps + licenses)"
|
||||
@echo " make scan-image - Scan Docker image for CVEs (requires trivy)"
|
||||
@echo " make validate-all - Run all quality + security checks"
|
||||
@echo " make check - Full pipeline: quality + security + tests"
|
||||
@echo ""
|
||||
@echo "Testing:"
|
||||
@echo " make test - Run pytest (unit/integration, SQLite)"
|
||||
@echo " make test-cov - Run pytest with coverage report"
|
||||
@echo " make test-e2e - Run E2E tests (PostgreSQL, requires Docker)"
|
||||
@echo " make test-e2e-schema - Run Schemathesis API schema tests"
|
||||
@echo " make test-all - Run all tests (unit + E2E)"
|
||||
@echo " make check-docker - Check if Docker is available"
|
||||
@echo " make check - Full pipeline: quality + security + tests"
|
||||
@echo ""
|
||||
@echo "Cleanup:"
|
||||
@echo " make clean - Remove cache and build artifacts"
|
||||
|
||||
# ============================================================================
|
||||
# Setup & Cleanup
|
||||
# ============================================================================
|
||||
|
||||
install-dev:
|
||||
@echo "📦 Installing all dependencies with uv (includes dev)..."
|
||||
@uv sync --extra dev
|
||||
@echo "✅ Development environment ready!"
|
||||
|
||||
sync:
|
||||
@echo "🔄 Syncing dependencies from uv.lock..."
|
||||
@uv sync --extra dev
|
||||
@echo "✅ Dependencies synced!"
|
||||
|
||||
# ============================================================================
|
||||
# Code Quality
|
||||
# ============================================================================
|
||||
|
||||
lint:
|
||||
@echo "🔍 Running Ruff linter..."
|
||||
@uv run ruff check app/ tests/
|
||||
|
||||
lint-fix:
|
||||
@echo "🔧 Running Ruff linter with auto-fix..."
|
||||
@uv run ruff check --fix app/ tests/
|
||||
|
||||
format:
|
||||
@echo "✨ Formatting code with Ruff..."
|
||||
@uv run ruff format app/ tests/
|
||||
|
||||
format-check:
|
||||
@echo "📋 Checking code formatting..."
|
||||
@uv run ruff format --check app/ tests/
|
||||
|
||||
type-check:
|
||||
@echo "🔎 Running pyright type checking..."
|
||||
@uv run pyright app/
|
||||
|
||||
validate: lint format-check type-check test-api-security
|
||||
@echo "✅ All quality checks passed!"
|
||||
|
||||
# API Security Testing (Schemathesis property-based fuzzing)
|
||||
test-api-security: check-docker
|
||||
@echo "🔐 Running Schemathesis API security fuzzing..."
|
||||
@IS_TEST=True PYTHONPATH=. uv run pytest tests/e2e/ -v -m "schemathesis" --tb=short -n 0
|
||||
@echo "✅ API schema security tests passed!"
|
||||
|
||||
# ============================================================================
|
||||
# Security & Audit
|
||||
# ============================================================================
|
||||
|
||||
dep-audit:
|
||||
@echo "🔒 Scanning dependencies for known vulnerabilities..."
|
||||
@uv run pip-audit --desc --skip-editable
|
||||
@echo "✅ No known vulnerabilities found!"
|
||||
|
||||
license-check:
|
||||
@echo "📜 Checking dependency license compliance..."
|
||||
@uv run pip-licenses --fail-on="GPL-3.0-or-later;AGPL-3.0-or-later" --format=plain > /dev/null
|
||||
@echo "✅ All dependency licenses are compliant!"
|
||||
|
||||
audit: dep-audit license-check
|
||||
@echo "✅ All security audits passed!"
|
||||
|
||||
scan-image: check-docker
|
||||
@echo "🐳 Scanning Docker image for OS-level CVEs with Trivy..."
|
||||
@docker build -t pragma-backend:scan -q --target production .
|
||||
@if command -v trivy > /dev/null 2>&1; then \
|
||||
trivy image --severity HIGH,CRITICAL --exit-code 1 pragma-backend:scan; \
|
||||
else \
|
||||
echo "ℹ️ Trivy not found locally, using Docker to run Trivy..."; \
|
||||
docker run --rm -v /var/run/docker.sock:/var/run/docker.sock aquasec/trivy image --severity HIGH,CRITICAL --exit-code 1 pragma-backend:scan; \
|
||||
fi
|
||||
@echo "✅ No HIGH/CRITICAL CVEs found in Docker image!"
|
||||
|
||||
validate-all: validate audit
|
||||
@echo "✅ All quality + security checks passed!"
|
||||
|
||||
check: validate-all test
|
||||
@echo "✅ Full validation pipeline complete!"
|
||||
|
||||
# ============================================================================
|
||||
# Testing
|
||||
# ============================================================================
|
||||
|
||||
test:
|
||||
@echo "🧪 Running tests..."
|
||||
@IS_TEST=True PYTHONPATH=. uv run pytest
|
||||
|
||||
test-cov:
|
||||
@echo "🧪 Running tests with coverage..."
|
||||
@IS_TEST=True PYTHONPATH=. uv run pytest --cov=app --cov-report=term-missing --cov-report=html -n 16
|
||||
@echo "📊 Coverage report generated in htmlcov/index.html"
|
||||
|
||||
# ============================================================================
|
||||
# E2E Testing (requires Docker)
|
||||
# ============================================================================
|
||||
|
||||
check-docker:
|
||||
@docker info > /dev/null 2>&1 || (echo ""; \
|
||||
echo "Docker is not running!"; \
|
||||
echo ""; \
|
||||
echo "E2E tests require Docker to be running."; \
|
||||
echo "Please start Docker Desktop or Docker Engine and try again."; \
|
||||
echo ""; \
|
||||
echo "Quick start:"; \
|
||||
echo " macOS/Windows: Open Docker Desktop"; \
|
||||
echo " Linux: sudo systemctl start docker"; \
|
||||
echo ""; \
|
||||
exit 1)
|
||||
@echo "Docker is available"
|
||||
|
||||
install-e2e:
|
||||
@echo "📦 Installing E2E test dependencies..."
|
||||
@uv sync --extra dev --extra e2e
|
||||
@echo "✅ E2E dependencies installed!"
|
||||
|
||||
test-e2e: check-docker
|
||||
@echo "🧪 Running E2E tests with PostgreSQL..."
|
||||
@IS_TEST=True PYTHONPATH=. uv run pytest tests/e2e/ -v --tb=short -n 0
|
||||
@echo "✅ E2E tests complete!"
|
||||
|
||||
test-e2e-schema: check-docker
|
||||
@echo "🧪 Running Schemathesis API schema tests..."
|
||||
@IS_TEST=True PYTHONPATH=. uv run pytest tests/e2e/ -v -m "schemathesis" --tb=short -n 0
|
||||
|
||||
# ============================================================================
|
||||
# Performance Benchmarks
|
||||
# ============================================================================
|
||||
|
||||
benchmark:
|
||||
@echo "⏱️ Running performance benchmarks..."
|
||||
@IS_TEST=True PYTHONPATH=. uv run pytest tests/benchmarks/ -v --benchmark-only --benchmark-sort=mean -p no:xdist --override-ini='addopts='
|
||||
|
||||
benchmark-save:
|
||||
@echo "⏱️ Running benchmarks and saving baseline..."
|
||||
@IS_TEST=True PYTHONPATH=. uv run pytest tests/benchmarks/ -v --benchmark-only --benchmark-save=baseline --benchmark-sort=mean -p no:xdist --override-ini='addopts='
|
||||
@echo "✅ Benchmark baseline saved to .benchmarks/"
|
||||
|
||||
benchmark-check:
|
||||
@echo "⏱️ Running benchmarks and comparing against baseline..."
|
||||
@if find .benchmarks -name '*_baseline*' -print -quit 2>/dev/null | grep -q .; then \
|
||||
IS_TEST=True PYTHONPATH=. uv run pytest tests/benchmarks/ -v --benchmark-only --benchmark-compare=0001_baseline --benchmark-sort=mean --benchmark-compare-fail=mean:200% -p no:xdist --override-ini='addopts='; \
|
||||
echo "✅ No performance regressions detected!"; \
|
||||
else \
|
||||
echo "⚠️ No benchmark baseline found. Run 'make benchmark-save' first to create one."; \
|
||||
echo " Running benchmarks without comparison..."; \
|
||||
IS_TEST=True PYTHONPATH=. uv run pytest tests/benchmarks/ -v --benchmark-only --benchmark-save=baseline --benchmark-sort=mean -p no:xdist --override-ini='addopts='; \
|
||||
echo "✅ Benchmark baseline created. Future runs of 'make benchmark-check' will compare against it."; \
|
||||
fi
|
||||
|
||||
test-all:
|
||||
@echo "🧪 Running ALL tests (unit + E2E)..."
|
||||
@$(MAKE) test
|
||||
@$(MAKE) test-e2e
|
||||
|
||||
# ============================================================================
|
||||
# Cleanup
|
||||
# ============================================================================
|
||||
|
||||
clean:
|
||||
@echo "🧹 Cleaning up..."
|
||||
@find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true
|
||||
@find . -type d -name ".pytest_cache" -exec rm -rf {} + 2>/dev/null || true
|
||||
@find . -type d -name ".pyright" -exec rm -rf {} + 2>/dev/null || true
|
||||
@find . -type d -name ".ruff_cache" -exec rm -rf {} + 2>/dev/null || true
|
||||
@find . -type d -name "*.egg-info" -exec rm -rf {} + 2>/dev/null || true
|
||||
@find . -type d -name "htmlcov" -exec rm -rf {} + 2>/dev/null || true
|
||||
@find . -type d -name "build" -exec rm -rf {} + 2>/dev/null || true
|
||||
@find . -type d -name ".uv_cache" -exec rm -rf {} + 2>/dev/null || true
|
||||
@find . -type f -name ".coverage" -delete 2>/dev/null || true
|
||||
@find . -type f -name "*.pyc" -delete 2>/dev/null || true
|
||||
@echo "✅ Cleanup complete!"
|
||||
707
backend/README.md
Normal file
707
backend/README.md
Normal file
@@ -0,0 +1,707 @@
|
||||
# PragmaStack Backend API
|
||||
|
||||
> The pragmatic, production-ready FastAPI backend for PragmaStack.
|
||||
|
||||
## Overview
|
||||
|
||||
Opinionated, secure, and fast. This backend provides the solid foundation you need to ship features, not boilerplate.
|
||||
|
||||
Features:
|
||||
|
||||
- **Authentication**: JWT with refresh tokens, session management, device tracking
|
||||
- **Database**: Async PostgreSQL with SQLAlchemy 2.0, Alembic migrations
|
||||
- **Security**: Rate limiting, CORS, CSP headers, password hashing (bcrypt)
|
||||
- **Multi-tenancy**: Organization-based access control with roles (Owner/Admin/Member)
|
||||
- **Testing**: 97%+ coverage with security-focused test suite
|
||||
- **Performance**: Async throughout, connection pooling, optimized queries
|
||||
- **Modern Tooling**: uv for dependencies, Ruff for linting/formatting, Pyright for type checking
|
||||
- **Security Auditing**: Automated dependency vulnerability scanning, license compliance, secrets detection
|
||||
- **Pre-commit Hooks**: Ruff, detect-secrets, and standard checks on every commit
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Python 3.12+
|
||||
- PostgreSQL 14+ (or SQLite for development)
|
||||
- **[uv](https://docs.astral.sh/uv/)** - Modern Python package manager (replaces pip)
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# Install uv (if not already installed)
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
||||
# Install all dependencies (production + dev)
|
||||
cd backend
|
||||
uv sync --extra dev
|
||||
|
||||
# Or use the Makefile
|
||||
make install-dev
|
||||
|
||||
# Copy environment template
|
||||
cp .env.example .env
|
||||
# Edit .env with your configuration
|
||||
```
|
||||
|
||||
**Why uv?**
|
||||
- 🚀 10-100x faster than pip
|
||||
- 🔒 Reproducible builds via `uv.lock` lockfile
|
||||
- 📦 Better dependency resolution
|
||||
- ⚡ Built by Astral (creators of Ruff)
|
||||
|
||||
### Database Setup
|
||||
|
||||
```bash
|
||||
# Run migrations
|
||||
python migrate.py apply
|
||||
|
||||
# Or use Alembic directly
|
||||
alembic upgrade head
|
||||
```
|
||||
|
||||
### Run Development Server
|
||||
|
||||
```bash
|
||||
# Using uv
|
||||
uv run uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
|
||||
|
||||
# Or activate environment first
|
||||
source .venv/bin/activate # On Windows: .venv\Scripts\activate
|
||||
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||
API will be available at:
|
||||
- **API**: http://localhost:8000
|
||||
- **Swagger Docs**: http://localhost:8000/docs
|
||||
- **ReDoc**: http://localhost:8000/redoc
|
||||
|
||||
---
|
||||
|
||||
## Dependency Management with uv
|
||||
|
||||
### Understanding uv
|
||||
|
||||
**uv** is the modern standard for Python dependency management, built in Rust for speed and reliability.
|
||||
|
||||
**Key files:**
|
||||
- `pyproject.toml` - Declares dependencies and tool configurations
|
||||
- `uv.lock` - Locks exact versions for reproducible builds (commit to git)
|
||||
|
||||
### Common Commands
|
||||
|
||||
#### Installing Dependencies
|
||||
|
||||
```bash
|
||||
# Install all dependencies from lockfile
|
||||
uv sync --extra dev
|
||||
|
||||
# Install only production dependencies (no dev tools)
|
||||
uv sync
|
||||
|
||||
# Or use the Makefile
|
||||
make install-dev # Install with dev dependencies
|
||||
make sync # Sync from lockfile
|
||||
```
|
||||
|
||||
#### Adding Dependencies
|
||||
|
||||
```bash
|
||||
# Add a production dependency
|
||||
uv add httpx
|
||||
|
||||
# Add a development dependency
|
||||
uv add --dev pytest-mock
|
||||
|
||||
# Add with version constraint
|
||||
uv add "fastapi>=0.115.0,<0.116.0"
|
||||
|
||||
# Add exact version
|
||||
uv add "pydantic==2.10.6"
|
||||
```
|
||||
|
||||
After adding dependencies, **commit both `pyproject.toml` and `uv.lock`** to git.
|
||||
|
||||
#### Removing Dependencies
|
||||
|
||||
```bash
|
||||
# Remove a package
|
||||
uv remove httpx
|
||||
|
||||
# Remove a dev dependency
|
||||
uv remove --dev pytest-mock
|
||||
```
|
||||
|
||||
#### Updating Dependencies
|
||||
|
||||
```bash
|
||||
# Update all packages to latest compatible versions
|
||||
uv sync --upgrade
|
||||
|
||||
# Update a specific package
|
||||
uv add --upgrade fastapi
|
||||
|
||||
# Check for outdated packages
|
||||
uv pip list --outdated
|
||||
```
|
||||
|
||||
#### Running Commands in uv Environment
|
||||
|
||||
```bash
|
||||
# Run any Python command via uv (no activation needed)
|
||||
uv run python script.py
|
||||
uv run pytest
|
||||
uv run pyright app/
|
||||
|
||||
# Or activate the virtual environment
|
||||
source .venv/bin/activate
|
||||
python script.py
|
||||
pytest
|
||||
```
|
||||
|
||||
### Makefile Commands
|
||||
|
||||
We provide convenient Makefile commands that use uv:
|
||||
|
||||
```bash
|
||||
# Setup
|
||||
make install-dev # Install all dependencies (prod + dev)
|
||||
make sync # Sync from lockfile
|
||||
|
||||
# Code Quality
|
||||
make lint # Run Ruff linter (check only)
|
||||
make lint-fix # Run Ruff with auto-fix
|
||||
make format # Format code with Ruff
|
||||
make format-check # Check if code is formatted
|
||||
make type-check # Run Pyright type checking
|
||||
make validate # Run all checks (lint + format + types)
|
||||
|
||||
# Security & Audit
|
||||
make dep-audit # Scan dependencies for known vulnerabilities (CVEs)
|
||||
make license-check # Check dependency license compliance
|
||||
make audit # Run all security audits (deps + licenses)
|
||||
make validate-all # Run all quality + security checks
|
||||
make check # Full pipeline: quality + security + tests
|
||||
|
||||
# Testing
|
||||
make test # Run all tests
|
||||
make test-cov # Run tests with coverage report
|
||||
make test-e2e # Run E2E tests (PostgreSQL, requires Docker)
|
||||
make test-e2e-schema # Run Schemathesis API schema tests
|
||||
make test-all # Run all tests (unit + E2E)
|
||||
|
||||
# Utilities
|
||||
make clean # Remove cache and build artifacts
|
||||
make help # Show all commands
|
||||
```
|
||||
|
||||
### Dependency Workflow Example
|
||||
|
||||
```bash
|
||||
# 1. Clone repository
|
||||
git clone <repo-url>
|
||||
cd backend
|
||||
|
||||
# 2. Install dependencies
|
||||
make install-dev
|
||||
|
||||
# 3. Make changes, add a new dependency
|
||||
uv add httpx
|
||||
|
||||
# 4. Test your changes
|
||||
make test
|
||||
|
||||
# 5. Commit (includes uv.lock)
|
||||
git add pyproject.toml uv.lock
|
||||
git commit -m "Add httpx dependency"
|
||||
|
||||
# 6. Other developers pull and sync
|
||||
git pull
|
||||
make sync # Uses the committed uv.lock
|
||||
```
|
||||
|
||||
### Troubleshooting uv
|
||||
|
||||
**Dependencies not found after install:**
|
||||
```bash
|
||||
# Make sure you're using uv run or activated environment
|
||||
uv run pytest # Option 1: Run via uv
|
||||
source .venv/bin/activate # Option 2: Activate first
|
||||
pytest
|
||||
```
|
||||
|
||||
**Lockfile out of sync:**
|
||||
```bash
|
||||
# Regenerate lockfile
|
||||
uv lock
|
||||
|
||||
# Force reinstall from lockfile
|
||||
uv sync --reinstall
|
||||
```
|
||||
|
||||
**uv not found:**
|
||||
```bash
|
||||
# Install uv globally
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
||||
# Add to PATH if needed
|
||||
export PATH="$HOME/.cargo/bin:$PATH"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Development
|
||||
|
||||
### Project Structure
|
||||
|
||||
```
|
||||
app/
|
||||
├── api/ # API routes and dependencies
|
||||
│ ├── routes/ # Endpoint implementations
|
||||
│ └── dependencies/ # Auth, permissions, etc.
|
||||
├── core/ # Core functionality
|
||||
│ ├── config.py # Settings management
|
||||
│ ├── database.py # Database engine setup
|
||||
│ ├── auth.py # JWT token handling
|
||||
│ └── exceptions.py # Custom exceptions
|
||||
├── repositories/ # Repository pattern (database operations)
|
||||
├── models/ # SQLAlchemy ORM models
|
||||
├── schemas/ # Pydantic request/response schemas
|
||||
├── services/ # Business logic layer
|
||||
└── utils/ # Utility functions
|
||||
```
|
||||
|
||||
See [docs/ARCHITECTURE.md](docs/ARCHITECTURE.md) for detailed architecture documentation.
|
||||
|
||||
### Configuration
|
||||
|
||||
Environment variables (`.env`):
|
||||
|
||||
```bash
|
||||
# Database
|
||||
POSTGRES_USER=postgres
|
||||
POSTGRES_PASSWORD=your_password
|
||||
POSTGRES_HOST=localhost
|
||||
POSTGRES_PORT=5432
|
||||
POSTGRES_DB=app_db
|
||||
|
||||
# Security (IMPORTANT: Change these!)
|
||||
SECRET_KEY=your-secret-key-min-32-chars-change-in-production
|
||||
ENVIRONMENT=development # development | production
|
||||
|
||||
# Optional
|
||||
BACKEND_CORS_ORIGINS=["http://localhost:3000"]
|
||||
CSP_MODE=relaxed # strict | relaxed | disabled
|
||||
|
||||
# First superuser (auto-created on startup)
|
||||
FIRST_SUPERUSER_EMAIL=admin@example.com
|
||||
FIRST_SUPERUSER_PASSWORD=SecurePass123!
|
||||
```
|
||||
|
||||
⚠️ **Security Note**: Never commit `.env` files. Use strong, unique values in production.
|
||||
|
||||
### Database Migrations
|
||||
|
||||
We use Alembic for database migrations with a helper script:
|
||||
|
||||
```bash
|
||||
# Generate migration from model changes
|
||||
python migrate.py generate "add user preferences"
|
||||
|
||||
# Apply migrations
|
||||
python migrate.py apply
|
||||
|
||||
# Generate and apply in one step
|
||||
python migrate.py auto "add user preferences"
|
||||
|
||||
# Check current version
|
||||
python migrate.py current
|
||||
|
||||
# List all migrations
|
||||
python migrate.py list
|
||||
```
|
||||
|
||||
Manual Alembic usage:
|
||||
|
||||
```bash
|
||||
# Generate migration
|
||||
alembic revision --autogenerate -m "description"
|
||||
|
||||
# Apply migrations
|
||||
alembic upgrade head
|
||||
|
||||
# Rollback one migration
|
||||
alembic downgrade -1
|
||||
```
|
||||
|
||||
### Testing
|
||||
|
||||
```bash
|
||||
# Using Makefile (recommended)
|
||||
make test # Run all tests
|
||||
make test-cov # Run with coverage report
|
||||
|
||||
# Using uv directly
|
||||
IS_TEST=True uv run pytest
|
||||
IS_TEST=True uv run pytest --cov=app --cov-report=term-missing -n 0
|
||||
|
||||
# Run specific test file
|
||||
IS_TEST=True uv run pytest tests/api/test_auth.py -v
|
||||
|
||||
# Run single test
|
||||
IS_TEST=True uv run pytest tests/api/test_auth.py::TestLogin::test_login_success -v
|
||||
|
||||
# Generate HTML coverage report
|
||||
IS_TEST=True uv run pytest --cov=app --cov-report=html -n 0
|
||||
open htmlcov/index.html
|
||||
```
|
||||
|
||||
**Test Environment**: Uses SQLite in-memory database. Tests run in parallel via pytest-xdist.
|
||||
|
||||
### Code Quality
|
||||
|
||||
```bash
|
||||
# Using Makefile (recommended)
|
||||
make lint # Ruff linting
|
||||
make format # Ruff formatting
|
||||
make type-check # Pyright type checking
|
||||
make validate # All checks at once
|
||||
|
||||
# Security audits
|
||||
make dep-audit # Scan dependencies for CVEs
|
||||
make license-check # Check license compliance
|
||||
make audit # All security audits
|
||||
make validate-all # Quality + security checks
|
||||
make check # Full pipeline: quality + security + tests
|
||||
|
||||
# Using uv directly
|
||||
uv run ruff check app/ tests/
|
||||
uv run ruff format app/ tests/
|
||||
uv run pyright app/
|
||||
```
|
||||
|
||||
**Tools:**
|
||||
- **Ruff**: All-in-one linting, formatting, and import sorting (replaces Black, Flake8, isort)
|
||||
- **Pyright**: Static type checking (strict mode)
|
||||
- **pip-audit**: Dependency vulnerability scanning against the OSV database
|
||||
- **pip-licenses**: Dependency license compliance checking
|
||||
- **detect-secrets**: Hardcoded secrets/credentials detection
|
||||
- **pre-commit**: Git hook framework for automated checks on every commit
|
||||
|
||||
All configurations are in `pyproject.toml`.
|
||||
|
||||
---
|
||||
|
||||
## API Documentation
|
||||
|
||||
Once the server is running, interactive API documentation is available:
|
||||
|
||||
- **Swagger UI**: http://localhost:8000/docs
|
||||
- Try out endpoints directly
|
||||
- See request/response schemas
|
||||
- View authentication requirements
|
||||
|
||||
- **ReDoc**: http://localhost:8000/redoc
|
||||
- Alternative documentation interface
|
||||
- Better for reading/printing
|
||||
|
||||
- **OpenAPI JSON**: http://localhost:8000/api/v1/openapi.json
|
||||
- Raw OpenAPI 3.0 specification
|
||||
- Use for client generation
|
||||
|
||||
---
|
||||
|
||||
## Authentication
|
||||
|
||||
### Token-Based Authentication
|
||||
|
||||
The API uses JWT tokens for authentication:
|
||||
|
||||
1. **Login**: `POST /api/v1/auth/login`
|
||||
- Returns access token (15 min expiry) and refresh token (7 day expiry)
|
||||
- Session tracked with device information
|
||||
|
||||
2. **Refresh**: `POST /api/v1/auth/refresh`
|
||||
- Exchange refresh token for new access token
|
||||
- Validates session is still active
|
||||
|
||||
3. **Logout**: `POST /api/v1/auth/logout`
|
||||
- Invalidates current session
|
||||
- Use `logout-all` to invalidate all user sessions
|
||||
|
||||
### Using Protected Endpoints
|
||||
|
||||
Include access token in Authorization header:
|
||||
|
||||
```bash
|
||||
curl -H "Authorization: Bearer <access_token>" \
|
||||
http://localhost:8000/api/v1/users/me
|
||||
```
|
||||
|
||||
### Roles & Permissions
|
||||
|
||||
- **Superuser**: Full system access (user/org management)
|
||||
- **Organization Roles**:
|
||||
- `Owner`: Full control of organization
|
||||
- `Admin`: Can manage members (except owners)
|
||||
- `Member`: Read-only access
|
||||
|
||||
---
|
||||
|
||||
## Common Tasks
|
||||
|
||||
### Create a Superuser
|
||||
|
||||
Superusers are created automatically on startup using `FIRST_SUPERUSER_EMAIL` and `FIRST_SUPERUSER_PASSWORD` from `.env`.
|
||||
|
||||
To create additional superusers, update a user via SQL or admin API.
|
||||
|
||||
### Add a New API Endpoint
|
||||
|
||||
See [docs/FEATURE_EXAMPLE.md](docs/FEATURE_EXAMPLE.md) for step-by-step guide.
|
||||
|
||||
Quick overview:
|
||||
1. Create Pydantic schemas in `app/schemas/`
|
||||
2. Create repository in `app/repositories/`
|
||||
3. Create route in `app/api/routes/`
|
||||
4. Register router in `app/api/main.py`
|
||||
5. Write tests in `tests/api/`
|
||||
|
||||
### Database Health Check
|
||||
|
||||
```bash
|
||||
# Check database connection
|
||||
python migrate.py check
|
||||
|
||||
# Health endpoint
|
||||
curl http://localhost:8000/health
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Docker Support
|
||||
|
||||
The Dockerfile uses **uv** for fast, reproducible builds:
|
||||
|
||||
```bash
|
||||
# Development with hot reload
|
||||
docker-compose -f docker-compose.dev.yml up
|
||||
|
||||
# Production
|
||||
docker-compose up -d
|
||||
|
||||
# Rebuild after changes
|
||||
docker-compose build backend
|
||||
```
|
||||
|
||||
**Docker features:**
|
||||
- Multi-stage builds (development + production)
|
||||
- uv for fast dependency installation
|
||||
- `uv.lock` ensures exact versions in containers
|
||||
- Development stage includes dev dependencies
|
||||
- Production stage optimized for size and security
|
||||
|
||||
---
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
**Module Import Errors**
|
||||
```bash
|
||||
# Ensure dependencies are installed
|
||||
make install-dev
|
||||
|
||||
# Or sync from lockfile
|
||||
make sync
|
||||
|
||||
# Verify Python environment
|
||||
uv run python --version
|
||||
```
|
||||
|
||||
**uv command not found**
|
||||
```bash
|
||||
# Install uv globally
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
||||
# Add to PATH (add to ~/.bashrc or ~/.zshrc)
|
||||
export PATH="$HOME/.cargo/bin:$PATH"
|
||||
```
|
||||
|
||||
**Database Connection Failed**
|
||||
```bash
|
||||
# Check PostgreSQL is running
|
||||
sudo systemctl status postgresql
|
||||
|
||||
# Verify credentials in .env
|
||||
cat .env | grep POSTGRES
|
||||
```
|
||||
|
||||
**Migration Conflicts**
|
||||
```bash
|
||||
# Check migration history
|
||||
python migrate.py list
|
||||
|
||||
# Downgrade and retry
|
||||
alembic downgrade -1
|
||||
alembic upgrade head
|
||||
```
|
||||
|
||||
**Tests Failing**
|
||||
```bash
|
||||
# Run with verbose output
|
||||
make test
|
||||
|
||||
# Run single test to isolate issue
|
||||
IS_TEST=True uv run pytest tests/api/test_auth.py::TestLogin::test_login_success -vv
|
||||
```
|
||||
|
||||
**Dependencies out of sync**
|
||||
```bash
|
||||
# Regenerate lockfile from pyproject.toml
|
||||
uv lock
|
||||
|
||||
# Reinstall everything
|
||||
make install-dev
|
||||
```
|
||||
|
||||
### Getting Help
|
||||
|
||||
See our detailed documentation:
|
||||
|
||||
- [ARCHITECTURE.md](docs/ARCHITECTURE.md) - System design and patterns
|
||||
- [CODING_STANDARDS.md](docs/CODING_STANDARDS.md) - Code quality guidelines
|
||||
- [COMMON_PITFALLS.md](docs/COMMON_PITFALLS.md) - Mistakes to avoid
|
||||
- [FEATURE_EXAMPLE.md](docs/FEATURE_EXAMPLE.md) - Adding new features
|
||||
|
||||
---
|
||||
|
||||
## Performance
|
||||
|
||||
### Database Connection Pooling
|
||||
|
||||
Configured in `app/core/config.py`:
|
||||
- Pool size: 20 connections
|
||||
- Max overflow: 50 connections
|
||||
- Pool timeout: 30 seconds
|
||||
- Connection recycling: 1 hour
|
||||
|
||||
### Async Operations
|
||||
|
||||
- All I/O operations use async/await
|
||||
- CPU-intensive operations (bcrypt) run in thread pool
|
||||
- No blocking calls in request handlers
|
||||
|
||||
### Query Optimization
|
||||
|
||||
- N+1 query prevention via eager loading
|
||||
- Bulk operations for admin actions
|
||||
- Indexed foreign keys and common lookups
|
||||
|
||||
---
|
||||
|
||||
## Security
|
||||
|
||||
### Built-in Security Features
|
||||
|
||||
- **Password Security**: bcrypt hashing, strength validation, common password blocking
|
||||
- **Token Security**: HMAC-SHA256 signed, short-lived access tokens, algorithm validation
|
||||
- **Session Management**: Database-backed, device tracking, revocation support
|
||||
- **Rate Limiting**: Per-endpoint limits on auth/sensitive operations
|
||||
- **CORS**: Explicit origins, methods, and headers only
|
||||
- **Security Headers**: CSP, HSTS, X-Frame-Options, etc.
|
||||
- **Input Validation**: Pydantic schemas, SQL injection prevention (ORM)
|
||||
|
||||
### Security Auditing
|
||||
|
||||
Automated, deterministic security checks are built into the development workflow:
|
||||
|
||||
```bash
|
||||
# Scan dependencies for known vulnerabilities (CVEs)
|
||||
make dep-audit
|
||||
|
||||
# Check dependency license compliance (blocks GPL-3.0/AGPL)
|
||||
make license-check
|
||||
|
||||
# Run all security audits
|
||||
make audit
|
||||
|
||||
# Full pipeline: quality + security + tests
|
||||
make check
|
||||
```
|
||||
|
||||
**Pre-commit hooks** automatically run on every commit:
|
||||
- **Ruff** lint + format checks
|
||||
- **detect-secrets** blocks commits containing hardcoded secrets
|
||||
- **Standard checks**: trailing whitespace, YAML/TOML validation, merge conflict detection, large file prevention
|
||||
|
||||
Setup pre-commit hooks:
|
||||
```bash
|
||||
uv run pre-commit install
|
||||
```
|
||||
|
||||
### Security Best Practices
|
||||
|
||||
1. **Never commit secrets**: Use `.env` files (git-ignored), enforced by detect-secrets pre-commit hook
|
||||
2. **Strong SECRET_KEY**: Min 32 chars, cryptographically random
|
||||
3. **HTTPS in production**: Required for token security
|
||||
4. **Regular updates**: Keep dependencies current (`uv sync --upgrade`), run `make dep-audit` to check for CVEs
|
||||
5. **Audit logs**: Monitor authentication events
|
||||
6. **Run `make check` before pushing**: Validates quality, security, and tests in one command
|
||||
|
||||
---
|
||||
|
||||
## Monitoring
|
||||
|
||||
### Health Check
|
||||
|
||||
```bash
|
||||
curl http://localhost:8000/health
|
||||
```
|
||||
|
||||
Returns:
|
||||
- API version
|
||||
- Environment
|
||||
- Database connectivity
|
||||
- Timestamp
|
||||
|
||||
### Logging
|
||||
|
||||
Logs are written to stdout with structured format:
|
||||
|
||||
```python
|
||||
# Configure log level
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
# In production, use JSON logs for log aggregation
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Additional Resources
|
||||
|
||||
### Official Documentation
|
||||
- **uv**: https://docs.astral.sh/uv/
|
||||
- **FastAPI**: https://fastapi.tiangolo.com
|
||||
- **SQLAlchemy 2.0**: https://docs.sqlalchemy.org/en/20/
|
||||
- **Pydantic**: https://docs.pydantic.dev/
|
||||
- **Alembic**: https://alembic.sqlalchemy.org/
|
||||
- **Ruff**: https://docs.astral.sh/ruff/
|
||||
|
||||
### Our Documentation
|
||||
- [Root README](../README.md) - Project-wide information
|
||||
- [CLAUDE.md](../CLAUDE.md) - Comprehensive development guide
|
||||
|
||||
---
|
||||
|
||||
**Built with modern Python tooling:**
|
||||
- 🚀 **uv** - 10-100x faster dependency management
|
||||
- ⚡ **Ruff** - 10-100x faster linting & formatting
|
||||
- 🔍 **Pyright** - Static type checking (strict mode)
|
||||
- ✅ **pytest** - Comprehensive test suite
|
||||
- 🔒 **pip-audit** - Dependency vulnerability scanning
|
||||
- 🔑 **detect-secrets** - Hardcoded secrets detection
|
||||
- 📜 **pip-licenses** - License compliance checking
|
||||
- 🪝 **pre-commit** - Automated git hooks
|
||||
|
||||
**All configured in a single `pyproject.toml` file!**
|
||||
@@ -2,6 +2,13 @@
|
||||
script_location = app/alembic
|
||||
sqlalchemy.url = postgresql://postgres:postgres@db:5432/app
|
||||
|
||||
# Use sequential naming: 0001_message.py, 0002_message.py, etc.
|
||||
# The rev_id is still used internally but filename is cleaner
|
||||
file_template = %%(rev)s_%%(slug)s
|
||||
|
||||
# Allow specifying custom revision IDs via --rev-id flag
|
||||
revision_environment = true
|
||||
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
|
||||
0
backend/app/__init__.py
Normal file → Executable file
0
backend/app/__init__.py
Normal file → Executable file
@@ -2,10 +2,10 @@ import sys
|
||||
from logging.config import fileConfig
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy import engine_from_config
|
||||
from sqlalchemy import pool
|
||||
|
||||
from alembic import context
|
||||
from sqlalchemy import create_engine, engine_from_config, pool, text
|
||||
from sqlalchemy.engine.url import make_url
|
||||
from sqlalchemy.exc import OperationalError
|
||||
|
||||
# Get the path to the app directory (parent of 'alembic')
|
||||
app_dir = Path(__file__).resolve().parent.parent
|
||||
@@ -14,7 +14,6 @@ sys.path.append(str(app_dir.parent))
|
||||
|
||||
# Import Core modules
|
||||
from app.core.config import settings
|
||||
from app.core.database import Base
|
||||
|
||||
# Import all models to ensure they're registered with SQLAlchemy
|
||||
from app.models import *
|
||||
@@ -23,6 +22,25 @@ from app.models import *
|
||||
# access to the values within the .ini file in use.
|
||||
config = context.config
|
||||
|
||||
|
||||
def include_object(object, name, type_, reflected, compare_to):
|
||||
"""
|
||||
Filter objects for autogenerate.
|
||||
|
||||
Skip comparing functional indexes (like LOWER(column)) and partial indexes
|
||||
(with WHERE clauses) as Alembic cannot reliably detect these from models.
|
||||
These should be managed manually via dedicated performance migrations.
|
||||
|
||||
Convention: Any index starting with "ix_perf_" is automatically excluded.
|
||||
This allows adding new performance indexes without updating this file.
|
||||
"""
|
||||
if type_ == "index" and name:
|
||||
# Convention-based: any index prefixed with ix_perf_ is manual
|
||||
if name.startswith("ix_perf_"):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
if config.config_file_name is not None:
|
||||
@@ -36,6 +54,53 @@ target_metadata = Base.metadata
|
||||
config.set_main_option("sqlalchemy.url", settings.database_url)
|
||||
|
||||
|
||||
def ensure_database_exists(db_url: str) -> None:
|
||||
"""
|
||||
Ensure the target PostgreSQL database exists.
|
||||
If connection to the target DB fails because it doesn't exist, connect to the
|
||||
default 'postgres' database and create it. Safe to call multiple times.
|
||||
"""
|
||||
try:
|
||||
# First, try connecting to the target database
|
||||
test_engine = create_engine(db_url, poolclass=pool.NullPool)
|
||||
with test_engine.connect() as conn:
|
||||
conn.execute(text("SELECT 1"))
|
||||
test_engine.dispose()
|
||||
return
|
||||
except OperationalError:
|
||||
# Likely the database does not exist; proceed to create it
|
||||
pass
|
||||
|
||||
url = make_url(db_url)
|
||||
# Only handle PostgreSQL here
|
||||
if url.get_backend_name() != "postgresql":
|
||||
return
|
||||
|
||||
target_db = url.database
|
||||
if not target_db:
|
||||
return
|
||||
|
||||
# Build admin URL pointing to the default 'postgres' database
|
||||
admin_url = url.set(database="postgres")
|
||||
|
||||
# CREATE DATABASE cannot run inside a transaction
|
||||
admin_engine = create_engine(
|
||||
str(admin_url), isolation_level="AUTOCOMMIT", poolclass=pool.NullPool
|
||||
)
|
||||
try:
|
||||
with admin_engine.connect() as conn:
|
||||
exists = conn.execute(
|
||||
text("SELECT 1 FROM pg_database WHERE datname = :dbname"),
|
||||
{"dbname": target_db},
|
||||
).scalar()
|
||||
if not exists:
|
||||
# Quote the database name safely
|
||||
dbname_quoted = '"' + target_db.replace('"', '""') + '"'
|
||||
conn.execute(text(f"CREATE DATABASE {dbname_quoted}"))
|
||||
finally:
|
||||
admin_engine.dispose()
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode.
|
||||
|
||||
@@ -54,6 +119,8 @@ def run_migrations_offline() -> None:
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
compare_type=True,
|
||||
include_object=include_object,
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
@@ -67,6 +134,9 @@ def run_migrations_online() -> None:
|
||||
and associate a connection with the context.
|
||||
|
||||
"""
|
||||
# Ensure the target database exists (handles first-run cases)
|
||||
ensure_database_exists(settings.database_url)
|
||||
|
||||
connectable = engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
@@ -75,7 +145,10 @@ def run_migrations_online() -> None:
|
||||
|
||||
with connectable.connect() as connection:
|
||||
context.configure(
|
||||
connection=connection, target_metadata=target_metadata
|
||||
connection=connection,
|
||||
target_metadata=target_metadata,
|
||||
compare_type=True,
|
||||
include_object=include_object,
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
@@ -85,4 +158,4 @@ def run_migrations_online() -> None:
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
run_migrations_online()
|
||||
|
||||
446
backend/app/alembic/versions/0001_initial_models.py
Normal file
446
backend/app/alembic/versions/0001_initial_models.py
Normal file
@@ -0,0 +1,446 @@
|
||||
"""initial models
|
||||
|
||||
Revision ID: 0001
|
||||
Revises:
|
||||
Create Date: 2025-11-27 09:08:09.464506
|
||||
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "0001"
|
||||
down_revision: str | None = None
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"oauth_states",
|
||||
sa.Column("state", sa.String(length=255), nullable=False),
|
||||
sa.Column("code_verifier", sa.String(length=128), nullable=True),
|
||||
sa.Column("nonce", sa.String(length=255), nullable=True),
|
||||
sa.Column("provider", sa.String(length=50), nullable=False),
|
||||
sa.Column("redirect_uri", sa.String(length=500), nullable=True),
|
||||
sa.Column("user_id", sa.UUID(), nullable=True),
|
||||
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_oauth_states_state"), "oauth_states", ["state"], unique=True
|
||||
)
|
||||
op.create_table(
|
||||
"organizations",
|
||||
sa.Column("name", sa.String(length=255), nullable=False),
|
||||
sa.Column("slug", sa.String(length=255), nullable=False),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=False),
|
||||
sa.Column("settings", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_organizations_is_active"), "organizations", ["is_active"], unique=False
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_organizations_name"), "organizations", ["name"], unique=False
|
||||
)
|
||||
op.create_index(
|
||||
"ix_organizations_name_active",
|
||||
"organizations",
|
||||
["name", "is_active"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_organizations_slug"), "organizations", ["slug"], unique=True
|
||||
)
|
||||
op.create_index(
|
||||
"ix_organizations_slug_active",
|
||||
"organizations",
|
||||
["slug", "is_active"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_table(
|
||||
"users",
|
||||
sa.Column("email", sa.String(length=255), nullable=False),
|
||||
sa.Column("password_hash", sa.String(length=255), nullable=True),
|
||||
sa.Column("first_name", sa.String(length=100), nullable=False),
|
||||
sa.Column("last_name", sa.String(length=100), nullable=True),
|
||||
sa.Column("phone_number", sa.String(length=20), nullable=True),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=False),
|
||||
sa.Column("is_superuser", sa.Boolean(), nullable=False),
|
||||
sa.Column(
|
||||
"preferences", postgresql.JSONB(astext_type=sa.Text()), nullable=True
|
||||
),
|
||||
sa.Column("locale", sa.String(length=10), nullable=True),
|
||||
sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(op.f("ix_users_deleted_at"), "users", ["deleted_at"], unique=False)
|
||||
op.create_index(op.f("ix_users_email"), "users", ["email"], unique=True)
|
||||
op.create_index(op.f("ix_users_is_active"), "users", ["is_active"], unique=False)
|
||||
op.create_index(
|
||||
op.f("ix_users_is_superuser"), "users", ["is_superuser"], unique=False
|
||||
)
|
||||
op.create_index(op.f("ix_users_locale"), "users", ["locale"], unique=False)
|
||||
op.create_table(
|
||||
"oauth_accounts",
|
||||
sa.Column("user_id", sa.UUID(), nullable=False),
|
||||
sa.Column("provider", sa.String(length=50), nullable=False),
|
||||
sa.Column("provider_user_id", sa.String(length=255), nullable=False),
|
||||
sa.Column("provider_email", sa.String(length=255), nullable=True),
|
||||
sa.Column("access_token_encrypted", sa.String(length=2048), nullable=True),
|
||||
sa.Column("refresh_token_encrypted", sa.String(length=2048), nullable=True),
|
||||
sa.Column("token_expires_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint(
|
||||
"provider", "provider_user_id", name="uq_oauth_provider_user"
|
||||
),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_oauth_accounts_provider"), "oauth_accounts", ["provider"], unique=False
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_oauth_accounts_provider_email"),
|
||||
"oauth_accounts",
|
||||
["provider_email"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_oauth_accounts_user_id"), "oauth_accounts", ["user_id"], unique=False
|
||||
)
|
||||
op.create_index(
|
||||
"ix_oauth_accounts_user_provider",
|
||||
"oauth_accounts",
|
||||
["user_id", "provider"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_table(
|
||||
"oauth_clients",
|
||||
sa.Column("client_id", sa.String(length=64), nullable=False),
|
||||
sa.Column("client_secret_hash", sa.String(length=255), nullable=True),
|
||||
sa.Column("client_name", sa.String(length=255), nullable=False),
|
||||
sa.Column("client_description", sa.String(length=1000), nullable=True),
|
||||
sa.Column("client_type", sa.String(length=20), nullable=False),
|
||||
sa.Column(
|
||||
"redirect_uris", postgresql.JSONB(astext_type=sa.Text()), nullable=False
|
||||
),
|
||||
sa.Column(
|
||||
"allowed_scopes", postgresql.JSONB(astext_type=sa.Text()), nullable=False
|
||||
),
|
||||
sa.Column("access_token_lifetime", sa.String(length=10), nullable=False),
|
||||
sa.Column("refresh_token_lifetime", sa.String(length=10), nullable=False),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=False),
|
||||
sa.Column("owner_user_id", sa.UUID(), nullable=True),
|
||||
sa.Column("mcp_server_url", sa.String(length=2048), nullable=True),
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(["owner_user_id"], ["users.id"], ondelete="SET NULL"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_oauth_clients_client_id"), "oauth_clients", ["client_id"], unique=True
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_oauth_clients_is_active"), "oauth_clients", ["is_active"], unique=False
|
||||
)
|
||||
op.create_table(
|
||||
"user_organizations",
|
||||
sa.Column("user_id", sa.UUID(), nullable=False),
|
||||
sa.Column("organization_id", sa.UUID(), nullable=False),
|
||||
sa.Column(
|
||||
"role",
|
||||
sa.Enum("OWNER", "ADMIN", "MEMBER", "GUEST", name="organizationrole"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=False),
|
||||
sa.Column("custom_permissions", sa.String(length=500), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["organization_id"], ["organizations.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("user_id", "organization_id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_user_org_org_active",
|
||||
"user_organizations",
|
||||
["organization_id", "is_active"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index("ix_user_org_role", "user_organizations", ["role"], unique=False)
|
||||
op.create_index(
|
||||
"ix_user_org_user_active",
|
||||
"user_organizations",
|
||||
["user_id", "is_active"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_user_organizations_is_active"),
|
||||
"user_organizations",
|
||||
["is_active"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_table(
|
||||
"user_sessions",
|
||||
sa.Column("user_id", sa.UUID(), nullable=False),
|
||||
sa.Column("refresh_token_jti", sa.String(length=255), nullable=False),
|
||||
sa.Column("device_name", sa.String(length=255), nullable=True),
|
||||
sa.Column("device_id", sa.String(length=255), nullable=True),
|
||||
sa.Column("ip_address", sa.String(length=45), nullable=True),
|
||||
sa.Column("user_agent", sa.String(length=500), nullable=True),
|
||||
sa.Column("last_used_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=False),
|
||||
sa.Column("location_city", sa.String(length=100), nullable=True),
|
||||
sa.Column("location_country", sa.String(length=100), nullable=True),
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_user_sessions_is_active"), "user_sessions", ["is_active"], unique=False
|
||||
)
|
||||
op.create_index(
|
||||
"ix_user_sessions_jti_active",
|
||||
"user_sessions",
|
||||
["refresh_token_jti", "is_active"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_user_sessions_refresh_token_jti"),
|
||||
"user_sessions",
|
||||
["refresh_token_jti"],
|
||||
unique=True,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_user_sessions_user_active",
|
||||
"user_sessions",
|
||||
["user_id", "is_active"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_user_sessions_user_id"), "user_sessions", ["user_id"], unique=False
|
||||
)
|
||||
op.create_table(
|
||||
"oauth_authorization_codes",
|
||||
sa.Column("code", sa.String(length=128), nullable=False),
|
||||
sa.Column("client_id", sa.String(length=64), nullable=False),
|
||||
sa.Column("user_id", sa.UUID(), nullable=False),
|
||||
sa.Column("redirect_uri", sa.String(length=2048), nullable=False),
|
||||
sa.Column("scope", sa.String(length=1000), nullable=False),
|
||||
sa.Column("code_challenge", sa.String(length=128), nullable=True),
|
||||
sa.Column("code_challenge_method", sa.String(length=10), nullable=True),
|
||||
sa.Column("state", sa.String(length=256), nullable=True),
|
||||
sa.Column("nonce", sa.String(length=256), nullable=True),
|
||||
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("used", sa.Boolean(), nullable=False),
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["client_id"], ["oauth_clients.client_id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_oauth_authorization_codes_client_user",
|
||||
"oauth_authorization_codes",
|
||||
["client_id", "user_id"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_oauth_authorization_codes_code"),
|
||||
"oauth_authorization_codes",
|
||||
["code"],
|
||||
unique=True,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_oauth_authorization_codes_expires_at",
|
||||
"oauth_authorization_codes",
|
||||
["expires_at"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_table(
|
||||
"oauth_consents",
|
||||
sa.Column("user_id", sa.UUID(), nullable=False),
|
||||
sa.Column("client_id", sa.String(length=64), nullable=False),
|
||||
sa.Column("granted_scopes", sa.String(length=1000), nullable=False),
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["client_id"], ["oauth_clients.client_id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_oauth_consents_user_client",
|
||||
"oauth_consents",
|
||||
["user_id", "client_id"],
|
||||
unique=True,
|
||||
)
|
||||
op.create_table(
|
||||
"oauth_provider_refresh_tokens",
|
||||
sa.Column("token_hash", sa.String(length=64), nullable=False),
|
||||
sa.Column("jti", sa.String(length=64), nullable=False),
|
||||
sa.Column("client_id", sa.String(length=64), nullable=False),
|
||||
sa.Column("user_id", sa.UUID(), nullable=False),
|
||||
sa.Column("scope", sa.String(length=1000), nullable=False),
|
||||
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("revoked", sa.Boolean(), nullable=False),
|
||||
sa.Column("last_used_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("device_info", sa.String(length=500), nullable=True),
|
||||
sa.Column("ip_address", sa.String(length=45), nullable=True),
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["client_id"], ["oauth_clients.client_id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_oauth_provider_refresh_tokens_client_user",
|
||||
"oauth_provider_refresh_tokens",
|
||||
["client_id", "user_id"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_oauth_provider_refresh_tokens_expires_at",
|
||||
"oauth_provider_refresh_tokens",
|
||||
["expires_at"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_oauth_provider_refresh_tokens_jti"),
|
||||
"oauth_provider_refresh_tokens",
|
||||
["jti"],
|
||||
unique=True,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_oauth_provider_refresh_tokens_revoked"),
|
||||
"oauth_provider_refresh_tokens",
|
||||
["revoked"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_oauth_provider_refresh_tokens_token_hash"),
|
||||
"oauth_provider_refresh_tokens",
|
||||
["token_hash"],
|
||||
unique=True,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_oauth_provider_refresh_tokens_user_revoked",
|
||||
"oauth_provider_refresh_tokens",
|
||||
["user_id", "revoked"],
|
||||
unique=False,
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(
|
||||
"ix_oauth_provider_refresh_tokens_user_revoked",
|
||||
table_name="oauth_provider_refresh_tokens",
|
||||
)
|
||||
op.drop_index(
|
||||
op.f("ix_oauth_provider_refresh_tokens_token_hash"),
|
||||
table_name="oauth_provider_refresh_tokens",
|
||||
)
|
||||
op.drop_index(
|
||||
op.f("ix_oauth_provider_refresh_tokens_revoked"),
|
||||
table_name="oauth_provider_refresh_tokens",
|
||||
)
|
||||
op.drop_index(
|
||||
op.f("ix_oauth_provider_refresh_tokens_jti"),
|
||||
table_name="oauth_provider_refresh_tokens",
|
||||
)
|
||||
op.drop_index(
|
||||
"ix_oauth_provider_refresh_tokens_expires_at",
|
||||
table_name="oauth_provider_refresh_tokens",
|
||||
)
|
||||
op.drop_index(
|
||||
"ix_oauth_provider_refresh_tokens_client_user",
|
||||
table_name="oauth_provider_refresh_tokens",
|
||||
)
|
||||
op.drop_table("oauth_provider_refresh_tokens")
|
||||
op.drop_index("ix_oauth_consents_user_client", table_name="oauth_consents")
|
||||
op.drop_table("oauth_consents")
|
||||
op.drop_index(
|
||||
"ix_oauth_authorization_codes_expires_at",
|
||||
table_name="oauth_authorization_codes",
|
||||
)
|
||||
op.drop_index(
|
||||
op.f("ix_oauth_authorization_codes_code"),
|
||||
table_name="oauth_authorization_codes",
|
||||
)
|
||||
op.drop_index(
|
||||
"ix_oauth_authorization_codes_client_user",
|
||||
table_name="oauth_authorization_codes",
|
||||
)
|
||||
op.drop_table("oauth_authorization_codes")
|
||||
op.drop_index(op.f("ix_user_sessions_user_id"), table_name="user_sessions")
|
||||
op.drop_index("ix_user_sessions_user_active", table_name="user_sessions")
|
||||
op.drop_index(
|
||||
op.f("ix_user_sessions_refresh_token_jti"), table_name="user_sessions"
|
||||
)
|
||||
op.drop_index("ix_user_sessions_jti_active", table_name="user_sessions")
|
||||
op.drop_index(op.f("ix_user_sessions_is_active"), table_name="user_sessions")
|
||||
op.drop_table("user_sessions")
|
||||
op.drop_index(
|
||||
op.f("ix_user_organizations_is_active"), table_name="user_organizations"
|
||||
)
|
||||
op.drop_index("ix_user_org_user_active", table_name="user_organizations")
|
||||
op.drop_index("ix_user_org_role", table_name="user_organizations")
|
||||
op.drop_index("ix_user_org_org_active", table_name="user_organizations")
|
||||
op.drop_table("user_organizations")
|
||||
op.drop_index(op.f("ix_oauth_clients_is_active"), table_name="oauth_clients")
|
||||
op.drop_index(op.f("ix_oauth_clients_client_id"), table_name="oauth_clients")
|
||||
op.drop_table("oauth_clients")
|
||||
op.drop_index("ix_oauth_accounts_user_provider", table_name="oauth_accounts")
|
||||
op.drop_index(op.f("ix_oauth_accounts_user_id"), table_name="oauth_accounts")
|
||||
op.drop_index(op.f("ix_oauth_accounts_provider_email"), table_name="oauth_accounts")
|
||||
op.drop_index(op.f("ix_oauth_accounts_provider"), table_name="oauth_accounts")
|
||||
op.drop_table("oauth_accounts")
|
||||
op.drop_index(op.f("ix_users_locale"), table_name="users")
|
||||
op.drop_index(op.f("ix_users_is_superuser"), table_name="users")
|
||||
op.drop_index(op.f("ix_users_is_active"), table_name="users")
|
||||
op.drop_index(op.f("ix_users_email"), table_name="users")
|
||||
op.drop_index(op.f("ix_users_deleted_at"), table_name="users")
|
||||
op.drop_table("users")
|
||||
op.drop_index("ix_organizations_slug_active", table_name="organizations")
|
||||
op.drop_index(op.f("ix_organizations_slug"), table_name="organizations")
|
||||
op.drop_index("ix_organizations_name_active", table_name="organizations")
|
||||
op.drop_index(op.f("ix_organizations_name"), table_name="organizations")
|
||||
op.drop_index(op.f("ix_organizations_is_active"), table_name="organizations")
|
||||
op.drop_table("organizations")
|
||||
op.drop_index(op.f("ix_oauth_states_state"), table_name="oauth_states")
|
||||
op.drop_table("oauth_states")
|
||||
# ### end Alembic commands ###
|
||||
127
backend/app/alembic/versions/0002_add_performance_indexes.py
Normal file
127
backend/app/alembic/versions/0002_add_performance_indexes.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""Add performance indexes
|
||||
|
||||
Revision ID: 0002
|
||||
Revises: 0001
|
||||
Create Date: 2025-11-27
|
||||
|
||||
Performance indexes that Alembic cannot auto-detect:
|
||||
- Functional indexes (LOWER expressions)
|
||||
- Partial indexes (WHERE clauses)
|
||||
|
||||
These indexes use the ix_perf_ prefix and are excluded from autogenerate
|
||||
via the include_object() function in env.py.
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "0002"
|
||||
down_revision: str | None = "0001"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ==========================================================================
|
||||
# USERS TABLE - Performance indexes for authentication
|
||||
# ==========================================================================
|
||||
|
||||
# Case-insensitive email lookup for login/registration
|
||||
# Query: SELECT * FROM users WHERE LOWER(email) = LOWER(:email) AND deleted_at IS NULL
|
||||
# Impact: High - every login, registration check, password reset
|
||||
op.create_index(
|
||||
"ix_perf_users_email_lower",
|
||||
"users",
|
||||
[sa.text("LOWER(email)")],
|
||||
unique=False,
|
||||
postgresql_where=sa.text("deleted_at IS NULL"),
|
||||
)
|
||||
|
||||
# Active users lookup (non-soft-deleted)
|
||||
# Query: SELECT * FROM users WHERE deleted_at IS NULL AND ...
|
||||
# Impact: Medium - user listings, admin queries
|
||||
op.create_index(
|
||||
"ix_perf_users_active",
|
||||
"users",
|
||||
["is_active"],
|
||||
unique=False,
|
||||
postgresql_where=sa.text("deleted_at IS NULL"),
|
||||
)
|
||||
|
||||
# ==========================================================================
|
||||
# ORGANIZATIONS TABLE - Performance indexes for multi-tenant lookups
|
||||
# ==========================================================================
|
||||
|
||||
# Case-insensitive slug lookup for URL routing
|
||||
# Query: SELECT * FROM organizations WHERE LOWER(slug) = LOWER(:slug) AND is_active = true
|
||||
# Impact: Medium - every organization page load
|
||||
op.create_index(
|
||||
"ix_perf_organizations_slug_lower",
|
||||
"organizations",
|
||||
[sa.text("LOWER(slug)")],
|
||||
unique=False,
|
||||
postgresql_where=sa.text("is_active = true"),
|
||||
)
|
||||
|
||||
# ==========================================================================
|
||||
# USER SESSIONS TABLE - Performance indexes for session management
|
||||
# ==========================================================================
|
||||
|
||||
# Expired session cleanup
|
||||
# Query: SELECT * FROM user_sessions WHERE expires_at < NOW() AND is_active = true
|
||||
# Impact: Medium - background cleanup jobs
|
||||
op.create_index(
|
||||
"ix_perf_user_sessions_expires",
|
||||
"user_sessions",
|
||||
["expires_at"],
|
||||
unique=False,
|
||||
postgresql_where=sa.text("is_active = true"),
|
||||
)
|
||||
|
||||
# ==========================================================================
|
||||
# OAUTH PROVIDER TOKENS - Performance indexes for token management
|
||||
# ==========================================================================
|
||||
|
||||
# Expired refresh token cleanup
|
||||
# Query: SELECT * FROM oauth_provider_refresh_tokens WHERE expires_at < NOW() AND revoked = false
|
||||
# Impact: Medium - OAuth token cleanup, validation
|
||||
op.create_index(
|
||||
"ix_perf_oauth_refresh_tokens_expires",
|
||||
"oauth_provider_refresh_tokens",
|
||||
["expires_at"],
|
||||
unique=False,
|
||||
postgresql_where=sa.text("revoked = false"),
|
||||
)
|
||||
|
||||
# ==========================================================================
|
||||
# OAUTH AUTHORIZATION CODES - Performance indexes for auth flow
|
||||
# ==========================================================================
|
||||
|
||||
# Expired authorization code cleanup
|
||||
# Query: DELETE FROM oauth_authorization_codes WHERE expires_at < NOW() AND used = false
|
||||
# Impact: Low-Medium - OAuth cleanup jobs
|
||||
op.create_index(
|
||||
"ix_perf_oauth_auth_codes_expires",
|
||||
"oauth_authorization_codes",
|
||||
["expires_at"],
|
||||
unique=False,
|
||||
postgresql_where=sa.text("used = false"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop indexes in reverse order
|
||||
op.drop_index(
|
||||
"ix_perf_oauth_auth_codes_expires", table_name="oauth_authorization_codes"
|
||||
)
|
||||
op.drop_index(
|
||||
"ix_perf_oauth_refresh_tokens_expires",
|
||||
table_name="oauth_provider_refresh_tokens",
|
||||
)
|
||||
op.drop_index("ix_perf_user_sessions_expires", table_name="user_sessions")
|
||||
op.drop_index("ix_perf_organizations_slug_lower", table_name="organizations")
|
||||
op.drop_index("ix_perf_users_active", table_name="users")
|
||||
op.drop_index("ix_perf_users_email_lower", table_name="users")
|
||||
@@ -0,0 +1,35 @@
|
||||
"""rename oauth account token fields drop encrypted suffix
|
||||
|
||||
Revision ID: 0003
|
||||
Revises: 0002
|
||||
Create Date: 2026-02-27 01:03:18.869178
|
||||
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "0003"
|
||||
down_revision: str | None = "0002"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column(
|
||||
"oauth_accounts", "access_token_encrypted", new_column_name="access_token"
|
||||
)
|
||||
op.alter_column(
|
||||
"oauth_accounts", "refresh_token_encrypted", new_column_name="refresh_token"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.alter_column(
|
||||
"oauth_accounts", "access_token", new_column_name="access_token_encrypted"
|
||||
)
|
||||
op.alter_column(
|
||||
"oauth_accounts", "refresh_token", new_column_name="refresh_token_encrypted"
|
||||
)
|
||||
@@ -1,34 +0,0 @@
|
||||
"""add_soft_delete_to_users
|
||||
|
||||
Revision ID: 2d0fcec3b06d
|
||||
Revises: 9e4f2a1b8c7d
|
||||
Create Date: 2025-10-30 16:40:21.000021
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '2d0fcec3b06d'
|
||||
down_revision: Union[str, None] = '9e4f2a1b8c7d'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add deleted_at column for soft deletes
|
||||
op.add_column('users', sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True))
|
||||
|
||||
# Add index on deleted_at for efficient queries
|
||||
op.create_index('ix_users_deleted_at', 'users', ['deleted_at'])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove index
|
||||
op.drop_index('ix_users_deleted_at', table_name='users')
|
||||
|
||||
# Remove column
|
||||
op.drop_column('users', 'deleted_at')
|
||||
@@ -1,46 +0,0 @@
|
||||
"""Add all initial models
|
||||
|
||||
Revision ID: 38bf9e7e74b3
|
||||
Revises: 7396957cbe80
|
||||
Create Date: 2025-02-28 09:19:33.212278
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '38bf9e7e74b3'
|
||||
down_revision: Union[str, None] = '7396957cbe80'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
||||
op.create_table('users',
|
||||
sa.Column('email', sa.String(), nullable=False),
|
||||
sa.Column('password_hash', sa.String(), nullable=False),
|
||||
sa.Column('first_name', sa.String(), nullable=False),
|
||||
sa.Column('last_name', sa.String(), nullable=True),
|
||||
sa.Column('phone_number', sa.String(), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||
sa.Column('is_superuser', sa.Boolean(), nullable=False),
|
||||
sa.Column('preferences', sa.JSON(), nullable=True),
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=True)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f('ix_users_email'), table_name='users')
|
||||
op.drop_table('users')
|
||||
# ### end Alembic commands ###
|
||||
@@ -1,102 +0,0 @@
|
||||
"""add_user_sessions_table
|
||||
|
||||
Revision ID: 549b50ea888d
|
||||
Revises: b76c725fc3cf
|
||||
Create Date: 2025-10-31 07:41:18.729544
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '549b50ea888d'
|
||||
down_revision: Union[str, None] = 'b76c725fc3cf'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create user_sessions table for per-device session management
|
||||
op.create_table(
|
||||
'user_sessions',
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('user_id', sa.UUID(), nullable=False),
|
||||
sa.Column('refresh_token_jti', sa.String(length=255), nullable=False),
|
||||
sa.Column('device_name', sa.String(length=255), nullable=True),
|
||||
sa.Column('device_id', sa.String(length=255), nullable=True),
|
||||
sa.Column('ip_address', sa.String(length=45), nullable=True),
|
||||
sa.Column('user_agent', sa.String(length=500), nullable=True),
|
||||
sa.Column('last_used_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False, server_default='true'),
|
||||
sa.Column('location_city', sa.String(length=100), nullable=True),
|
||||
sa.Column('location_country', sa.String(length=100), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
|
||||
# Create foreign key to users table
|
||||
op.create_foreign_key(
|
||||
'fk_user_sessions_user_id',
|
||||
'user_sessions',
|
||||
'users',
|
||||
['user_id'],
|
||||
['id'],
|
||||
ondelete='CASCADE'
|
||||
)
|
||||
|
||||
# Create indexes for performance
|
||||
# 1. Lookup session by refresh token JTI (most common query)
|
||||
op.create_index(
|
||||
'ix_user_sessions_jti',
|
||||
'user_sessions',
|
||||
['refresh_token_jti'],
|
||||
unique=True
|
||||
)
|
||||
|
||||
# 2. Lookup sessions by user ID
|
||||
op.create_index(
|
||||
'ix_user_sessions_user_id',
|
||||
'user_sessions',
|
||||
['user_id']
|
||||
)
|
||||
|
||||
# 3. Composite index for active sessions by user
|
||||
op.create_index(
|
||||
'ix_user_sessions_user_active',
|
||||
'user_sessions',
|
||||
['user_id', 'is_active']
|
||||
)
|
||||
|
||||
# 4. Index on expires_at for cleanup job
|
||||
op.create_index(
|
||||
'ix_user_sessions_expires_at',
|
||||
'user_sessions',
|
||||
['expires_at']
|
||||
)
|
||||
|
||||
# 5. Composite index for active session lookup by JTI
|
||||
op.create_index(
|
||||
'ix_user_sessions_jti_active',
|
||||
'user_sessions',
|
||||
['refresh_token_jti', 'is_active']
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop indexes first
|
||||
op.drop_index('ix_user_sessions_jti_active', table_name='user_sessions')
|
||||
op.drop_index('ix_user_sessions_expires_at', table_name='user_sessions')
|
||||
op.drop_index('ix_user_sessions_user_active', table_name='user_sessions')
|
||||
op.drop_index('ix_user_sessions_user_id', table_name='user_sessions')
|
||||
op.drop_index('ix_user_sessions_jti', table_name='user_sessions')
|
||||
|
||||
# Drop foreign key
|
||||
op.drop_constraint('fk_user_sessions_user_id', 'user_sessions', type_='foreignkey')
|
||||
|
||||
# Drop table
|
||||
op.drop_table('user_sessions')
|
||||
@@ -1,26 +0,0 @@
|
||||
"""Initial empty migration
|
||||
|
||||
Revision ID: 7396957cbe80
|
||||
Revises:
|
||||
Create Date: 2025-02-27 12:47:46.445313
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '7396957cbe80'
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
pass
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
pass
|
||||
@@ -1,84 +0,0 @@
|
||||
"""Add missing indexes and fix column types
|
||||
|
||||
Revision ID: 9e4f2a1b8c7d
|
||||
Revises: 38bf9e7e74b3
|
||||
Create Date: 2025-10-30 10:00:00.000000
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '9e4f2a1b8c7d'
|
||||
down_revision: Union[str, None] = '38bf9e7e74b3'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add missing indexes for is_active and is_superuser
|
||||
op.create_index(op.f('ix_users_is_active'), 'users', ['is_active'], unique=False)
|
||||
op.create_index(op.f('ix_users_is_superuser'), 'users', ['is_superuser'], unique=False)
|
||||
|
||||
# Fix column types to match model definitions with explicit lengths
|
||||
op.alter_column('users', 'email',
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(length=255),
|
||||
nullable=False)
|
||||
|
||||
op.alter_column('users', 'password_hash',
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(length=255),
|
||||
nullable=False)
|
||||
|
||||
op.alter_column('users', 'first_name',
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(length=100),
|
||||
nullable=False,
|
||||
server_default='user') # Add server default
|
||||
|
||||
op.alter_column('users', 'last_name',
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(length=100),
|
||||
nullable=True)
|
||||
|
||||
op.alter_column('users', 'phone_number',
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(length=20),
|
||||
nullable=True)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Revert column types
|
||||
op.alter_column('users', 'phone_number',
|
||||
existing_type=sa.String(length=20),
|
||||
type_=sa.String(),
|
||||
nullable=True)
|
||||
|
||||
op.alter_column('users', 'last_name',
|
||||
existing_type=sa.String(length=100),
|
||||
type_=sa.String(),
|
||||
nullable=True)
|
||||
|
||||
op.alter_column('users', 'first_name',
|
||||
existing_type=sa.String(length=100),
|
||||
type_=sa.String(),
|
||||
nullable=False,
|
||||
server_default=None) # Remove server default
|
||||
|
||||
op.alter_column('users', 'password_hash',
|
||||
existing_type=sa.String(length=255),
|
||||
type_=sa.String(),
|
||||
nullable=False)
|
||||
|
||||
op.alter_column('users', 'email',
|
||||
existing_type=sa.String(length=255),
|
||||
type_=sa.String(),
|
||||
nullable=False)
|
||||
|
||||
# Drop indexes
|
||||
op.drop_index(op.f('ix_users_is_superuser'), table_name='users')
|
||||
op.drop_index(op.f('ix_users_is_active'), table_name='users')
|
||||
@@ -1,52 +0,0 @@
|
||||
"""add_composite_indexes
|
||||
|
||||
Revision ID: b76c725fc3cf
|
||||
Revises: 2d0fcec3b06d
|
||||
Create Date: 2025-10-30 16:41:33.273135
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'b76c725fc3cf'
|
||||
down_revision: Union[str, None] = '2d0fcec3b06d'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add composite indexes for common query patterns
|
||||
|
||||
# Composite index for filtering active users by role
|
||||
op.create_index(
|
||||
'ix_users_active_superuser',
|
||||
'users',
|
||||
['is_active', 'is_superuser'],
|
||||
postgresql_where=sa.text('deleted_at IS NULL')
|
||||
)
|
||||
|
||||
# Composite index for sorting active users by creation date
|
||||
op.create_index(
|
||||
'ix_users_active_created',
|
||||
'users',
|
||||
['is_active', 'created_at'],
|
||||
postgresql_where=sa.text('deleted_at IS NULL')
|
||||
)
|
||||
|
||||
# Composite index for email lookup of non-deleted users
|
||||
op.create_index(
|
||||
'ix_users_email_not_deleted',
|
||||
'users',
|
||||
['email', 'deleted_at']
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove composite indexes
|
||||
op.drop_index('ix_users_email_not_deleted', table_name='users')
|
||||
op.drop_index('ix_users_active_created', table_name='users')
|
||||
op.drop_index('ix_users_active_superuser', table_name='users')
|
||||
@@ -1,106 +0,0 @@
|
||||
"""add_organizations_and_user_organizations
|
||||
|
||||
Revision ID: fbf6318a8a36
|
||||
Revises: 549b50ea888d
|
||||
Create Date: 2025-10-31 12:08:05.141353
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'fbf6318a8a36'
|
||||
down_revision: Union[str, None] = '549b50ea888d'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create organizations table
|
||||
op.create_table(
|
||||
'organizations',
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('name', sa.String(length=255), nullable=False),
|
||||
sa.Column('slug', sa.String(length=255), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False, server_default='true'),
|
||||
sa.Column('settings', sa.JSON(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
|
||||
# Create indexes for organizations
|
||||
op.create_index('ix_organizations_name', 'organizations', ['name'])
|
||||
op.create_index('ix_organizations_slug', 'organizations', ['slug'], unique=True)
|
||||
op.create_index('ix_organizations_is_active', 'organizations', ['is_active'])
|
||||
op.create_index('ix_organizations_name_active', 'organizations', ['name', 'is_active'])
|
||||
op.create_index('ix_organizations_slug_active', 'organizations', ['slug', 'is_active'])
|
||||
|
||||
# Create user_organizations junction table
|
||||
op.create_table(
|
||||
'user_organizations',
|
||||
sa.Column('user_id', sa.UUID(), nullable=False),
|
||||
sa.Column('organization_id', sa.UUID(), nullable=False),
|
||||
sa.Column('role', sa.Enum('OWNER', 'ADMIN', 'MEMBER', 'GUEST', name='organizationrole'), nullable=False, server_default='MEMBER'),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False, server_default='true'),
|
||||
sa.Column('custom_permissions', sa.String(length=500), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint('user_id', 'organization_id')
|
||||
)
|
||||
|
||||
# Create foreign keys
|
||||
op.create_foreign_key(
|
||||
'fk_user_organizations_user_id',
|
||||
'user_organizations',
|
||||
'users',
|
||||
['user_id'],
|
||||
['id'],
|
||||
ondelete='CASCADE'
|
||||
)
|
||||
op.create_foreign_key(
|
||||
'fk_user_organizations_organization_id',
|
||||
'user_organizations',
|
||||
'organizations',
|
||||
['organization_id'],
|
||||
['id'],
|
||||
ondelete='CASCADE'
|
||||
)
|
||||
|
||||
# Create indexes for user_organizations
|
||||
op.create_index('ix_user_organizations_role', 'user_organizations', ['role'])
|
||||
op.create_index('ix_user_organizations_is_active', 'user_organizations', ['is_active'])
|
||||
op.create_index('ix_user_org_user_active', 'user_organizations', ['user_id', 'is_active'])
|
||||
op.create_index('ix_user_org_org_active', 'user_organizations', ['organization_id', 'is_active'])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop indexes for user_organizations
|
||||
op.drop_index('ix_user_org_org_active', table_name='user_organizations')
|
||||
op.drop_index('ix_user_org_user_active', table_name='user_organizations')
|
||||
op.drop_index('ix_user_organizations_is_active', table_name='user_organizations')
|
||||
op.drop_index('ix_user_organizations_role', table_name='user_organizations')
|
||||
|
||||
# Drop foreign keys
|
||||
op.drop_constraint('fk_user_organizations_organization_id', 'user_organizations', type_='foreignkey')
|
||||
op.drop_constraint('fk_user_organizations_user_id', 'user_organizations', type_='foreignkey')
|
||||
|
||||
# Drop user_organizations table
|
||||
op.drop_table('user_organizations')
|
||||
|
||||
# Drop indexes for organizations
|
||||
op.drop_index('ix_organizations_slug_active', table_name='organizations')
|
||||
op.drop_index('ix_organizations_name_active', table_name='organizations')
|
||||
op.drop_index('ix_organizations_is_active', table_name='organizations')
|
||||
op.drop_index('ix_organizations_slug', table_name='organizations')
|
||||
op.drop_index('ix_organizations_name', table_name='organizations')
|
||||
|
||||
# Drop organizations table
|
||||
op.drop_table('organizations')
|
||||
|
||||
# Drop enum type
|
||||
op.execute('DROP TYPE IF EXISTS organizationrole')
|
||||
56
backend/app/api/dependencies/auth.py
Normal file → Executable file
56
backend/app/api/dependencies/auth.py
Normal file → Executable file
@@ -1,21 +1,19 @@
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Depends, HTTPException, status, Header
|
||||
from fastapi import Depends, Header, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from fastapi.security.utils import get_authorization_scheme_param
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.auth import get_token_data, TokenExpiredError, TokenInvalidError
|
||||
from app.core.auth import TokenExpiredError, TokenInvalidError, get_token_data
|
||||
from app.core.database import get_db
|
||||
from app.models.user import User
|
||||
from app.repositories.user import user_repo
|
||||
|
||||
# OAuth2 configuration
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||
|
||||
|
||||
def get_current_user(
|
||||
db: Session = Depends(get_db),
|
||||
token: str = Depends(oauth2_scheme)
|
||||
async def get_current_user(
|
||||
db: AsyncSession = Depends(get_db), token: str = Depends(oauth2_scheme)
|
||||
) -> User:
|
||||
"""
|
||||
Get the current authenticated user.
|
||||
@@ -34,18 +32,17 @@ def get_current_user(
|
||||
# Decode token and get user ID
|
||||
token_data = get_token_data(token)
|
||||
|
||||
# Get user from database
|
||||
user = db.query(User).filter(User.id == token_data.user_id).first()
|
||||
# Get user from database via repository
|
||||
user = await user_repo.get(db, id=str(token_data.user_id))
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
||||
)
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Inactive user"
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
|
||||
)
|
||||
|
||||
return user
|
||||
@@ -54,19 +51,17 @@ def get_current_user(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Token expired",
|
||||
headers={"WWW-Authenticate": "Bearer"}
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
except TokenInvalidError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"}
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
|
||||
def get_current_active_user(
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> User:
|
||||
def get_current_active_user(current_user: User = Depends(get_current_user)) -> User:
|
||||
"""
|
||||
Check if the current user is active.
|
||||
|
||||
@@ -81,15 +76,12 @@ def get_current_active_user(
|
||||
"""
|
||||
if not current_user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Inactive user"
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="Inactive user"
|
||||
)
|
||||
return current_user
|
||||
|
||||
|
||||
def get_current_superuser(
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> User:
|
||||
def get_current_superuser(current_user: User = Depends(get_current_user)) -> User:
|
||||
"""
|
||||
Check if the current user is a superuser.
|
||||
|
||||
@@ -104,13 +96,12 @@ def get_current_superuser(
|
||||
"""
|
||||
if not current_user.is_superuser:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Not enough permissions"
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="Not enough permissions"
|
||||
)
|
||||
return current_user
|
||||
|
||||
|
||||
async def get_optional_token(authorization: str = Header(None)) -> Optional[str]:
|
||||
async def get_optional_token(authorization: str = Header(None)) -> str | None:
|
||||
"""
|
||||
Get the token from the Authorization header without requiring it.
|
||||
|
||||
@@ -133,10 +124,9 @@ async def get_optional_token(authorization: str = Header(None)) -> Optional[str]
|
||||
return token
|
||||
|
||||
|
||||
def get_optional_current_user(
|
||||
db: Session = Depends(get_db),
|
||||
token: Optional[str] = Depends(get_optional_token)
|
||||
) -> Optional[User]:
|
||||
async def get_optional_current_user(
|
||||
db: AsyncSession = Depends(get_db), token: str | None = Depends(get_optional_token)
|
||||
) -> User | None:
|
||||
"""
|
||||
Get the current user if authenticated, otherwise return None.
|
||||
Useful for endpoints that work with both authenticated and unauthenticated users.
|
||||
@@ -153,9 +143,9 @@ def get_optional_current_user(
|
||||
|
||||
try:
|
||||
token_data = get_token_data(token)
|
||||
user = db.query(User).filter(User.id == token_data.user_id).first()
|
||||
user = await user_repo.get(db, id=str(token_data.user_id))
|
||||
if not user or not user.is_active:
|
||||
return None
|
||||
return user
|
||||
except (TokenExpiredError, TokenInvalidError):
|
||||
return None
|
||||
return None
|
||||
|
||||
132
backend/app/api/dependencies/locale.py
Normal file
132
backend/app/api/dependencies/locale.py
Normal file
@@ -0,0 +1,132 @@
|
||||
# app/api/dependencies/locale.py
|
||||
"""
|
||||
Locale detection dependency for internationalization (i18n).
|
||||
|
||||
Implements a three-tier fallback system:
|
||||
1. User's saved preference (if authenticated and user.locale is set)
|
||||
2. Accept-Language header (for unauthenticated users or no saved preference)
|
||||
3. Default to English ("en")
|
||||
"""
|
||||
|
||||
from fastapi import Depends, Request
|
||||
|
||||
from app.api.dependencies.auth import get_optional_current_user
|
||||
from app.models.user import User
|
||||
|
||||
# Supported locales (BCP 47 format)
|
||||
# Template showcases English and Italian
|
||||
# Users can extend by adding more locales here
|
||||
# Note: Stored in lowercase for case-insensitive matching
|
||||
SUPPORTED_LOCALES = {"en", "it", "en-us", "en-gb", "it-it"}
|
||||
DEFAULT_LOCALE = "en"
|
||||
|
||||
|
||||
def parse_accept_language(accept_language: str) -> str | None:
|
||||
"""
|
||||
Parse the Accept-Language header and return the best matching supported locale.
|
||||
|
||||
The Accept-Language header format is:
|
||||
"it-IT,it;q=0.9,en-US;q=0.8,en;q=0.7"
|
||||
|
||||
This function extracts locales in priority order (by quality value) and returns
|
||||
the first one that matches our supported locales.
|
||||
|
||||
Args:
|
||||
accept_language: The Accept-Language header value
|
||||
|
||||
Returns:
|
||||
The best matching locale code, or None if no match found
|
||||
|
||||
Examples:
|
||||
>>> parse_accept_language("it-IT,it;q=0.9,en;q=0.8")
|
||||
"it-IT" # or "it" if it-IT is not supported
|
||||
>>> parse_accept_language("fr-FR,fr;q=0.9")
|
||||
None # French not supported
|
||||
"""
|
||||
if not accept_language:
|
||||
return None
|
||||
|
||||
# Split by comma to get individual locale entries
|
||||
# Format: "locale;q=weight" or just "locale"
|
||||
locales = []
|
||||
for entry in accept_language.split(","):
|
||||
# Remove quality value (;q=0.9) if present
|
||||
locale = entry.split(";")[0].strip()
|
||||
if locale:
|
||||
locales.append(locale)
|
||||
|
||||
# Check each locale in priority order
|
||||
for locale in locales:
|
||||
locale_lower = locale.lower()
|
||||
|
||||
# Try exact match first (e.g., "it-IT")
|
||||
if locale_lower in SUPPORTED_LOCALES:
|
||||
return locale_lower
|
||||
|
||||
# Try language code only (e.g., "it" from "it-IT")
|
||||
lang_code = locale_lower.split("-")[0]
|
||||
if lang_code in SUPPORTED_LOCALES:
|
||||
return lang_code
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def get_locale(
|
||||
request: Request,
|
||||
current_user: User | None = Depends(get_optional_current_user),
|
||||
) -> str:
|
||||
"""
|
||||
Detect and return the appropriate locale for the current request.
|
||||
|
||||
Three-tier fallback system:
|
||||
1. **User Preference** (highest priority)
|
||||
- If user is authenticated and has a saved locale preference, use it
|
||||
- This persists across sessions and devices
|
||||
|
||||
2. **Accept-Language Header** (second priority)
|
||||
- Parse the Accept-Language header from the request
|
||||
- Match against supported locales
|
||||
- Common for browser requests
|
||||
|
||||
3. **Default Locale** (fallback)
|
||||
- Return "en" (English) if no user preference and no header match
|
||||
|
||||
Args:
|
||||
request: The FastAPI request object (for accessing headers)
|
||||
current_user: The current authenticated user (optional)
|
||||
|
||||
Returns:
|
||||
A valid locale code from SUPPORTED_LOCALES (guaranteed to be supported)
|
||||
|
||||
Examples:
|
||||
>>> # Authenticated user with saved preference
|
||||
>>> await get_locale(request, user_with_locale_it)
|
||||
"it"
|
||||
|
||||
>>> # Unauthenticated user with Italian browser
|
||||
>>> # (request has Accept-Language: it-IT,it;q=0.9)
|
||||
>>> await get_locale(request, None)
|
||||
"it"
|
||||
|
||||
>>> # Unauthenticated user with unsupported language
|
||||
>>> # (request has Accept-Language: fr-FR,fr;q=0.9)
|
||||
>>> await get_locale(request, None)
|
||||
"en"
|
||||
"""
|
||||
# Priority 1: User's saved preference
|
||||
if current_user and current_user.locale:
|
||||
# Validate that saved locale is still supported
|
||||
# (in case SUPPORTED_LOCALES changed after user set preference)
|
||||
locale_value = str(current_user.locale)
|
||||
if locale_value in SUPPORTED_LOCALES:
|
||||
return locale_value
|
||||
|
||||
# Priority 2: Accept-Language header
|
||||
accept_language = request.headers.get("accept-language", "")
|
||||
if accept_language:
|
||||
detected_locale = parse_accept_language(accept_language)
|
||||
if detected_locale:
|
||||
return detected_locale
|
||||
|
||||
# Priority 3: Default fallback
|
||||
return DEFAULT_LOCALE
|
||||
104
backend/app/api/dependencies/permissions.py
Normal file → Executable file
104
backend/app/api/dependencies/permissions.py
Normal file → Executable file
@@ -7,21 +7,20 @@ These dependencies are optional and flexible:
|
||||
- Use require_org_role for organization-specific access control
|
||||
- Projects can choose to use these or implement their own permission system
|
||||
"""
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.dependencies.auth import get_current_user
|
||||
from app.core.database import get_db
|
||||
from app.models.user import User
|
||||
from app.models.user_organization import OrganizationRole
|
||||
from app.api.dependencies.auth import get_current_user
|
||||
from app.crud.organization import organization as organization_crud
|
||||
from app.services.organization_service import organization_service
|
||||
|
||||
|
||||
def require_superuser(
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> User:
|
||||
def require_superuser(current_user: User = Depends(get_current_user)) -> User:
|
||||
"""
|
||||
Dependency to ensure the current user is a superuser.
|
||||
|
||||
@@ -35,23 +34,7 @@ def require_superuser(
|
||||
if not current_user.is_superuser:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Superuser privileges required"
|
||||
)
|
||||
return current_user
|
||||
|
||||
|
||||
def require_active_user(
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> User:
|
||||
"""
|
||||
Dependency to ensure the current user is active.
|
||||
|
||||
Use this for endpoints that require an active account.
|
||||
"""
|
||||
if not current_user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Inactive account"
|
||||
detail="Superuser privileges required",
|
||||
)
|
||||
return current_user
|
||||
|
||||
@@ -73,11 +56,11 @@ class OrganizationPermission:
|
||||
"""
|
||||
self.allowed_roles = allowed_roles
|
||||
|
||||
def __call__(
|
||||
async def __call__(
|
||||
self,
|
||||
organization_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> User:
|
||||
"""
|
||||
Check if user has required role in the organization.
|
||||
@@ -98,22 +81,20 @@ class OrganizationPermission:
|
||||
return current_user
|
||||
|
||||
# Get user's role in organization
|
||||
user_role = organization_crud.get_user_role_in_org(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
organization_id=organization_id
|
||||
user_role = await organization_service.get_user_role_in_org(
|
||||
db, user_id=current_user.id, organization_id=organization_id
|
||||
)
|
||||
|
||||
if not user_role:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Not a member of this organization"
|
||||
detail="Not a member of this organization",
|
||||
)
|
||||
|
||||
if user_role not in self.allowed_roles:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Role {user_role} not authorized. Required: {self.allowed_roles}"
|
||||
detail=f"Role {user_role} not authorized. Required: {self.allowed_roles}",
|
||||
)
|
||||
|
||||
return current_user
|
||||
@@ -121,49 +102,18 @@ class OrganizationPermission:
|
||||
|
||||
# Common permission presets for convenience
|
||||
require_org_owner = OrganizationPermission([OrganizationRole.OWNER])
|
||||
require_org_admin = OrganizationPermission([OrganizationRole.OWNER, OrganizationRole.ADMIN])
|
||||
require_org_member = OrganizationPermission([
|
||||
OrganizationRole.OWNER,
|
||||
OrganizationRole.ADMIN,
|
||||
OrganizationRole.MEMBER
|
||||
])
|
||||
require_org_admin = OrganizationPermission(
|
||||
[OrganizationRole.OWNER, OrganizationRole.ADMIN]
|
||||
)
|
||||
require_org_member = OrganizationPermission(
|
||||
[OrganizationRole.OWNER, OrganizationRole.ADMIN, OrganizationRole.MEMBER]
|
||||
)
|
||||
|
||||
|
||||
def get_current_org_role(
|
||||
async def require_org_membership(
|
||||
organization_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
) -> Optional[OrganizationRole]:
|
||||
"""
|
||||
Get the current user's role in an organization.
|
||||
|
||||
This is a non-blocking dependency that returns the role or None.
|
||||
Use this when you want to check permissions conditionally.
|
||||
|
||||
Example:
|
||||
@router.get("/organizations/{org_id}/items")
|
||||
def list_items(
|
||||
org_id: UUID,
|
||||
role: OrganizationRole = Depends(get_current_org_role)
|
||||
):
|
||||
if role in [OrganizationRole.OWNER, OrganizationRole.ADMIN]:
|
||||
# Show admin features
|
||||
...
|
||||
"""
|
||||
if current_user.is_superuser:
|
||||
return OrganizationRole.OWNER
|
||||
|
||||
return organization_crud.get_user_role_in_org(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
organization_id=organization_id
|
||||
)
|
||||
|
||||
|
||||
def require_org_membership(
|
||||
organization_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> User:
|
||||
"""
|
||||
Ensure user is a member of the organization (any role).
|
||||
@@ -173,16 +123,14 @@ def require_org_membership(
|
||||
if current_user.is_superuser:
|
||||
return current_user
|
||||
|
||||
user_role = organization_crud.get_user_role_in_org(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
organization_id=organization_id
|
||||
user_role = await organization_service.get_user_role_in_org(
|
||||
db, user_id=current_user.id, organization_id=organization_id
|
||||
)
|
||||
|
||||
if not user_role:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Not a member of this organization"
|
||||
detail="Not a member of this organization",
|
||||
)
|
||||
|
||||
return current_user
|
||||
|
||||
41
backend/app/api/dependencies/services.py
Normal file
41
backend/app/api/dependencies/services.py
Normal file
@@ -0,0 +1,41 @@
|
||||
# app/api/dependencies/services.py
|
||||
"""FastAPI dependency functions for service singletons."""
|
||||
|
||||
from app.services import oauth_provider_service
|
||||
from app.services.auth_service import AuthService
|
||||
from app.services.oauth_service import OAuthService
|
||||
from app.services.organization_service import OrganizationService, organization_service
|
||||
from app.services.session_service import SessionService, session_service
|
||||
from app.services.user_service import UserService, user_service
|
||||
|
||||
|
||||
def get_auth_service() -> AuthService:
|
||||
"""Return the AuthService singleton for dependency injection."""
|
||||
from app.services.auth_service import AuthService as _AuthService
|
||||
|
||||
return _AuthService()
|
||||
|
||||
|
||||
def get_user_service() -> UserService:
|
||||
"""Return the UserService singleton for dependency injection."""
|
||||
return user_service
|
||||
|
||||
|
||||
def get_organization_service() -> OrganizationService:
|
||||
"""Return the OrganizationService singleton for dependency injection."""
|
||||
return organization_service
|
||||
|
||||
|
||||
def get_session_service() -> SessionService:
|
||||
"""Return the SessionService singleton for dependency injection."""
|
||||
return session_service
|
||||
|
||||
|
||||
def get_oauth_service() -> OAuthService:
|
||||
"""Return OAuthService for dependency injection."""
|
||||
return OAuthService()
|
||||
|
||||
|
||||
def get_oauth_provider_service():
|
||||
"""Return the oauth_provider_service module for dependency injection."""
|
||||
return oauth_provider_service
|
||||
@@ -1,10 +1,24 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.api.routes import auth, users, sessions, admin, organizations
|
||||
from app.api.routes import (
|
||||
admin,
|
||||
auth,
|
||||
oauth,
|
||||
oauth_provider,
|
||||
organizations,
|
||||
sessions,
|
||||
users,
|
||||
)
|
||||
|
||||
api_router = APIRouter()
|
||||
api_router.include_router(auth.router, prefix="/auth", tags=["Authentication"])
|
||||
api_router.include_router(oauth.router, prefix="/oauth", tags=["OAuth"])
|
||||
api_router.include_router(
|
||||
oauth_provider.router, prefix="/oauth", tags=["OAuth Provider"]
|
||||
)
|
||||
api_router.include_router(users.router, prefix="/users", tags=["Users"])
|
||||
api_router.include_router(sessions.router, prefix="/sessions", tags=["Sessions"])
|
||||
api_router.include_router(admin.router, prefix="/admin", tags=["Admin"])
|
||||
api_router.include_router(organizations.router, prefix="/organizations", tags=["Organizations"])
|
||||
api_router.include_router(
|
||||
organizations.router, prefix="/organizations", tags=["Organizations"]
|
||||
)
|
||||
|
||||
805
backend/app/api/routes/admin.py
Normal file → Executable file
805
backend/app/api/routes/admin.py
Normal file → Executable file
File diff suppressed because it is too large
Load Diff
470
backend/app/api/routes/auth.py
Normal file → Executable file
470
backend/app/api/routes/auth.py
Normal file → Executable file
@@ -1,37 +1,46 @@
|
||||
# app/api/routes/auth.py
|
||||
import logging
|
||||
import os
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Body, Request
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.dependencies.auth import get_current_user
|
||||
from app.core.auth import TokenExpiredError, TokenInvalidError, decode_token
|
||||
from app.core.auth import (
|
||||
TokenExpiredError,
|
||||
TokenInvalidError,
|
||||
decode_token,
|
||||
)
|
||||
from app.core.database import get_db
|
||||
from app.core.exceptions import (
|
||||
AuthenticationError as AuthError,
|
||||
DatabaseError,
|
||||
DuplicateError,
|
||||
ErrorCode,
|
||||
)
|
||||
from app.models.user import User
|
||||
from app.schemas.common import MessageResponse
|
||||
from app.schemas.sessions import LogoutRequest, SessionCreate
|
||||
from app.schemas.users import (
|
||||
LoginRequest,
|
||||
PasswordResetConfirm,
|
||||
PasswordResetRequest,
|
||||
RefreshTokenRequest,
|
||||
Token,
|
||||
UserCreate,
|
||||
UserResponse,
|
||||
Token,
|
||||
LoginRequest,
|
||||
RefreshTokenRequest,
|
||||
PasswordResetRequest,
|
||||
PasswordResetConfirm
|
||||
)
|
||||
from app.schemas.common import MessageResponse
|
||||
from app.schemas.sessions import SessionCreate, LogoutRequest
|
||||
from app.services.auth_service import AuthService, AuthenticationError
|
||||
from app.services.auth_service import AuthenticationError, AuthService
|
||||
from app.services.email_service import email_service
|
||||
from app.utils.security import create_password_reset_token, verify_password_reset_token
|
||||
from app.services.session_service import session_service
|
||||
from app.services.user_service import user_service
|
||||
from app.utils.device import extract_device_info
|
||||
from app.crud.user import user as user_crud
|
||||
from app.crud.session import session as session_crud
|
||||
from app.core.auth import get_password_hash
|
||||
from app.utils.security import create_password_reset_token, verify_password_reset_token
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -44,12 +53,67 @@ IS_TEST = os.getenv("IS_TEST", "False") == "True"
|
||||
RATE_MULTIPLIER = 100 if IS_TEST else 1
|
||||
|
||||
|
||||
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED, operation_id="register")
|
||||
async def _create_login_session(
|
||||
db: AsyncSession,
|
||||
request: Request,
|
||||
user: User,
|
||||
tokens: Token,
|
||||
login_type: str = "login",
|
||||
) -> None:
|
||||
"""
|
||||
Create a session record for successful login.
|
||||
|
||||
This is a best-effort operation - login succeeds even if session creation fails.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
request: FastAPI request object for device info extraction
|
||||
user: Authenticated user
|
||||
tokens: Token object containing refresh token with JTI
|
||||
login_type: Type of login for logging ("login" or "oauth")
|
||||
"""
|
||||
try:
|
||||
device_info = extract_device_info(request)
|
||||
|
||||
# Decode refresh token to get JTI and expiration
|
||||
refresh_payload = decode_token(tokens.refresh_token, verify_type="refresh")
|
||||
|
||||
session_data = SessionCreate(
|
||||
user_id=user.id,
|
||||
refresh_token_jti=refresh_payload.jti,
|
||||
device_name=device_info.device_name or "API Client",
|
||||
device_id=device_info.device_id,
|
||||
ip_address=device_info.ip_address,
|
||||
user_agent=device_info.user_agent,
|
||||
last_used_at=datetime.now(UTC),
|
||||
expires_at=datetime.fromtimestamp(refresh_payload.exp, tz=UTC),
|
||||
location_city=device_info.location_city,
|
||||
location_country=device_info.location_country,
|
||||
)
|
||||
|
||||
await session_service.create_session(db, obj_in=session_data)
|
||||
|
||||
logger.info(
|
||||
"%s successful: %s from %s (IP: %s)",
|
||||
login_type.capitalize(),
|
||||
user.email,
|
||||
device_info.device_name,
|
||||
device_info.ip_address,
|
||||
)
|
||||
except Exception as session_err:
|
||||
# Log but don't fail login if session creation fails
|
||||
logger.exception("Failed to create session for %s: %s", user.email, session_err)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/register",
|
||||
response_model=UserResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
operation_id="register",
|
||||
)
|
||||
@limiter.limit(f"{5 * RATE_MULTIPLIER}/minute")
|
||||
async def register_user(
|
||||
request: Request,
|
||||
user_data: UserCreate,
|
||||
db: Session = Depends(get_db)
|
||||
request: Request, user_data: UserCreate, db: AsyncSession = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Register a new user.
|
||||
@@ -58,28 +122,33 @@ async def register_user(
|
||||
The created user information.
|
||||
"""
|
||||
try:
|
||||
user = AuthService.create_user(db, user_data)
|
||||
user = await AuthService.create_user(db, user_data)
|
||||
return user
|
||||
except AuthenticationError as e:
|
||||
logger.warning(f"Registration failed: {str(e)}")
|
||||
except DuplicateError:
|
||||
# SECURITY: Don't reveal if email exists - generic error message
|
||||
logger.warning("Registration failed: duplicate email %s", user_data.email)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=str(e)
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Registration failed. Please check your information and try again.",
|
||||
)
|
||||
except AuthError as e:
|
||||
logger.warning("Registration failed: %s", e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Registration failed. Please check your information and try again.",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during registration: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="An unexpected error occurred. Please try again later."
|
||||
logger.exception("Unexpected error during registration: %s", e)
|
||||
raise DatabaseError(
|
||||
message="An unexpected error occurred. Please try again later.",
|
||||
error_code=ErrorCode.INTERNAL_ERROR,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/login", response_model=Token, operation_id="login")
|
||||
@limiter.limit(f"{10 * RATE_MULTIPLIER}/minute")
|
||||
async def login(
|
||||
request: Request,
|
||||
login_data: LoginRequest,
|
||||
db: Session = Depends(get_db)
|
||||
request: Request, login_data: LoginRequest, db: AsyncSession = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Login with username and password.
|
||||
@@ -91,79 +160,45 @@ async def login(
|
||||
"""
|
||||
try:
|
||||
# Attempt to authenticate the user
|
||||
user = AuthService.authenticate_user(db, login_data.email, login_data.password)
|
||||
user = await AuthService.authenticate_user(
|
||||
db, login_data.email, login_data.password
|
||||
)
|
||||
|
||||
# Explicitly check for None result and raise correct exception
|
||||
if user is None:
|
||||
logger.warning(f"Invalid login attempt for: {login_data.email}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid email or password",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
logger.warning("Invalid login attempt for: %s", login_data.email)
|
||||
raise AuthError(
|
||||
message="Invalid email or password",
|
||||
error_code=ErrorCode.INVALID_CREDENTIALS,
|
||||
)
|
||||
|
||||
# User is authenticated, generate tokens
|
||||
tokens = AuthService.create_tokens(user)
|
||||
|
||||
# Extract device information and create session record
|
||||
# Session creation is best-effort - we don't fail login if it fails
|
||||
try:
|
||||
device_info = extract_device_info(request)
|
||||
|
||||
# Decode refresh token to get JTI and expiration
|
||||
refresh_payload = decode_token(tokens.refresh_token, verify_type="refresh")
|
||||
|
||||
session_data = SessionCreate(
|
||||
user_id=user.id,
|
||||
refresh_token_jti=refresh_payload.jti,
|
||||
device_name=device_info.device_name,
|
||||
device_id=device_info.device_id,
|
||||
ip_address=device_info.ip_address,
|
||||
user_agent=device_info.user_agent,
|
||||
last_used_at=datetime.now(timezone.utc),
|
||||
expires_at=datetime.fromtimestamp(refresh_payload.exp, tz=timezone.utc),
|
||||
location_city=device_info.location_city,
|
||||
location_country=device_info.location_country,
|
||||
)
|
||||
|
||||
session_crud.create_session(db, obj_in=session_data)
|
||||
|
||||
logger.info(
|
||||
f"User login successful: {user.email} from {device_info.device_name} "
|
||||
f"(IP: {device_info.ip_address})"
|
||||
)
|
||||
except Exception as session_err:
|
||||
# Log but don't fail login if session creation fails
|
||||
logger.error(f"Failed to create session for {user.email}: {str(session_err)}", exc_info=True)
|
||||
# Create session record (best-effort, doesn't fail login)
|
||||
await _create_login_session(db, request, user, tokens, login_type="login")
|
||||
|
||||
return tokens
|
||||
|
||||
except HTTPException:
|
||||
# Re-raise HTTP exceptions without modification
|
||||
raise
|
||||
except AuthenticationError as e:
|
||||
# Handle specific authentication errors like inactive accounts
|
||||
logger.warning(f"Authentication failed: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=str(e),
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
logger.warning("Authentication failed: %s", e)
|
||||
raise AuthError(message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS)
|
||||
except Exception as e:
|
||||
# Handle unexpected errors
|
||||
logger.error(f"Unexpected error during login: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="An unexpected error occurred. Please try again later."
|
||||
logger.exception("Unexpected error during login: %s", e)
|
||||
raise DatabaseError(
|
||||
message="An unexpected error occurred. Please try again later.",
|
||||
error_code=ErrorCode.INTERNAL_ERROR,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/login/oauth", response_model=Token, operation_id='login_oauth')
|
||||
@router.post("/login/oauth", response_model=Token, operation_id="login_oauth")
|
||||
@limiter.limit("10/minute")
|
||||
async def login_oauth(
|
||||
request: Request,
|
||||
form_data: OAuth2PasswordRequestForm = Depends(),
|
||||
db: Session = Depends(get_db)
|
||||
request: Request,
|
||||
form_data: OAuth2PasswordRequestForm = Depends(),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
OAuth2-compatible login endpoint, used by the OpenAPI UI.
|
||||
@@ -174,74 +209,41 @@ async def login_oauth(
|
||||
Access and refresh tokens.
|
||||
"""
|
||||
try:
|
||||
user = AuthService.authenticate_user(db, form_data.username, form_data.password)
|
||||
user = await AuthService.authenticate_user(
|
||||
db, form_data.username, form_data.password
|
||||
)
|
||||
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid email or password",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
raise AuthError(
|
||||
message="Invalid email or password",
|
||||
error_code=ErrorCode.INVALID_CREDENTIALS,
|
||||
)
|
||||
|
||||
# Generate tokens
|
||||
tokens = AuthService.create_tokens(user)
|
||||
|
||||
# Extract device information and create session record
|
||||
# Session creation is best-effort - we don't fail login if it fails
|
||||
try:
|
||||
device_info = extract_device_info(request)
|
||||
# Create session record (best-effort, doesn't fail login)
|
||||
await _create_login_session(db, request, user, tokens, login_type="oauth")
|
||||
|
||||
# Decode refresh token to get JTI and expiration
|
||||
refresh_payload = decode_token(tokens.refresh_token, verify_type="refresh")
|
||||
|
||||
session_data = SessionCreate(
|
||||
user_id=user.id,
|
||||
refresh_token_jti=refresh_payload.jti,
|
||||
device_name=device_info.device_name or "API Client",
|
||||
device_id=device_info.device_id,
|
||||
ip_address=device_info.ip_address,
|
||||
user_agent=device_info.user_agent,
|
||||
last_used_at=datetime.now(timezone.utc),
|
||||
expires_at=datetime.fromtimestamp(refresh_payload.exp, tz=timezone.utc),
|
||||
location_city=device_info.location_city,
|
||||
location_country=device_info.location_country,
|
||||
)
|
||||
|
||||
session_crud.create_session(db, obj_in=session_data)
|
||||
|
||||
logger.info(f"OAuth login successful: {user.email} from {device_info.device_name}")
|
||||
except Exception as session_err:
|
||||
logger.error(f"Failed to create session for {user.email}: {str(session_err)}", exc_info=True)
|
||||
|
||||
# Format response for OAuth compatibility
|
||||
return {
|
||||
"access_token": tokens.access_token,
|
||||
"refresh_token": tokens.refresh_token,
|
||||
"token_type": tokens.token_type
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
# Return full token response with user data
|
||||
return tokens
|
||||
except AuthenticationError as e:
|
||||
logger.warning(f"OAuth authentication failed: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=str(e),
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
logger.warning("OAuth authentication failed: %s", e)
|
||||
raise AuthError(message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during OAuth login: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="An unexpected error occurred. Please try again later."
|
||||
logger.exception("Unexpected error during OAuth login: %s", e)
|
||||
raise DatabaseError(
|
||||
message="An unexpected error occurred. Please try again later.",
|
||||
error_code=ErrorCode.INTERNAL_ERROR,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=Token, operation_id="refresh_token")
|
||||
@limiter.limit("30/minute")
|
||||
async def refresh_token(
|
||||
request: Request,
|
||||
refresh_data: RefreshTokenRequest,
|
||||
db: Session = Depends(get_db)
|
||||
request: Request,
|
||||
refresh_data: RefreshTokenRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Refresh access token using a refresh token.
|
||||
@@ -253,13 +255,18 @@ async def refresh_token(
|
||||
"""
|
||||
try:
|
||||
# Decode the refresh token to get the JTI
|
||||
refresh_payload = decode_token(refresh_data.refresh_token, verify_type="refresh")
|
||||
refresh_payload = decode_token(
|
||||
refresh_data.refresh_token, verify_type="refresh"
|
||||
)
|
||||
|
||||
# Check if session exists and is active
|
||||
session = session_crud.get_active_by_jti(db, jti=refresh_payload.jti)
|
||||
session = await session_service.get_active_by_jti(db, jti=refresh_payload.jti)
|
||||
|
||||
if not session:
|
||||
logger.warning(f"Refresh token used for inactive or non-existent session: {refresh_payload.jti}")
|
||||
logger.warning(
|
||||
"Refresh token used for inactive or non-existent session: %s",
|
||||
refresh_payload.jti,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Session has been revoked. Please log in again.",
|
||||
@@ -267,21 +274,21 @@ async def refresh_token(
|
||||
)
|
||||
|
||||
# Generate new tokens
|
||||
tokens = AuthService.refresh_tokens(db, refresh_data.refresh_token)
|
||||
tokens = await AuthService.refresh_tokens(db, refresh_data.refresh_token)
|
||||
|
||||
# Decode new refresh token to get new JTI
|
||||
new_refresh_payload = decode_token(tokens.refresh_token, verify_type="refresh")
|
||||
|
||||
# Update session with new refresh token JTI and expiration
|
||||
try:
|
||||
session_crud.update_refresh_token(
|
||||
await session_service.update_refresh_token(
|
||||
db,
|
||||
session=session,
|
||||
new_jti=new_refresh_payload.jti,
|
||||
new_expires_at=datetime.fromtimestamp(new_refresh_payload.exp, tz=timezone.utc)
|
||||
new_expires_at=datetime.fromtimestamp(new_refresh_payload.exp, tz=UTC),
|
||||
)
|
||||
except Exception as session_err:
|
||||
logger.error(f"Failed to update session {session.id}: {str(session_err)}", exc_info=True)
|
||||
logger.exception("Failed to update session %s: %s", session.id, session_err)
|
||||
# Continue anyway - tokens are already issued
|
||||
|
||||
return tokens
|
||||
@@ -304,27 +311,13 @@ async def refresh_token(
|
||||
# Re-raise HTTP exceptions (like session revoked)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during token refresh: {str(e)}")
|
||||
logger.error("Unexpected error during token refresh: %s", e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="An unexpected error occurred. Please try again later."
|
||||
detail="An unexpected error occurred. Please try again later.",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserResponse, operation_id="get_current_user_info")
|
||||
@limiter.limit("60/minute")
|
||||
async def get_current_user_info(
|
||||
request: Request,
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> Any:
|
||||
"""
|
||||
Get current user information.
|
||||
|
||||
Requires authentication.
|
||||
"""
|
||||
return current_user
|
||||
|
||||
|
||||
@router.post(
|
||||
"/password-reset/request",
|
||||
response_model=MessageResponse,
|
||||
@@ -338,13 +331,13 @@ async def get_current_user_info(
|
||||
|
||||
**Rate Limit**: 3 requests/minute
|
||||
""",
|
||||
operation_id="request_password_reset"
|
||||
operation_id="request_password_reset",
|
||||
)
|
||||
@limiter.limit("3/minute")
|
||||
async def request_password_reset(
|
||||
request: Request,
|
||||
reset_request: PasswordResetRequest,
|
||||
db: Session = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Request a password reset.
|
||||
@@ -354,7 +347,7 @@ async def request_password_reset(
|
||||
"""
|
||||
try:
|
||||
# Look up user by email
|
||||
user = user_crud.get_by_email(db, email=reset_request.email)
|
||||
user = await user_service.get_by_email(db, email=reset_request.email)
|
||||
|
||||
# Only send email if user exists and is active
|
||||
if user and user.is_active:
|
||||
@@ -363,26 +356,27 @@ async def request_password_reset(
|
||||
|
||||
# Send password reset email
|
||||
await email_service.send_password_reset_email(
|
||||
to_email=user.email,
|
||||
reset_token=reset_token,
|
||||
user_name=user.first_name
|
||||
to_email=user.email, reset_token=reset_token, user_name=user.first_name
|
||||
)
|
||||
logger.info(f"Password reset requested for {user.email}")
|
||||
logger.info("Password reset requested for %s", user.email)
|
||||
else:
|
||||
# Log attempt but don't reveal if email exists
|
||||
logger.warning(f"Password reset requested for non-existent or inactive email: {reset_request.email}")
|
||||
logger.warning(
|
||||
"Password reset requested for non-existent or inactive email: %s",
|
||||
reset_request.email,
|
||||
)
|
||||
|
||||
# Always return success to prevent email enumeration
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message="If your email is registered, you will receive a password reset link shortly"
|
||||
message="If your email is registered, you will receive a password reset link shortly",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing password reset request: {str(e)}", exc_info=True)
|
||||
logger.exception("Error processing password reset request: %s", e)
|
||||
# Still return success to prevent information leakage
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message="If your email is registered, you will receive a password reset link shortly"
|
||||
message="If your email is registered, you will receive a password reset link shortly",
|
||||
)
|
||||
|
||||
|
||||
@@ -396,13 +390,13 @@ async def request_password_reset(
|
||||
|
||||
**Rate Limit**: 5 requests/minute
|
||||
""",
|
||||
operation_id="confirm_password_reset"
|
||||
operation_id="confirm_password_reset",
|
||||
)
|
||||
@limiter.limit("5/minute")
|
||||
def confirm_password_reset(
|
||||
async def confirm_password_reset(
|
||||
request: Request,
|
||||
reset_confirm: PasswordResetConfirm,
|
||||
db: Session = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Confirm password reset with token.
|
||||
@@ -416,44 +410,52 @@ def confirm_password_reset(
|
||||
if not email:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid or expired password reset token"
|
||||
detail="Invalid or expired password reset token",
|
||||
)
|
||||
|
||||
# Look up user
|
||||
user = user_crud.get_by_email(db, email=email)
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
# Reset password via service (validates user exists and is active)
|
||||
try:
|
||||
user = await AuthService.reset_password(
|
||||
db, email=email, new_password=reset_confirm.new_password
|
||||
)
|
||||
except AuthenticationError as e:
|
||||
err_msg = str(e)
|
||||
if "inactive" in err_msg.lower():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail=err_msg
|
||||
)
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=err_msg)
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="User account is inactive"
|
||||
# SECURITY: Invalidate all existing sessions after password reset
|
||||
# This prevents stolen sessions from being used after password change
|
||||
try:
|
||||
deactivated_count = await session_service.deactivate_all_user_sessions(
|
||||
db, user_id=str(user.id)
|
||||
)
|
||||
logger.info(
|
||||
"Password reset successful for %s, invalidated %s sessions",
|
||||
user.email,
|
||||
deactivated_count,
|
||||
)
|
||||
except Exception as session_error:
|
||||
# Log but don't fail password reset if session invalidation fails
|
||||
logger.error(
|
||||
"Failed to invalidate sessions after password reset: %s", session_error
|
||||
)
|
||||
|
||||
# Update password
|
||||
user.password_hash = get_password_hash(reset_confirm.new_password)
|
||||
db.add(user)
|
||||
db.commit()
|
||||
|
||||
logger.info(f"Password reset successful for {user.email}")
|
||||
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message="Password has been reset successfully. You can now log in with your new password."
|
||||
message="Password has been reset successfully. All devices have been logged out for security. You can now log in with your new password.",
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error confirming password reset: {str(e)}", exc_info=True)
|
||||
db.rollback()
|
||||
logger.exception("Error confirming password reset: %s", e)
|
||||
await db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="An error occurred while resetting your password"
|
||||
detail="An error occurred while resetting your password",
|
||||
)
|
||||
|
||||
|
||||
@@ -471,14 +473,14 @@ def confirm_password_reset(
|
||||
|
||||
**Rate Limit**: 10 requests/minute
|
||||
""",
|
||||
operation_id="logout"
|
||||
operation_id="logout",
|
||||
)
|
||||
@limiter.limit("10/minute")
|
||||
def logout(
|
||||
async def logout(
|
||||
request: Request,
|
||||
logout_request: LogoutRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Logout from current device by deactivating the session.
|
||||
@@ -494,57 +496,57 @@ def logout(
|
||||
try:
|
||||
# Decode refresh token to get JTI
|
||||
try:
|
||||
refresh_payload = decode_token(logout_request.refresh_token, verify_type="refresh")
|
||||
refresh_payload = decode_token(
|
||||
logout_request.refresh_token, verify_type="refresh"
|
||||
)
|
||||
except (TokenExpiredError, TokenInvalidError) as e:
|
||||
# Even if token is expired/invalid, try to deactivate session
|
||||
logger.warning(f"Logout with invalid/expired token: {str(e)}")
|
||||
logger.warning("Logout with invalid/expired token: %s", e)
|
||||
# Don't fail - return success anyway
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message="Logged out successfully"
|
||||
)
|
||||
return MessageResponse(success=True, message="Logged out successfully")
|
||||
|
||||
# Find the session by JTI
|
||||
session = session_crud.get_by_jti(db, jti=refresh_payload.jti)
|
||||
session = await session_service.get_by_jti(db, jti=refresh_payload.jti)
|
||||
|
||||
if session:
|
||||
# Verify session belongs to current user (security check)
|
||||
if str(session.user_id) != str(current_user.id):
|
||||
logger.warning(
|
||||
f"User {current_user.id} attempted to logout session {session.id} "
|
||||
f"belonging to user {session.user_id}"
|
||||
"User %s attempted to logout session %s belonging to user %s",
|
||||
current_user.id,
|
||||
session.id,
|
||||
session.user_id,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You can only logout your own sessions"
|
||||
detail="You can only logout your own sessions",
|
||||
)
|
||||
|
||||
# Deactivate the session
|
||||
session_crud.deactivate(db, session_id=str(session.id))
|
||||
await session_service.deactivate(db, session_id=str(session.id))
|
||||
|
||||
logger.info(
|
||||
f"User {current_user.id} logged out from {session.device_name} "
|
||||
f"(session {session.id})"
|
||||
"User %s logged out from %s (session %s)",
|
||||
current_user.id,
|
||||
session.device_name,
|
||||
session.id,
|
||||
)
|
||||
else:
|
||||
# Session not found - maybe already deleted or never existed
|
||||
# Return success anyway (idempotent)
|
||||
logger.info(f"Logout requested for non-existent session (JTI: {refresh_payload.jti})")
|
||||
logger.info(
|
||||
"Logout requested for non-existent session (JTI: %s)",
|
||||
refresh_payload.jti,
|
||||
)
|
||||
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message="Logged out successfully"
|
||||
)
|
||||
return MessageResponse(success=True, message="Logged out successfully")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error during logout for user {current_user.id}: {str(e)}", exc_info=True)
|
||||
logger.exception("Error during logout for user %s: %s", current_user.id, e)
|
||||
# Don't expose error details
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message="Logged out successfully"
|
||||
)
|
||||
return MessageResponse(success=True, message="Logged out successfully")
|
||||
|
||||
|
||||
@router.post(
|
||||
@@ -560,13 +562,13 @@ def logout(
|
||||
|
||||
**Rate Limit**: 5 requests/minute
|
||||
""",
|
||||
operation_id="logout_all"
|
||||
operation_id="logout_all",
|
||||
)
|
||||
@limiter.limit("5/minute")
|
||||
def logout_all(
|
||||
async def logout_all(
|
||||
request: Request,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Logout from all devices by deactivating all user sessions.
|
||||
@@ -580,19 +582,23 @@ def logout_all(
|
||||
"""
|
||||
try:
|
||||
# Deactivate all sessions for this user
|
||||
count = session_crud.deactivate_all_user_sessions(db, user_id=str(current_user.id))
|
||||
count = await session_service.deactivate_all_user_sessions(
|
||||
db, user_id=str(current_user.id)
|
||||
)
|
||||
|
||||
logger.info(f"User {current_user.id} logged out from all devices ({count} sessions)")
|
||||
logger.info(
|
||||
"User %s logged out from all devices (%s sessions)", current_user.id, count
|
||||
)
|
||||
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message=f"Successfully logged out from all devices ({count} sessions terminated)"
|
||||
message=f"Successfully logged out from all devices ({count} sessions terminated)",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during logout-all for user {current_user.id}: {str(e)}", exc_info=True)
|
||||
db.rollback()
|
||||
logger.exception("Error during logout-all for user %s: %s", current_user.id, e)
|
||||
await db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="An error occurred while logging out"
|
||||
detail="An error occurred while logging out",
|
||||
)
|
||||
|
||||
434
backend/app/api/routes/oauth.py
Normal file
434
backend/app/api/routes/oauth.py
Normal file
@@ -0,0 +1,434 @@
|
||||
# app/api/routes/oauth.py
|
||||
"""
|
||||
OAuth routes for social authentication.
|
||||
|
||||
Endpoints:
|
||||
- GET /oauth/providers - List enabled OAuth providers
|
||||
- GET /oauth/authorize/{provider} - Get authorization URL
|
||||
- POST /oauth/callback/{provider} - Handle OAuth callback
|
||||
- GET /oauth/accounts - List linked OAuth accounts
|
||||
- DELETE /oauth/accounts/{provider} - Unlink an OAuth account
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.dependencies.auth import get_current_user, get_optional_current_user
|
||||
from app.core.auth import decode_token
|
||||
from app.core.config import settings
|
||||
from app.core.database import get_db
|
||||
from app.core.exceptions import AuthenticationError as AuthError
|
||||
from app.models.user import User
|
||||
from app.schemas.oauth import (
|
||||
OAuthAccountsListResponse,
|
||||
OAuthCallbackRequest,
|
||||
OAuthCallbackResponse,
|
||||
OAuthProvidersResponse,
|
||||
OAuthUnlinkResponse,
|
||||
)
|
||||
from app.schemas.sessions import SessionCreate
|
||||
from app.schemas.users import Token
|
||||
from app.services.oauth_service import OAuthService
|
||||
from app.services.session_service import session_service
|
||||
from app.utils.device import extract_device_info
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Initialize limiter for this router
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
# Use higher rate limits in test environment
|
||||
IS_TEST = os.getenv("IS_TEST", "False") == "True"
|
||||
RATE_MULTIPLIER = 100 if IS_TEST else 1
|
||||
|
||||
|
||||
async def _create_oauth_login_session(
|
||||
db: AsyncSession,
|
||||
request: Request,
|
||||
user: User,
|
||||
tokens: Token,
|
||||
provider: str,
|
||||
) -> None:
|
||||
"""
|
||||
Create a session record for successful OAuth login.
|
||||
|
||||
This is a best-effort operation - login succeeds even if session creation fails.
|
||||
"""
|
||||
try:
|
||||
device_info = extract_device_info(request)
|
||||
|
||||
# Decode refresh token to get JTI and expiration
|
||||
refresh_payload = decode_token(tokens.refresh_token, verify_type="refresh")
|
||||
|
||||
session_data = SessionCreate(
|
||||
user_id=user.id,
|
||||
refresh_token_jti=refresh_payload.jti,
|
||||
device_name=device_info.device_name or f"OAuth ({provider})",
|
||||
device_id=device_info.device_id,
|
||||
ip_address=device_info.ip_address,
|
||||
user_agent=device_info.user_agent,
|
||||
last_used_at=datetime.now(UTC),
|
||||
expires_at=datetime.fromtimestamp(refresh_payload.exp, tz=UTC),
|
||||
location_city=device_info.location_city,
|
||||
location_country=device_info.location_country,
|
||||
)
|
||||
|
||||
await session_service.create_session(db, obj_in=session_data)
|
||||
|
||||
logger.info(
|
||||
"OAuth login successful: %s via %s from %s (IP: %s)",
|
||||
user.email,
|
||||
provider,
|
||||
device_info.device_name,
|
||||
device_info.ip_address,
|
||||
)
|
||||
except Exception as session_err:
|
||||
# Log but don't fail login if session creation fails
|
||||
logger.exception(
|
||||
"Failed to create session for OAuth login %s: %s", user.email, session_err
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/providers",
|
||||
response_model=OAuthProvidersResponse,
|
||||
summary="List OAuth Providers",
|
||||
description="""
|
||||
Get list of enabled OAuth providers for the login/register UI.
|
||||
|
||||
Returns:
|
||||
List of enabled providers with display info.
|
||||
""",
|
||||
operation_id="list_oauth_providers",
|
||||
)
|
||||
async def list_providers() -> Any:
|
||||
"""
|
||||
Get list of enabled OAuth providers.
|
||||
|
||||
This endpoint is public (no authentication required) as it's needed
|
||||
for the login/register UI to display available social login options.
|
||||
"""
|
||||
return OAuthService.get_enabled_providers()
|
||||
|
||||
|
||||
@router.get(
|
||||
"/authorize/{provider}",
|
||||
response_model=dict,
|
||||
summary="Get OAuth Authorization URL",
|
||||
description="""
|
||||
Get the authorization URL to redirect the user to the OAuth provider.
|
||||
|
||||
The frontend should redirect the user to the returned URL.
|
||||
After authentication, the provider will redirect back to the callback URL.
|
||||
|
||||
**Rate Limit**: 10 requests/minute
|
||||
""",
|
||||
operation_id="get_oauth_authorization_url",
|
||||
)
|
||||
@limiter.limit(f"{10 * RATE_MULTIPLIER}/minute")
|
||||
async def get_authorization_url(
|
||||
request: Request,
|
||||
provider: str,
|
||||
redirect_uri: str = Query(
|
||||
..., description="Frontend callback URL after OAuth completes"
|
||||
),
|
||||
current_user: User | None = Depends(get_optional_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Get OAuth authorization URL.
|
||||
|
||||
Args:
|
||||
provider: OAuth provider (google, github)
|
||||
redirect_uri: Frontend callback URL
|
||||
current_user: Current user (optional, for account linking)
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
dict with authorization_url and state
|
||||
"""
|
||||
if not settings.OAUTH_ENABLED:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="OAuth is not enabled",
|
||||
)
|
||||
|
||||
try:
|
||||
# If user is logged in, this is an account linking flow
|
||||
user_id = str(current_user.id) if current_user else None
|
||||
|
||||
url, state = await OAuthService.create_authorization_url(
|
||||
db,
|
||||
provider=provider,
|
||||
redirect_uri=redirect_uri,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
return {
|
||||
"authorization_url": url,
|
||||
"state": state,
|
||||
}
|
||||
|
||||
except AuthError as e:
|
||||
logger.warning("OAuth authorization failed: %s", e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("OAuth authorization error: %s", e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to create authorization URL",
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/callback/{provider}",
|
||||
response_model=OAuthCallbackResponse,
|
||||
summary="OAuth Callback",
|
||||
description="""
|
||||
Handle OAuth callback from provider.
|
||||
|
||||
The frontend should call this endpoint with the code and state
|
||||
parameters received from the OAuth provider redirect.
|
||||
|
||||
Returns:
|
||||
JWT tokens for the authenticated user.
|
||||
|
||||
**Rate Limit**: 10 requests/minute
|
||||
""",
|
||||
operation_id="handle_oauth_callback",
|
||||
)
|
||||
@limiter.limit(f"{10 * RATE_MULTIPLIER}/minute")
|
||||
async def handle_callback(
|
||||
request: Request,
|
||||
provider: str,
|
||||
callback_data: OAuthCallbackRequest,
|
||||
redirect_uri: str = Query(
|
||||
..., description="Must match the redirect_uri used in authorization"
|
||||
),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Handle OAuth callback.
|
||||
|
||||
Args:
|
||||
provider: OAuth provider (google, github)
|
||||
callback_data: Code and state from provider
|
||||
redirect_uri: Original redirect URI (for validation)
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
OAuthCallbackResponse with tokens
|
||||
"""
|
||||
if not settings.OAUTH_ENABLED:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="OAuth is not enabled",
|
||||
)
|
||||
|
||||
try:
|
||||
result = await OAuthService.handle_callback(
|
||||
db,
|
||||
code=callback_data.code,
|
||||
state=callback_data.state,
|
||||
redirect_uri=redirect_uri,
|
||||
)
|
||||
|
||||
# Create session for the login (need to get the user first)
|
||||
# Note: This requires fetching the user from the token
|
||||
# For now, we skip session creation here as the result doesn't include user info
|
||||
# The session will be created on next request if needed
|
||||
|
||||
return result
|
||||
|
||||
except AuthError as e:
|
||||
logger.warning("OAuth callback failed: %s", e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=str(e),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("OAuth callback error: %s", e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="OAuth authentication failed",
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/accounts",
|
||||
response_model=OAuthAccountsListResponse,
|
||||
summary="List Linked OAuth Accounts",
|
||||
description="""
|
||||
Get list of OAuth accounts linked to the current user.
|
||||
|
||||
Requires authentication.
|
||||
""",
|
||||
operation_id="list_oauth_accounts",
|
||||
)
|
||||
async def list_accounts(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
List OAuth accounts linked to the current user.
|
||||
|
||||
Args:
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
List of linked OAuth accounts
|
||||
"""
|
||||
accounts = await OAuthService.get_user_accounts(db, user_id=current_user.id)
|
||||
return OAuthAccountsListResponse(accounts=accounts)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/accounts/{provider}",
|
||||
response_model=OAuthUnlinkResponse,
|
||||
summary="Unlink OAuth Account",
|
||||
description="""
|
||||
Unlink an OAuth provider from the current user.
|
||||
|
||||
The user must have either a password set or another OAuth provider
|
||||
linked to ensure they can still log in.
|
||||
|
||||
**Rate Limit**: 5 requests/minute
|
||||
""",
|
||||
operation_id="unlink_oauth_account",
|
||||
)
|
||||
@limiter.limit(f"{5 * RATE_MULTIPLIER}/minute")
|
||||
async def unlink_account(
|
||||
request: Request,
|
||||
provider: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Unlink an OAuth provider from the current user.
|
||||
|
||||
Args:
|
||||
provider: Provider to unlink (google, github)
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Success message
|
||||
"""
|
||||
try:
|
||||
await OAuthService.unlink_provider(
|
||||
db,
|
||||
user=current_user,
|
||||
provider=provider,
|
||||
)
|
||||
|
||||
return OAuthUnlinkResponse(
|
||||
success=True,
|
||||
message=f"{provider.capitalize()} account unlinked successfully",
|
||||
)
|
||||
|
||||
except AuthError as e:
|
||||
logger.warning("OAuth unlink failed for %s: %s", current_user.email, e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("OAuth unlink error: %s", e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to unlink OAuth account",
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/link/{provider}",
|
||||
response_model=dict,
|
||||
summary="Start Account Linking",
|
||||
description="""
|
||||
Start the OAuth flow to link a new provider to the current user.
|
||||
|
||||
This is a convenience endpoint that redirects to /authorize/{provider}
|
||||
with the current user context.
|
||||
|
||||
**Rate Limit**: 10 requests/minute
|
||||
""",
|
||||
operation_id="start_oauth_link",
|
||||
)
|
||||
@limiter.limit(f"{10 * RATE_MULTIPLIER}/minute")
|
||||
async def start_link(
|
||||
request: Request,
|
||||
provider: str,
|
||||
redirect_uri: str = Query(
|
||||
..., description="Frontend callback URL after OAuth completes"
|
||||
),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Start OAuth account linking flow.
|
||||
|
||||
This endpoint requires authentication and will initiate an OAuth flow
|
||||
to link a new provider to the current user's account.
|
||||
|
||||
Args:
|
||||
provider: OAuth provider to link (google, github)
|
||||
redirect_uri: Frontend callback URL
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
dict with authorization_url and state
|
||||
"""
|
||||
if not settings.OAUTH_ENABLED:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="OAuth is not enabled",
|
||||
)
|
||||
|
||||
# Check if user already has this provider linked
|
||||
existing = await OAuthService.get_user_account_by_provider(
|
||||
db, user_id=current_user.id, provider=provider
|
||||
)
|
||||
if existing:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"You already have a {provider} account linked",
|
||||
)
|
||||
|
||||
try:
|
||||
url, state = await OAuthService.create_authorization_url(
|
||||
db,
|
||||
provider=provider,
|
||||
redirect_uri=redirect_uri,
|
||||
user_id=str(current_user.id),
|
||||
)
|
||||
|
||||
return {
|
||||
"authorization_url": url,
|
||||
"state": state,
|
||||
}
|
||||
|
||||
except AuthError as e:
|
||||
logger.warning("OAuth link authorization failed: %s", e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("OAuth link error: %s", e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to create authorization URL",
|
||||
)
|
||||
824
backend/app/api/routes/oauth_provider.py
Normal file
824
backend/app/api/routes/oauth_provider.py
Normal file
@@ -0,0 +1,824 @@
|
||||
# app/api/routes/oauth_provider.py
|
||||
"""
|
||||
OAuth Provider routes (Authorization Server mode) for MCP integration.
|
||||
|
||||
Implements OAuth 2.0 Authorization Server endpoints:
|
||||
- GET /.well-known/oauth-authorization-server - Server metadata (RFC 8414)
|
||||
- GET /oauth/provider/authorize - Authorization endpoint
|
||||
- POST /oauth/provider/token - Token endpoint
|
||||
- POST /oauth/provider/revoke - Token revocation (RFC 7009)
|
||||
- POST /oauth/provider/introspect - Token introspection (RFC 7662)
|
||||
- Client management endpoints
|
||||
|
||||
Security features:
|
||||
- PKCE required for public clients (S256)
|
||||
- CSRF protection via state parameter
|
||||
- Secure token handling
|
||||
- Rate limiting on sensitive endpoints
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from fastapi import APIRouter, Depends, Form, HTTPException, Query, Request, status
|
||||
from fastapi.responses import RedirectResponse
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.dependencies.auth import (
|
||||
get_current_active_user,
|
||||
get_current_superuser,
|
||||
get_optional_current_user,
|
||||
)
|
||||
from app.core.config import settings
|
||||
from app.core.database import get_db
|
||||
from app.models.user import User
|
||||
from app.schemas.oauth import (
|
||||
OAuthClientCreate,
|
||||
OAuthClientResponse,
|
||||
OAuthServerMetadata,
|
||||
OAuthTokenIntrospectionResponse,
|
||||
OAuthTokenResponse,
|
||||
)
|
||||
from app.services import oauth_provider_service as provider_service
|
||||
|
||||
router = APIRouter()
|
||||
# Separate router for RFC 8414 well-known endpoint (registered at root level)
|
||||
wellknown_router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
|
||||
def require_provider_enabled():
|
||||
"""Dependency to check if OAuth provider mode is enabled."""
|
||||
if not settings.OAUTH_PROVIDER_ENABLED:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="OAuth provider mode is not enabled. Set OAUTH_PROVIDER_ENABLED=true",
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Server Metadata (RFC 8414)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@wellknown_router.get(
|
||||
"/.well-known/oauth-authorization-server",
|
||||
response_model=OAuthServerMetadata,
|
||||
summary="OAuth Server Metadata",
|
||||
description="""
|
||||
OAuth 2.0 Authorization Server Metadata (RFC 8414).
|
||||
|
||||
Returns server metadata including supported endpoints, scopes,
|
||||
and capabilities. MCP clients use this to discover the server.
|
||||
|
||||
Note: This endpoint is at the root level per RFC 8414.
|
||||
""",
|
||||
operation_id="get_oauth_server_metadata",
|
||||
tags=["OAuth Provider"],
|
||||
)
|
||||
async def get_server_metadata(
|
||||
_: None = Depends(require_provider_enabled),
|
||||
) -> OAuthServerMetadata:
|
||||
"""Get OAuth 2.0 server metadata."""
|
||||
base_url = settings.OAUTH_ISSUER.rstrip("/")
|
||||
|
||||
return OAuthServerMetadata(
|
||||
issuer=base_url,
|
||||
authorization_endpoint=f"{base_url}/api/v1/oauth/provider/authorize",
|
||||
token_endpoint=f"{base_url}/api/v1/oauth/provider/token",
|
||||
revocation_endpoint=f"{base_url}/api/v1/oauth/provider/revoke",
|
||||
introspection_endpoint=f"{base_url}/api/v1/oauth/provider/introspect",
|
||||
registration_endpoint=None, # Dynamic registration not supported
|
||||
scopes_supported=[
|
||||
"openid",
|
||||
"profile",
|
||||
"email",
|
||||
"read:users",
|
||||
"write:users",
|
||||
"read:organizations",
|
||||
"write:organizations",
|
||||
"admin",
|
||||
],
|
||||
response_types_supported=["code"],
|
||||
grant_types_supported=["authorization_code", "refresh_token"],
|
||||
code_challenge_methods_supported=["S256"],
|
||||
token_endpoint_auth_methods_supported=[
|
||||
"client_secret_basic",
|
||||
"client_secret_post",
|
||||
"none", # For public clients with PKCE
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Authorization Endpoint
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.get(
|
||||
"/provider/authorize",
|
||||
summary="Authorization Endpoint",
|
||||
description="""
|
||||
OAuth 2.0 Authorization Endpoint.
|
||||
|
||||
Initiates the authorization code flow:
|
||||
1. Validates client and parameters
|
||||
2. Checks if user is authenticated (redirects to login if not)
|
||||
3. Checks existing consent
|
||||
4. Redirects to consent page if needed
|
||||
5. Issues authorization code and redirects back to client
|
||||
|
||||
Required parameters:
|
||||
- response_type: Must be "code"
|
||||
- client_id: Registered client ID
|
||||
- redirect_uri: Must match registered URI
|
||||
|
||||
Recommended parameters:
|
||||
- state: CSRF protection
|
||||
- code_challenge + code_challenge_method: PKCE (required for public clients)
|
||||
- scope: Requested permissions
|
||||
""",
|
||||
operation_id="oauth_provider_authorize",
|
||||
tags=["OAuth Provider"],
|
||||
)
|
||||
@limiter.limit("30/minute")
|
||||
async def authorize(
|
||||
request: Request,
|
||||
response_type: str = Query(..., description="Must be 'code'"),
|
||||
client_id: str = Query(..., description="OAuth client ID"),
|
||||
redirect_uri: str = Query(..., description="Redirect URI"),
|
||||
scope: str = Query(default="", description="Requested scopes (space-separated)"),
|
||||
state: str = Query(default="", description="CSRF state parameter"),
|
||||
code_challenge: str | None = Query(default=None, description="PKCE code challenge"),
|
||||
code_challenge_method: str | None = Query(
|
||||
default=None, description="PKCE method (S256)"
|
||||
),
|
||||
nonce: str | None = Query(default=None, description="OpenID Connect nonce"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_: None = Depends(require_provider_enabled),
|
||||
current_user: User | None = Depends(get_optional_current_user),
|
||||
) -> Any:
|
||||
"""
|
||||
Authorization endpoint - initiates OAuth flow.
|
||||
|
||||
If user is not authenticated, redirects to login with return URL.
|
||||
If user has not consented, redirects to consent page.
|
||||
If all checks pass, generates code and redirects to client.
|
||||
"""
|
||||
# Validate response_type
|
||||
if response_type != "code":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="invalid_request: response_type must be 'code'",
|
||||
)
|
||||
|
||||
# Validate PKCE method if provided - ONLY S256 is allowed (RFC 7636 Section 4.3)
|
||||
# "plain" method provides no security benefit and MUST NOT be used
|
||||
if code_challenge_method and code_challenge_method != "S256":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="invalid_request: code_challenge_method must be 'S256' (plain is not supported)",
|
||||
)
|
||||
|
||||
# Validate client
|
||||
try:
|
||||
client = await provider_service.get_client(db, client_id)
|
||||
if not client:
|
||||
raise provider_service.InvalidClientError("Unknown client_id")
|
||||
provider_service.validate_redirect_uri(client, redirect_uri)
|
||||
except provider_service.OAuthProviderError as e:
|
||||
# For client/redirect errors, we can't safely redirect - show error
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"{e.error}: {e.error_description}",
|
||||
)
|
||||
|
||||
# Validate and filter scopes
|
||||
try:
|
||||
requested_scopes = provider_service.parse_scope(scope)
|
||||
valid_scopes = provider_service.validate_scopes(client, requested_scopes)
|
||||
except provider_service.InvalidScopeError as e:
|
||||
# Redirect with error
|
||||
scope_error_params: dict[str, str] = {"error": e.error}
|
||||
if e.error_description:
|
||||
scope_error_params["error_description"] = e.error_description
|
||||
if state:
|
||||
scope_error_params["state"] = state
|
||||
return RedirectResponse(
|
||||
url=f"{redirect_uri}?{urlencode(scope_error_params)}",
|
||||
status_code=status.HTTP_302_FOUND,
|
||||
)
|
||||
|
||||
# Public clients MUST use PKCE
|
||||
if client.client_type == "public":
|
||||
if not code_challenge or code_challenge_method != "S256":
|
||||
pkce_error_params: dict[str, str] = {
|
||||
"error": "invalid_request",
|
||||
"error_description": "PKCE with S256 is required for public clients",
|
||||
}
|
||||
if state:
|
||||
pkce_error_params["state"] = state
|
||||
return RedirectResponse(
|
||||
url=f"{redirect_uri}?{urlencode(pkce_error_params)}",
|
||||
status_code=status.HTTP_302_FOUND,
|
||||
)
|
||||
|
||||
# If user is not authenticated, redirect to login
|
||||
if not current_user:
|
||||
# Store authorization request in session and redirect to login
|
||||
# The frontend will handle the return URL
|
||||
login_url = f"{settings.FRONTEND_URL}/login"
|
||||
return_params = urlencode(
|
||||
{
|
||||
"oauth_authorize": "true",
|
||||
"client_id": client_id,
|
||||
"redirect_uri": redirect_uri,
|
||||
"scope": " ".join(valid_scopes),
|
||||
"state": state,
|
||||
"code_challenge": code_challenge or "",
|
||||
"code_challenge_method": code_challenge_method or "",
|
||||
"nonce": nonce or "",
|
||||
}
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{login_url}?return_to=/auth/consent?{return_params}",
|
||||
status_code=status.HTTP_302_FOUND,
|
||||
)
|
||||
|
||||
# Check if user has already consented
|
||||
has_consent = await provider_service.check_consent(
|
||||
db, current_user.id, client_id, valid_scopes
|
||||
)
|
||||
|
||||
if not has_consent:
|
||||
# Redirect to consent page
|
||||
consent_params = urlencode(
|
||||
{
|
||||
"client_id": client_id,
|
||||
"client_name": client.client_name,
|
||||
"redirect_uri": redirect_uri,
|
||||
"scope": " ".join(valid_scopes),
|
||||
"state": state,
|
||||
"code_challenge": code_challenge or "",
|
||||
"code_challenge_method": code_challenge_method or "",
|
||||
"nonce": nonce or "",
|
||||
}
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{settings.FRONTEND_URL}/auth/consent?{consent_params}",
|
||||
status_code=status.HTTP_302_FOUND,
|
||||
)
|
||||
|
||||
# User is authenticated and has consented - issue authorization code
|
||||
try:
|
||||
code = await provider_service.create_authorization_code(
|
||||
db=db,
|
||||
client=client,
|
||||
user=current_user,
|
||||
redirect_uri=redirect_uri,
|
||||
scope=" ".join(valid_scopes),
|
||||
code_challenge=code_challenge,
|
||||
code_challenge_method=code_challenge_method,
|
||||
state=state,
|
||||
nonce=nonce,
|
||||
)
|
||||
except provider_service.OAuthProviderError as e:
|
||||
error_params: dict[str, str] = {"error": e.error}
|
||||
if e.error_description:
|
||||
error_params["error_description"] = e.error_description
|
||||
if state:
|
||||
error_params["state"] = state
|
||||
return RedirectResponse(
|
||||
url=f"{redirect_uri}?{urlencode(error_params)}",
|
||||
status_code=status.HTTP_302_FOUND,
|
||||
)
|
||||
|
||||
# Success - redirect with code
|
||||
success_params = {"code": code}
|
||||
if state:
|
||||
success_params["state"] = state
|
||||
|
||||
return RedirectResponse(
|
||||
url=f"{redirect_uri}?{urlencode(success_params)}",
|
||||
status_code=status.HTTP_302_FOUND,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/provider/authorize/consent",
|
||||
summary="Submit Authorization Consent",
|
||||
description="""
|
||||
Submit user consent for OAuth authorization.
|
||||
|
||||
Called by the consent page after user approves or denies.
|
||||
""",
|
||||
operation_id="oauth_provider_consent",
|
||||
tags=["OAuth Provider"],
|
||||
)
|
||||
@limiter.limit("30/minute")
|
||||
async def submit_consent(
|
||||
request: Request,
|
||||
approved: bool = Form(..., description="Whether user approved"),
|
||||
client_id: str = Form(..., description="OAuth client ID"),
|
||||
redirect_uri: str = Form(..., description="Redirect URI"),
|
||||
scope: str = Form(default="", description="Granted scopes"),
|
||||
state: str = Form(default="", description="CSRF state parameter"),
|
||||
code_challenge: str | None = Form(default=None),
|
||||
code_challenge_method: str | None = Form(default=None),
|
||||
nonce: str | None = Form(default=None),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_: None = Depends(require_provider_enabled),
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
) -> Any:
|
||||
"""Process consent form submission."""
|
||||
# Validate client
|
||||
try:
|
||||
client = await provider_service.get_client(db, client_id)
|
||||
if not client:
|
||||
raise provider_service.InvalidClientError("Unknown client_id")
|
||||
provider_service.validate_redirect_uri(client, redirect_uri)
|
||||
except provider_service.OAuthProviderError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"{e.error}: {e.error_description}",
|
||||
)
|
||||
|
||||
# If user denied, redirect with error
|
||||
if not approved:
|
||||
denied_params: dict[str, str] = {
|
||||
"error": "access_denied",
|
||||
"error_description": "User denied authorization",
|
||||
}
|
||||
if state:
|
||||
denied_params["state"] = state
|
||||
return RedirectResponse(
|
||||
url=f"{redirect_uri}?{urlencode(denied_params)}",
|
||||
status_code=status.HTTP_302_FOUND,
|
||||
)
|
||||
|
||||
# Parse and validate scopes
|
||||
granted_scopes = provider_service.parse_scope(scope)
|
||||
valid_scopes = provider_service.validate_scopes(client, granted_scopes)
|
||||
|
||||
# Record consent
|
||||
await provider_service.grant_consent(db, current_user.id, client_id, valid_scopes)
|
||||
|
||||
# Generate authorization code
|
||||
try:
|
||||
code = await provider_service.create_authorization_code(
|
||||
db=db,
|
||||
client=client,
|
||||
user=current_user,
|
||||
redirect_uri=redirect_uri,
|
||||
scope=" ".join(valid_scopes),
|
||||
code_challenge=code_challenge,
|
||||
code_challenge_method=code_challenge_method,
|
||||
state=state,
|
||||
nonce=nonce,
|
||||
)
|
||||
except provider_service.OAuthProviderError as e:
|
||||
error_params: dict[str, str] = {"error": e.error}
|
||||
if e.error_description:
|
||||
error_params["error_description"] = e.error_description
|
||||
if state:
|
||||
error_params["state"] = state
|
||||
return RedirectResponse(
|
||||
url=f"{redirect_uri}?{urlencode(error_params)}",
|
||||
status_code=status.HTTP_302_FOUND,
|
||||
)
|
||||
|
||||
# Success
|
||||
success_params = {"code": code}
|
||||
if state:
|
||||
success_params["state"] = state
|
||||
|
||||
return RedirectResponse(
|
||||
url=f"{redirect_uri}?{urlencode(success_params)}",
|
||||
status_code=status.HTTP_302_FOUND,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Token Endpoint
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.post(
|
||||
"/provider/token",
|
||||
response_model=OAuthTokenResponse,
|
||||
summary="Token Endpoint",
|
||||
description="""
|
||||
OAuth 2.0 Token Endpoint.
|
||||
|
||||
Supports:
|
||||
- authorization_code: Exchange code for tokens
|
||||
- refresh_token: Refresh access token
|
||||
|
||||
Client authentication:
|
||||
- Confidential clients: client_secret (Basic auth or POST body)
|
||||
- Public clients: No secret, but PKCE code_verifier required
|
||||
""",
|
||||
operation_id="oauth_provider_token",
|
||||
tags=["OAuth Provider"],
|
||||
)
|
||||
@limiter.limit("60/minute")
|
||||
async def token(
|
||||
request: Request,
|
||||
grant_type: str = Form(..., description="Grant type"),
|
||||
code: str | None = Form(default=None, description="Authorization code"),
|
||||
redirect_uri: str | None = Form(default=None, description="Redirect URI"),
|
||||
client_id: str | None = Form(default=None, description="Client ID"),
|
||||
client_secret: str | None = Form(default=None, description="Client secret"),
|
||||
code_verifier: str | None = Form(default=None, description="PKCE code verifier"),
|
||||
refresh_token: str | None = Form(default=None, description="Refresh token"),
|
||||
scope: str | None = Form(default=None, description="Scope (for refresh)"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_: None = Depends(require_provider_enabled),
|
||||
) -> OAuthTokenResponse:
|
||||
"""Token endpoint - exchange code for tokens or refresh."""
|
||||
# Extract client credentials from Basic auth if not in body
|
||||
if not client_id:
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if auth_header.startswith("Basic "):
|
||||
import base64
|
||||
|
||||
try:
|
||||
decoded = base64.b64decode(auth_header[6:]).decode()
|
||||
client_id, client_secret = decoded.split(":", 1)
|
||||
except Exception as e:
|
||||
# Log malformed Basic auth for security monitoring
|
||||
logger.warning(
|
||||
"Malformed Basic auth header in token request: %s", type(e).__name__
|
||||
)
|
||||
# Fall back to form body
|
||||
|
||||
if not client_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="invalid_client: client_id required",
|
||||
headers={"WWW-Authenticate": "Basic"},
|
||||
)
|
||||
|
||||
# Get device info
|
||||
device_info = request.headers.get("User-Agent", "")[:500]
|
||||
ip_address = get_remote_address(request)
|
||||
|
||||
try:
|
||||
if grant_type == "authorization_code":
|
||||
if not code:
|
||||
raise provider_service.InvalidRequestError("code required")
|
||||
if not redirect_uri:
|
||||
raise provider_service.InvalidRequestError("redirect_uri required")
|
||||
|
||||
result = await provider_service.exchange_authorization_code(
|
||||
db=db,
|
||||
code=code,
|
||||
client_id=client_id,
|
||||
redirect_uri=redirect_uri,
|
||||
code_verifier=code_verifier,
|
||||
client_secret=client_secret,
|
||||
device_info=device_info,
|
||||
ip_address=ip_address,
|
||||
)
|
||||
|
||||
elif grant_type == "refresh_token":
|
||||
if not refresh_token:
|
||||
raise provider_service.InvalidRequestError("refresh_token required")
|
||||
|
||||
result = await provider_service.refresh_tokens(
|
||||
db=db,
|
||||
refresh_token=refresh_token,
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
scope=scope,
|
||||
device_info=device_info,
|
||||
ip_address=ip_address,
|
||||
)
|
||||
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="unsupported_grant_type: Must be authorization_code or refresh_token",
|
||||
)
|
||||
|
||||
return OAuthTokenResponse(**result)
|
||||
|
||||
except provider_service.InvalidClientError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=f"{e.error}: {e.error_description}",
|
||||
headers={"WWW-Authenticate": "Basic"},
|
||||
)
|
||||
except provider_service.OAuthProviderError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"{e.error}: {e.error_description}",
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Token Revocation (RFC 7009)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.post(
|
||||
"/provider/revoke",
|
||||
status_code=status.HTTP_200_OK,
|
||||
summary="Token Revocation Endpoint",
|
||||
description="""
|
||||
OAuth 2.0 Token Revocation Endpoint (RFC 7009).
|
||||
|
||||
Revokes an access token or refresh token.
|
||||
Always returns 200 OK (even if token is invalid) per spec.
|
||||
""",
|
||||
operation_id="oauth_provider_revoke",
|
||||
tags=["OAuth Provider"],
|
||||
)
|
||||
@limiter.limit("30/minute")
|
||||
async def revoke(
|
||||
request: Request,
|
||||
token: str = Form(..., description="Token to revoke"),
|
||||
token_type_hint: str | None = Form(
|
||||
default=None, description="Token type hint (access_token, refresh_token)"
|
||||
),
|
||||
client_id: str | None = Form(default=None, description="Client ID"),
|
||||
client_secret: str | None = Form(default=None, description="Client secret"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_: None = Depends(require_provider_enabled),
|
||||
) -> dict[str, str]:
|
||||
"""Revoke a token."""
|
||||
# Extract client credentials from Basic auth if not in body
|
||||
if not client_id:
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if auth_header.startswith("Basic "):
|
||||
import base64
|
||||
|
||||
try:
|
||||
decoded = base64.b64decode(auth_header[6:]).decode()
|
||||
client_id, client_secret = decoded.split(":", 1)
|
||||
except Exception as e:
|
||||
# Log malformed Basic auth for security monitoring
|
||||
logger.warning(
|
||||
"Malformed Basic auth header in revoke request: %s",
|
||||
type(e).__name__,
|
||||
)
|
||||
# Fall back to form body
|
||||
|
||||
try:
|
||||
await provider_service.revoke_token(
|
||||
db=db,
|
||||
token=token,
|
||||
token_type_hint=token_type_hint,
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
)
|
||||
except provider_service.InvalidClientError:
|
||||
# Per RFC 7009, we should return 200 OK even for errors
|
||||
# But client authentication errors can return 401
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="invalid_client",
|
||||
headers={"WWW-Authenticate": "Basic"},
|
||||
)
|
||||
except Exception as e:
|
||||
# Log but don't expose errors per RFC 7009
|
||||
logger.warning("Token revocation error: %s", e)
|
||||
|
||||
# Always return 200 OK per RFC 7009
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Token Introspection (RFC 7662)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.post(
|
||||
"/provider/introspect",
|
||||
response_model=OAuthTokenIntrospectionResponse,
|
||||
summary="Token Introspection Endpoint",
|
||||
description="""
|
||||
OAuth 2.0 Token Introspection Endpoint (RFC 7662).
|
||||
|
||||
Allows resource servers to query the authorization server
|
||||
to determine the active state and metadata of a token.
|
||||
""",
|
||||
operation_id="oauth_provider_introspect",
|
||||
tags=["OAuth Provider"],
|
||||
)
|
||||
@limiter.limit("120/minute")
|
||||
async def introspect(
|
||||
request: Request,
|
||||
token: str = Form(..., description="Token to introspect"),
|
||||
token_type_hint: str | None = Form(
|
||||
default=None, description="Token type hint (access_token, refresh_token)"
|
||||
),
|
||||
client_id: str | None = Form(default=None, description="Client ID"),
|
||||
client_secret: str | None = Form(default=None, description="Client secret"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_: None = Depends(require_provider_enabled),
|
||||
) -> OAuthTokenIntrospectionResponse:
|
||||
"""Introspect a token."""
|
||||
# Extract client credentials from Basic auth if not in body
|
||||
if not client_id:
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if auth_header.startswith("Basic "):
|
||||
import base64
|
||||
|
||||
try:
|
||||
decoded = base64.b64decode(auth_header[6:]).decode()
|
||||
client_id, client_secret = decoded.split(":", 1)
|
||||
except Exception as e:
|
||||
# Log malformed Basic auth for security monitoring
|
||||
logger.warning(
|
||||
"Malformed Basic auth header in introspect request: %s",
|
||||
type(e).__name__,
|
||||
)
|
||||
# Fall back to form body
|
||||
|
||||
try:
|
||||
result = await provider_service.introspect_token(
|
||||
db=db,
|
||||
token=token,
|
||||
token_type_hint=token_type_hint,
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
)
|
||||
return OAuthTokenIntrospectionResponse(**result)
|
||||
except provider_service.InvalidClientError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="invalid_client",
|
||||
headers={"WWW-Authenticate": "Basic"},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Token introspection error: %s", e)
|
||||
return OAuthTokenIntrospectionResponse(active=False) # pyright: ignore[reportCallIssue]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Client Management (Admin)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.post(
|
||||
"/provider/clients",
|
||||
response_model=dict,
|
||||
summary="Register OAuth Client",
|
||||
description="""
|
||||
Register a new OAuth client (admin only).
|
||||
|
||||
Creates an MCP client that can authenticate against this API.
|
||||
Returns client_id and client_secret (for confidential clients).
|
||||
|
||||
**Important:** Store the client_secret securely - it won't be shown again!
|
||||
""",
|
||||
operation_id="register_oauth_client",
|
||||
tags=["OAuth Provider Admin"],
|
||||
)
|
||||
async def register_client(
|
||||
client_name: str = Form(..., description="Client application name"),
|
||||
redirect_uris: str = Form(..., description="Comma-separated redirect URIs"),
|
||||
client_type: str = Form(default="public", description="public or confidential"),
|
||||
scopes: str = Form(
|
||||
default="openid profile email",
|
||||
description="Allowed scopes (space-separated)",
|
||||
),
|
||||
mcp_server_url: str | None = Form(default=None, description="MCP server URL"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_: None = Depends(require_provider_enabled),
|
||||
current_user: User = Depends(get_current_superuser),
|
||||
) -> dict:
|
||||
"""Register a new OAuth client."""
|
||||
# Parse redirect URIs
|
||||
uris = [uri.strip() for uri in redirect_uris.split(",") if uri.strip()]
|
||||
if not uris:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="At least one redirect_uri is required",
|
||||
)
|
||||
|
||||
# Parse scopes
|
||||
allowed_scopes = [s.strip() for s in scopes.split() if s.strip()]
|
||||
|
||||
client_data = OAuthClientCreate(
|
||||
client_name=client_name,
|
||||
client_description=None,
|
||||
redirect_uris=uris,
|
||||
allowed_scopes=allowed_scopes,
|
||||
client_type=client_type,
|
||||
)
|
||||
|
||||
client, secret = await provider_service.register_client(db, client_data)
|
||||
|
||||
# Update MCP server URL if provided
|
||||
if mcp_server_url:
|
||||
client.mcp_server_url = mcp_server_url
|
||||
await db.commit()
|
||||
|
||||
result = {
|
||||
"client_id": client.client_id,
|
||||
"client_name": client.client_name,
|
||||
"client_type": client.client_type,
|
||||
"redirect_uris": client.redirect_uris,
|
||||
"allowed_scopes": client.allowed_scopes,
|
||||
}
|
||||
|
||||
if secret:
|
||||
result["client_secret"] = secret
|
||||
result["warning"] = (
|
||||
"Store the client_secret securely! It will not be shown again."
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.get(
|
||||
"/provider/clients",
|
||||
response_model=list[OAuthClientResponse],
|
||||
summary="List OAuth Clients",
|
||||
description="List all registered OAuth clients (admin only).",
|
||||
operation_id="list_oauth_clients",
|
||||
tags=["OAuth Provider Admin"],
|
||||
)
|
||||
async def list_clients(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_: None = Depends(require_provider_enabled),
|
||||
current_user: User = Depends(get_current_superuser),
|
||||
) -> list[OAuthClientResponse]:
|
||||
"""List all OAuth clients."""
|
||||
clients = await provider_service.list_clients(db)
|
||||
return [OAuthClientResponse.model_validate(c) for c in clients]
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/provider/clients/{client_id}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
summary="Delete OAuth Client",
|
||||
description="Delete an OAuth client (admin only). Revokes all tokens.",
|
||||
operation_id="delete_oauth_client",
|
||||
tags=["OAuth Provider Admin"],
|
||||
)
|
||||
async def delete_client(
|
||||
client_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_: None = Depends(require_provider_enabled),
|
||||
current_user: User = Depends(get_current_superuser),
|
||||
) -> None:
|
||||
"""Delete an OAuth client."""
|
||||
client = await provider_service.get_client(db, client_id)
|
||||
if not client:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Client not found",
|
||||
)
|
||||
|
||||
await provider_service.delete_client_by_id(db, client_id=client_id)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# User Consent Management
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.get(
|
||||
"/provider/consents",
|
||||
summary="List My Consents",
|
||||
description="List OAuth applications the current user has authorized.",
|
||||
operation_id="list_my_oauth_consents",
|
||||
tags=["OAuth Provider"],
|
||||
)
|
||||
async def list_my_consents(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_: None = Depends(require_provider_enabled),
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
) -> list[dict]:
|
||||
"""List applications the user has authorized."""
|
||||
return await provider_service.list_user_consents(db, user_id=current_user.id)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/provider/consents/{client_id}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
summary="Revoke My Consent",
|
||||
description="Revoke authorization for an OAuth application. Also revokes all tokens.",
|
||||
operation_id="revoke_my_oauth_consent",
|
||||
tags=["OAuth Provider"],
|
||||
)
|
||||
async def revoke_my_consent(
|
||||
client_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_: None = Depends(require_provider_enabled),
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
) -> None:
|
||||
"""Revoke consent for an application."""
|
||||
revoked = await provider_service.revoke_consent(db, current_user.id, client_id)
|
||||
if not revoked:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="No consent found for this client",
|
||||
)
|
||||
125
backend/app/api/routes/organizations.py
Normal file → Executable file
125
backend/app/api/routes/organizations.py
Normal file → Executable file
@@ -4,31 +4,29 @@ Organization endpoints for regular users.
|
||||
|
||||
These endpoints allow users to view and manage organizations they belong to.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, status
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.dependencies.auth import get_current_user
|
||||
from app.api.dependencies.permissions import require_org_admin, require_org_membership, get_current_org_role
|
||||
from app.api.dependencies.permissions import require_org_admin, require_org_membership
|
||||
from app.core.database import get_db
|
||||
from app.crud.organization import organization as organization_crud
|
||||
from app.models.user import User
|
||||
from app.models.user_organization import OrganizationRole
|
||||
from app.schemas.organizations import (
|
||||
OrganizationResponse,
|
||||
OrganizationMemberResponse,
|
||||
OrganizationUpdate
|
||||
)
|
||||
from app.schemas.common import (
|
||||
PaginationParams,
|
||||
PaginatedResponse,
|
||||
MessageResponse,
|
||||
create_pagination_meta
|
||||
PaginationParams,
|
||||
create_pagination_meta,
|
||||
)
|
||||
from app.core.exceptions import NotFoundError, AuthorizationError, ErrorCode
|
||||
from app.schemas.organizations import (
|
||||
OrganizationMemberResponse,
|
||||
OrganizationResponse,
|
||||
OrganizationUpdate,
|
||||
)
|
||||
from app.services.organization_service import organization_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -37,37 +35,32 @@ router = APIRouter()
|
||||
|
||||
@router.get(
|
||||
"/me",
|
||||
response_model=List[OrganizationResponse],
|
||||
response_model=list[OrganizationResponse],
|
||||
summary="Get My Organizations",
|
||||
description="Get all organizations the current user belongs to",
|
||||
operation_id="get_my_organizations"
|
||||
operation_id="get_my_organizations",
|
||||
)
|
||||
def get_my_organizations(
|
||||
async def get_my_organizations(
|
||||
is_active: bool = Query(True, description="Filter by active membership"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Get all organizations the current user belongs to.
|
||||
|
||||
Returns organizations with member count for each.
|
||||
Uses optimized single query to avoid N+1 problem.
|
||||
"""
|
||||
try:
|
||||
orgs = organization_crud.get_user_organizations(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
is_active=is_active
|
||||
# Get all org data in single query with JOIN and subquery
|
||||
orgs_data = await organization_service.get_user_organizations_with_details(
|
||||
db, user_id=current_user.id, is_active=is_active
|
||||
)
|
||||
|
||||
# Add member count and role to each organization
|
||||
# Transform to response objects
|
||||
orgs_with_data = []
|
||||
for org in orgs:
|
||||
role = organization_crud.get_user_role_in_org(
|
||||
db,
|
||||
user_id=current_user.id,
|
||||
organization_id=org.id
|
||||
)
|
||||
|
||||
for item in orgs_data:
|
||||
org = item["organization"]
|
||||
org_dict = {
|
||||
"id": org.id,
|
||||
"name": org.name,
|
||||
@@ -77,14 +70,14 @@ def get_my_organizations(
|
||||
"settings": org.settings,
|
||||
"created_at": org.created_at,
|
||||
"updated_at": org.updated_at,
|
||||
"member_count": organization_crud.get_member_count(db, organization_id=org.id)
|
||||
"member_count": item["member_count"],
|
||||
}
|
||||
orgs_with_data.append(OrganizationResponse(**org_dict))
|
||||
|
||||
return orgs_with_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user organizations: {str(e)}", exc_info=True)
|
||||
logger.exception("Error getting user organizations: %s", e)
|
||||
raise
|
||||
|
||||
|
||||
@@ -93,12 +86,12 @@ def get_my_organizations(
|
||||
response_model=OrganizationResponse,
|
||||
summary="Get Organization Details",
|
||||
description="Get details of an organization the user belongs to",
|
||||
operation_id="get_organization"
|
||||
operation_id="get_organization",
|
||||
)
|
||||
def get_organization(
|
||||
async def get_organization(
|
||||
organization_id: UUID,
|
||||
current_user: User = Depends(require_org_membership),
|
||||
db: Session = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Get details of a specific organization.
|
||||
@@ -106,13 +99,7 @@ def get_organization(
|
||||
User must be a member of the organization.
|
||||
"""
|
||||
try:
|
||||
org = organization_crud.get(db, id=organization_id)
|
||||
if not org:
|
||||
raise NotFoundError(
|
||||
detail=f"Organization {organization_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND
|
||||
)
|
||||
|
||||
org = await organization_service.get_organization(db, str(organization_id))
|
||||
org_dict = {
|
||||
"id": org.id,
|
||||
"name": org.name,
|
||||
@@ -122,14 +109,14 @@ def get_organization(
|
||||
"settings": org.settings,
|
||||
"created_at": org.created_at,
|
||||
"updated_at": org.updated_at,
|
||||
"member_count": organization_crud.get_member_count(db, organization_id=org.id)
|
||||
"member_count": await organization_service.get_member_count(
|
||||
db, organization_id=org.id
|
||||
),
|
||||
}
|
||||
return OrganizationResponse(**org_dict)
|
||||
|
||||
except NotFoundError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting organization: {str(e)}", exc_info=True)
|
||||
logger.exception("Error getting organization: %s", e)
|
||||
raise
|
||||
|
||||
|
||||
@@ -138,14 +125,14 @@ def get_organization(
|
||||
response_model=PaginatedResponse[OrganizationMemberResponse],
|
||||
summary="Get Organization Members",
|
||||
description="Get all members of an organization (members can view)",
|
||||
operation_id="get_organization_members"
|
||||
operation_id="get_organization_members",
|
||||
)
|
||||
def get_organization_members(
|
||||
async def get_organization_members(
|
||||
organization_id: UUID,
|
||||
pagination: PaginationParams = Depends(),
|
||||
is_active: bool = Query(True, description="Filter by active status"),
|
||||
current_user: User = Depends(require_org_membership),
|
||||
db: Session = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Get all members of an organization.
|
||||
@@ -153,12 +140,12 @@ def get_organization_members(
|
||||
User must be a member of the organization to view members.
|
||||
"""
|
||||
try:
|
||||
members, total = organization_crud.get_organization_members(
|
||||
members, total = await organization_service.get_organization_members(
|
||||
db,
|
||||
organization_id=organization_id,
|
||||
skip=pagination.offset,
|
||||
limit=pagination.limit,
|
||||
is_active=is_active
|
||||
is_active=is_active,
|
||||
)
|
||||
|
||||
member_responses = [OrganizationMemberResponse(**member) for member in members]
|
||||
@@ -167,13 +154,13 @@ def get_organization_members(
|
||||
total=total,
|
||||
page=pagination.page,
|
||||
limit=pagination.limit,
|
||||
items_count=len(member_responses)
|
||||
items_count=len(member_responses),
|
||||
)
|
||||
|
||||
return PaginatedResponse(data=member_responses, pagination=pagination_meta)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting organization members: {str(e)}", exc_info=True)
|
||||
logger.exception("Error getting organization members: %s", e)
|
||||
raise
|
||||
|
||||
|
||||
@@ -182,13 +169,13 @@ def get_organization_members(
|
||||
response_model=OrganizationResponse,
|
||||
summary="Update Organization",
|
||||
description="Update organization details (admin/owner only)",
|
||||
operation_id="update_organization"
|
||||
operation_id="update_organization",
|
||||
)
|
||||
def update_organization(
|
||||
async def update_organization(
|
||||
organization_id: UUID,
|
||||
org_in: OrganizationUpdate,
|
||||
current_user: User = Depends(require_org_admin),
|
||||
db: Session = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Update organization details.
|
||||
@@ -196,15 +183,13 @@ def update_organization(
|
||||
Requires owner or admin role in the organization.
|
||||
"""
|
||||
try:
|
||||
org = organization_crud.get(db, id=organization_id)
|
||||
if not org:
|
||||
raise NotFoundError(
|
||||
detail=f"Organization {organization_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND
|
||||
)
|
||||
|
||||
updated_org = organization_crud.update(db, db_obj=org, obj_in=org_in)
|
||||
logger.info(f"User {current_user.email} updated organization {updated_org.name}")
|
||||
org = await organization_service.get_organization(db, str(organization_id))
|
||||
updated_org = await organization_service.update_organization(
|
||||
db, org=org, obj_in=org_in
|
||||
)
|
||||
logger.info(
|
||||
"User %s updated organization %s", current_user.email, updated_org.name
|
||||
)
|
||||
|
||||
org_dict = {
|
||||
"id": updated_org.id,
|
||||
@@ -215,12 +200,12 @@ def update_organization(
|
||||
"settings": updated_org.settings,
|
||||
"created_at": updated_org.created_at,
|
||||
"updated_at": updated_org.updated_at,
|
||||
"member_count": organization_crud.get_member_count(db, organization_id=updated_org.id)
|
||||
"member_count": await organization_service.get_member_count(
|
||||
db, organization_id=updated_org.id
|
||||
),
|
||||
}
|
||||
return OrganizationResponse(**org_dict)
|
||||
|
||||
except NotFoundError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating organization: {str(e)}", exc_info=True)
|
||||
logger.exception("Error updating organization: %s", e)
|
||||
raise
|
||||
|
||||
122
backend/app/api/routes/sessions.py
Normal file → Executable file
122
backend/app/api/routes/sessions.py
Normal file → Executable file
@@ -3,23 +3,24 @@ Session management endpoints.
|
||||
|
||||
Allows users to view and manage their active sessions across devices.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, List
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.dependencies.auth import get_current_user
|
||||
from app.core.database import get_db
|
||||
from app.core.auth import decode_token
|
||||
from app.core.database import get_db
|
||||
from app.core.exceptions import AuthorizationError, ErrorCode, NotFoundError
|
||||
from app.models.user import User
|
||||
from app.schemas.sessions import SessionResponse, SessionListResponse
|
||||
from app.schemas.common import MessageResponse
|
||||
from app.crud.session import session as session_crud
|
||||
from app.core.exceptions import NotFoundError, AuthorizationError, ErrorCode
|
||||
from app.schemas.sessions import SessionListResponse, SessionResponse
|
||||
from app.services.session_service import session_service
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -39,13 +40,13 @@ limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
**Rate Limit**: 30 requests/minute
|
||||
""",
|
||||
operation_id="list_my_sessions"
|
||||
operation_id="list_my_sessions",
|
||||
)
|
||||
@limiter.limit("30/minute")
|
||||
def list_my_sessions(
|
||||
async def list_my_sessions(
|
||||
request: Request,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
List all active sessions for the current user.
|
||||
@@ -59,23 +60,21 @@ def list_my_sessions(
|
||||
"""
|
||||
try:
|
||||
# Get all active sessions for user
|
||||
sessions = session_crud.get_user_sessions(
|
||||
db,
|
||||
user_id=str(current_user.id),
|
||||
active_only=True
|
||||
sessions = await session_service.get_user_sessions(
|
||||
db, user_id=str(current_user.id), active_only=True
|
||||
)
|
||||
|
||||
# Try to identify current session from Authorization header
|
||||
current_session_jti = None
|
||||
auth_header = request.headers.get("authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
try:
|
||||
access_token = auth_header.split(" ")[1]
|
||||
token_payload = decode_token(access_token)
|
||||
decode_token(access_token)
|
||||
# Note: Access tokens don't have JTI by default, but we can try
|
||||
# For now, we'll mark current based on most recent activity
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
# Optional token parsing - silently ignore failures
|
||||
logger.debug("Failed to decode access token for session marking: %s", e)
|
||||
|
||||
# Convert to response format
|
||||
session_responses = []
|
||||
@@ -90,22 +89,25 @@ def list_my_sessions(
|
||||
last_used_at=s.last_used_at,
|
||||
created_at=s.created_at,
|
||||
expires_at=s.expires_at,
|
||||
is_current=(s == sessions[0] if sessions else False) # Most recent = current
|
||||
is_current=(
|
||||
s == sessions[0] if sessions else False
|
||||
), # Most recent = current
|
||||
)
|
||||
session_responses.append(session_response)
|
||||
|
||||
logger.info(f"User {current_user.id} listed {len(session_responses)} active sessions")
|
||||
logger.info(
|
||||
"User %s listed %s active sessions", current_user.id, len(session_responses)
|
||||
)
|
||||
|
||||
return SessionListResponse(
|
||||
sessions=session_responses,
|
||||
total=len(session_responses)
|
||||
sessions=session_responses, total=len(session_responses)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing sessions for user {current_user.id}: {str(e)}", exc_info=True)
|
||||
logger.exception("Error listing sessions for user %s: %s", current_user.id, e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to retrieve sessions"
|
||||
detail="Failed to retrieve sessions",
|
||||
)
|
||||
|
||||
|
||||
@@ -122,14 +124,14 @@ def list_my_sessions(
|
||||
|
||||
**Rate Limit**: 10 requests/minute
|
||||
""",
|
||||
operation_id="revoke_session"
|
||||
operation_id="revoke_session",
|
||||
)
|
||||
@limiter.limit("10/minute")
|
||||
def revoke_session(
|
||||
async def revoke_session(
|
||||
request: Request,
|
||||
session_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Revoke a specific session by ID.
|
||||
@@ -144,45 +146,49 @@ def revoke_session(
|
||||
"""
|
||||
try:
|
||||
# Get the session
|
||||
session = session_crud.get(db, id=str(session_id))
|
||||
session = await session_service.get_session(db, str(session_id))
|
||||
|
||||
if not session:
|
||||
raise NotFoundError(
|
||||
message=f"Session {session_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
# Verify session belongs to current user
|
||||
if str(session.user_id) != str(current_user.id):
|
||||
logger.warning(
|
||||
f"User {current_user.id} attempted to revoke session {session_id} "
|
||||
f"belonging to user {session.user_id}"
|
||||
"User %s attempted to revoke session %s belonging to user %s",
|
||||
current_user.id,
|
||||
session_id,
|
||||
session.user_id,
|
||||
)
|
||||
raise AuthorizationError(
|
||||
message="You can only revoke your own sessions",
|
||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS
|
||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
)
|
||||
|
||||
# Deactivate the session
|
||||
session_crud.deactivate(db, session_id=str(session_id))
|
||||
await session_service.deactivate(db, session_id=str(session_id))
|
||||
|
||||
logger.info(
|
||||
f"User {current_user.id} revoked session {session_id} "
|
||||
f"({session.device_name})"
|
||||
"User %s revoked session %s (%s)",
|
||||
current_user.id,
|
||||
session_id,
|
||||
session.device_name,
|
||||
)
|
||||
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message=f"Session revoked: {session.device_name or 'Unknown device'}"
|
||||
message=f"Session revoked: {session.device_name or 'Unknown device'}",
|
||||
)
|
||||
|
||||
except (NotFoundError, AuthorizationError):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error revoking session {session_id}: {str(e)}", exc_info=True)
|
||||
logger.exception("Error revoking session %s: %s", session_id, e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to revoke session"
|
||||
detail="Failed to revoke session",
|
||||
)
|
||||
|
||||
|
||||
@@ -198,13 +204,13 @@ def revoke_session(
|
||||
|
||||
**Rate Limit**: 5 requests/minute
|
||||
""",
|
||||
operation_id="cleanup_expired_sessions"
|
||||
operation_id="cleanup_expired_sessions",
|
||||
)
|
||||
@limiter.limit("5/minute")
|
||||
def cleanup_expired_sessions(
|
||||
async def cleanup_expired_sessions(
|
||||
request: Request,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Cleanup expired sessions for the current user.
|
||||
@@ -217,35 +223,25 @@ def cleanup_expired_sessions(
|
||||
Success message with count of sessions cleaned
|
||||
"""
|
||||
try:
|
||||
from datetime import datetime, timezone
|
||||
|
||||
# Get all sessions for user
|
||||
all_sessions = session_crud.get_user_sessions(
|
||||
db,
|
||||
user_id=str(current_user.id),
|
||||
active_only=False
|
||||
# Use optimized bulk DELETE instead of N individual deletes
|
||||
deleted_count = await session_service.cleanup_expired_for_user(
|
||||
db, user_id=str(current_user.id)
|
||||
)
|
||||
|
||||
# Delete expired and inactive sessions
|
||||
deleted_count = 0
|
||||
for s in all_sessions:
|
||||
if not s.is_active and s.expires_at < datetime.now(timezone.utc):
|
||||
db.delete(s)
|
||||
deleted_count += 1
|
||||
|
||||
db.commit()
|
||||
|
||||
logger.info(f"User {current_user.id} cleaned up {deleted_count} expired sessions")
|
||||
logger.info(
|
||||
"User %s cleaned up %s expired sessions", current_user.id, deleted_count
|
||||
)
|
||||
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message=f"Cleaned up {deleted_count} expired sessions"
|
||||
success=True, message=f"Cleaned up {deleted_count} expired sessions"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up sessions for user {current_user.id}: {str(e)}", exc_info=True)
|
||||
db.rollback()
|
||||
logger.exception(
|
||||
"Error cleaning up sessions for user %s: %s", current_user.id, e
|
||||
)
|
||||
await db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to cleanup sessions"
|
||||
detail="Failed to cleanup sessions",
|
||||
)
|
||||
|
||||
195
backend/app/api/routes/users.py
Normal file → Executable file
195
backend/app/api/routes/users.py
Normal file → Executable file
@@ -1,33 +1,30 @@
|
||||
"""
|
||||
User management endpoints for CRUD operations.
|
||||
User management endpoints for database operations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, status, Request
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi import APIRouter, Depends, Query, Request, status
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.dependencies.auth import get_current_user, get_current_superuser
|
||||
from app.api.dependencies.auth import get_current_superuser, get_current_user
|
||||
from app.core.database import get_db
|
||||
from app.crud.user import user as user_crud
|
||||
from app.core.exceptions import AuthorizationError, ErrorCode
|
||||
from app.models.user import User
|
||||
from app.schemas.users import UserResponse, UserUpdate, PasswordChange
|
||||
from app.schemas.common import (
|
||||
PaginationParams,
|
||||
PaginatedResponse,
|
||||
MessageResponse,
|
||||
PaginatedResponse,
|
||||
PaginationParams,
|
||||
SortParams,
|
||||
create_pagination_meta
|
||||
)
|
||||
from app.services.auth_service import AuthService, AuthenticationError
|
||||
from app.core.exceptions import (
|
||||
NotFoundError,
|
||||
AuthorizationError,
|
||||
ErrorCode
|
||||
create_pagination_meta,
|
||||
)
|
||||
from app.schemas.users import PasswordChange, UserResponse, UserUpdate
|
||||
from app.services.auth_service import AuthenticationError, AuthService
|
||||
from app.services.user_service import user_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -50,15 +47,15 @@ limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
**Rate Limit**: 60 requests/minute
|
||||
""",
|
||||
operation_id="list_users"
|
||||
operation_id="list_users",
|
||||
)
|
||||
def list_users(
|
||||
async def list_users(
|
||||
pagination: PaginationParams = Depends(),
|
||||
sort: SortParams = Depends(),
|
||||
is_active: Optional[bool] = Query(None, description="Filter by active status"),
|
||||
is_superuser: Optional[bool] = Query(None, description="Filter by superuser status"),
|
||||
is_active: bool | None = Query(None, description="Filter by active status"),
|
||||
is_superuser: bool | None = Query(None, description="Filter by superuser status"),
|
||||
current_user: User = Depends(get_current_superuser),
|
||||
db: Session = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
List all users with pagination, filtering, and sorting.
|
||||
@@ -74,13 +71,13 @@ def list_users(
|
||||
filters["is_superuser"] = is_superuser
|
||||
|
||||
# Get paginated users with total count
|
||||
users, total = user_crud.get_multi_with_total(
|
||||
users, total = await user_service.list_users(
|
||||
db,
|
||||
skip=pagination.offset,
|
||||
limit=pagination.limit,
|
||||
sort_by=sort.sort_by,
|
||||
sort_order=sort.sort_order.value if sort.sort_order else "asc",
|
||||
filters=filters if filters else None
|
||||
filters=filters if filters else None,
|
||||
)
|
||||
|
||||
# Create pagination metadata
|
||||
@@ -88,15 +85,12 @@ def list_users(
|
||||
total=total,
|
||||
page=pagination.page,
|
||||
limit=pagination.limit,
|
||||
items_count=len(users)
|
||||
items_count=len(users),
|
||||
)
|
||||
|
||||
return PaginatedResponse(
|
||||
data=users,
|
||||
pagination=pagination_meta
|
||||
)
|
||||
return PaginatedResponse(data=users, pagination=pagination_meta)
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing users: {str(e)}", exc_info=True)
|
||||
logger.exception("Error listing users: %s", e)
|
||||
raise
|
||||
|
||||
|
||||
@@ -111,10 +105,10 @@ def list_users(
|
||||
|
||||
**Rate Limit**: 60 requests/minute
|
||||
""",
|
||||
operation_id="get_current_user_profile"
|
||||
operation_id="get_current_user_profile",
|
||||
)
|
||||
def get_current_user_profile(
|
||||
current_user: User = Depends(get_current_user)
|
||||
async def get_current_user_profile(
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> Any:
|
||||
"""Get current user's profile."""
|
||||
return current_user
|
||||
@@ -133,39 +127,29 @@ def get_current_user_profile(
|
||||
|
||||
**Rate Limit**: 30 requests/minute
|
||||
""",
|
||||
operation_id="update_current_user"
|
||||
operation_id="update_current_user",
|
||||
)
|
||||
def update_current_user(
|
||||
async def update_current_user(
|
||||
user_update: UserUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Update current user's profile.
|
||||
|
||||
Users cannot elevate their own permissions (is_superuser).
|
||||
Users cannot elevate their own permissions (protected by UserUpdate schema validator).
|
||||
"""
|
||||
# Prevent users from making themselves superuser
|
||||
if getattr(user_update, 'is_superuser', None) is not None:
|
||||
logger.warning(f"User {current_user.id} attempted to modify is_superuser field")
|
||||
raise AuthorizationError(
|
||||
message="Cannot modify superuser status",
|
||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS
|
||||
)
|
||||
|
||||
try:
|
||||
updated_user = user_crud.update(
|
||||
db,
|
||||
db_obj=current_user,
|
||||
obj_in=user_update
|
||||
updated_user = await user_service.update_user(
|
||||
db, user=current_user, obj_in=user_update
|
||||
)
|
||||
logger.info(f"User {current_user.id} updated their profile")
|
||||
logger.info("User %s updated their profile", current_user.id)
|
||||
return updated_user
|
||||
except ValueError as e:
|
||||
logger.error(f"Error updating user {current_user.id}: {str(e)}")
|
||||
logger.error("Error updating user %s: %s", current_user.id, e)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error updating user {current_user.id}: {str(e)}", exc_info=True)
|
||||
logger.exception("Unexpected error updating user %s: %s", current_user.id, e)
|
||||
raise
|
||||
|
||||
|
||||
@@ -183,12 +167,12 @@ def update_current_user(
|
||||
|
||||
**Rate Limit**: 60 requests/minute
|
||||
""",
|
||||
operation_id="get_user_by_id"
|
||||
operation_id="get_user_by_id",
|
||||
)
|
||||
def get_user_by_id(
|
||||
async def get_user_by_id(
|
||||
user_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Get user by ID.
|
||||
@@ -198,21 +182,17 @@ def get_user_by_id(
|
||||
# Check permissions
|
||||
if str(user_id) != str(current_user.id) and not current_user.is_superuser:
|
||||
logger.warning(
|
||||
f"User {current_user.id} attempted to access user {user_id} without permission"
|
||||
"User %s attempted to access user %s without permission",
|
||||
current_user.id,
|
||||
user_id,
|
||||
)
|
||||
raise AuthorizationError(
|
||||
message="Not enough permissions to view this user",
|
||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS
|
||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
)
|
||||
|
||||
# Get user
|
||||
user = user_crud.get(db, id=str(user_id))
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
message=f"User with id {user_id} not found",
|
||||
error_code=ErrorCode.USER_NOT_FOUND
|
||||
)
|
||||
|
||||
user = await user_service.get_user(db, str(user_id))
|
||||
return user
|
||||
|
||||
|
||||
@@ -230,57 +210,46 @@ def get_user_by_id(
|
||||
|
||||
**Rate Limit**: 30 requests/minute
|
||||
""",
|
||||
operation_id="update_user"
|
||||
operation_id="update_user",
|
||||
)
|
||||
def update_user(
|
||||
async def update_user(
|
||||
user_id: UUID,
|
||||
user_update: UserUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Update user by ID.
|
||||
|
||||
Users can update their own profile. Superusers can update any profile.
|
||||
Regular users cannot modify is_superuser field.
|
||||
Superuser field modification is prevented by UserUpdate schema validator.
|
||||
"""
|
||||
# Check permissions
|
||||
is_own_profile = str(user_id) == str(current_user.id)
|
||||
|
||||
if not is_own_profile and not current_user.is_superuser:
|
||||
logger.warning(
|
||||
f"User {current_user.id} attempted to update user {user_id} without permission"
|
||||
"User %s attempted to update user %s without permission",
|
||||
current_user.id,
|
||||
user_id,
|
||||
)
|
||||
raise AuthorizationError(
|
||||
message="Not enough permissions to update this user",
|
||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS
|
||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
)
|
||||
|
||||
# Get user
|
||||
user = user_crud.get(db, id=str(user_id))
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
message=f"User with id {user_id} not found",
|
||||
error_code=ErrorCode.USER_NOT_FOUND
|
||||
)
|
||||
|
||||
# Prevent non-superusers from modifying superuser status
|
||||
if getattr(user_update, 'is_superuser', None) is not None and not current_user.is_superuser:
|
||||
logger.warning(f"User {current_user.id} attempted to modify is_superuser field")
|
||||
raise AuthorizationError(
|
||||
message="Cannot modify superuser status",
|
||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS
|
||||
)
|
||||
user = await user_service.get_user(db, str(user_id))
|
||||
|
||||
try:
|
||||
updated_user = user_crud.update(db, db_obj=user, obj_in=user_update)
|
||||
logger.info(f"User {user_id} updated by {current_user.id}")
|
||||
updated_user = await user_service.update_user(db, user=user, obj_in=user_update)
|
||||
logger.info("User %s updated by %s", user_id, current_user.id)
|
||||
return updated_user
|
||||
except ValueError as e:
|
||||
logger.error(f"Error updating user {user_id}: {str(e)}")
|
||||
logger.error("Error updating user %s: %s", user_id, e)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error updating user {user_id}: {str(e)}", exc_info=True)
|
||||
logger.exception("Unexpected error updating user %s: %s", user_id, e)
|
||||
raise
|
||||
|
||||
|
||||
@@ -297,14 +266,14 @@ def update_user(
|
||||
|
||||
**Rate Limit**: 5 requests/minute
|
||||
""",
|
||||
operation_id="change_current_user_password"
|
||||
operation_id="change_current_user_password",
|
||||
)
|
||||
@limiter.limit("5/minute")
|
||||
def change_current_user_password(
|
||||
async def change_current_user_password(
|
||||
request: Request,
|
||||
password_change: PasswordChange,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Change current user's password.
|
||||
@@ -312,27 +281,27 @@ def change_current_user_password(
|
||||
Requires current password for verification.
|
||||
"""
|
||||
try:
|
||||
success = AuthService.change_password(
|
||||
success = await AuthService.change_password(
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
current_password=password_change.current_password,
|
||||
new_password=password_change.new_password
|
||||
new_password=password_change.new_password,
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info(f"User {current_user.id} changed their password")
|
||||
logger.info("User %s changed their password", current_user.id)
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message="Password changed successfully"
|
||||
success=True, message="Password changed successfully"
|
||||
)
|
||||
except AuthenticationError as e:
|
||||
logger.warning(f"Failed password change attempt for user {current_user.id}: {str(e)}")
|
||||
logger.warning(
|
||||
"Failed password change attempt for user %s: %s", current_user.id, e
|
||||
)
|
||||
raise AuthorizationError(
|
||||
message=str(e),
|
||||
error_code=ErrorCode.INVALID_CREDENTIALS
|
||||
message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error changing password for user {current_user.id}: {str(e)}")
|
||||
logger.error("Error changing password for user %s: %s", current_user.id, e)
|
||||
raise
|
||||
|
||||
|
||||
@@ -351,12 +320,12 @@ def change_current_user_password(
|
||||
|
||||
**Note**: This performs a hard delete. Consider implementing soft deletes for production.
|
||||
""",
|
||||
operation_id="delete_user"
|
||||
operation_id="delete_user",
|
||||
)
|
||||
def delete_user(
|
||||
async def delete_user(
|
||||
user_id: UUID,
|
||||
current_user: User = Depends(get_current_superuser),
|
||||
db: Session = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Delete user by ID (superuser only).
|
||||
@@ -367,28 +336,22 @@ def delete_user(
|
||||
if str(user_id) == str(current_user.id):
|
||||
raise AuthorizationError(
|
||||
message="Cannot delete your own account",
|
||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS
|
||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
)
|
||||
|
||||
# Get user
|
||||
user = user_crud.get(db, id=str(user_id))
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
message=f"User with id {user_id} not found",
|
||||
error_code=ErrorCode.USER_NOT_FOUND
|
||||
)
|
||||
# Get user (raises NotFoundError if not found)
|
||||
await user_service.get_user(db, str(user_id))
|
||||
|
||||
try:
|
||||
# Use soft delete instead of hard delete
|
||||
user_crud.soft_delete(db, id=str(user_id))
|
||||
logger.info(f"User {user_id} soft-deleted by {current_user.id}")
|
||||
await user_service.soft_delete_user(db, str(user_id))
|
||||
logger.info("User %s soft-deleted by %s", user_id, current_user.id)
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message=f"User {user_id} deleted successfully"
|
||||
success=True, message=f"User {user_id} deleted successfully"
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"Error deleting user {user_id}: {str(e)}")
|
||||
logger.error("Error deleting user %s: %s", user_id, e)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error deleting user {user_id}: {str(e)}", exc_info=True)
|
||||
logger.exception("Unexpected error deleting user %s: %s", user_id, e)
|
||||
raise
|
||||
|
||||
@@ -1,53 +1,94 @@
|
||||
import logging
|
||||
logging.getLogger('passlib').setLevel(logging.ERROR)
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Dict, Optional, Union
|
||||
import asyncio
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
from jose import jwt, JWTError
|
||||
from passlib.context import CryptContext
|
||||
import bcrypt
|
||||
import jwt
|
||||
from jwt.exceptions import (
|
||||
ExpiredSignatureError,
|
||||
InvalidTokenError,
|
||||
MissingRequiredClaimError,
|
||||
)
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.core.config import settings
|
||||
from app.schemas.users import TokenData, TokenPayload
|
||||
|
||||
|
||||
# Password hashing context
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
# Custom exceptions for auth
|
||||
class AuthError(Exception):
|
||||
"""Base authentication error"""
|
||||
pass
|
||||
|
||||
|
||||
class TokenExpiredError(AuthError):
|
||||
"""Token has expired"""
|
||||
pass
|
||||
|
||||
|
||||
class TokenInvalidError(AuthError):
|
||||
"""Token is invalid"""
|
||||
pass
|
||||
|
||||
|
||||
class TokenMissingClaimError(AuthError):
|
||||
"""Token is missing a required claim"""
|
||||
pass
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""Verify a password against a hash."""
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
"""Verify a password against a bcrypt hash."""
|
||||
return bcrypt.checkpw(
|
||||
plain_password.encode("utf-8"), hashed_password.encode("utf-8")
|
||||
)
|
||||
|
||||
|
||||
def get_password_hash(password: str) -> str:
|
||||
"""Generate a password hash."""
|
||||
return pwd_context.hash(password)
|
||||
"""Generate a bcrypt password hash."""
|
||||
salt = bcrypt.gensalt()
|
||||
return bcrypt.hashpw(password.encode("utf-8"), salt).decode("utf-8")
|
||||
|
||||
|
||||
async def verify_password_async(plain_password: str, hashed_password: str) -> bool:
|
||||
"""
|
||||
Verify a password against a hash asynchronously.
|
||||
|
||||
Runs the CPU-intensive bcrypt operation in a thread pool to avoid
|
||||
blocking the event loop.
|
||||
|
||||
Args:
|
||||
plain_password: Plain text password to verify
|
||||
hashed_password: Hashed password to verify against
|
||||
|
||||
Returns:
|
||||
True if password matches, False otherwise
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(
|
||||
None, partial(verify_password, plain_password, hashed_password)
|
||||
)
|
||||
|
||||
|
||||
async def get_password_hash_async(password: str) -> str:
|
||||
"""
|
||||
Generate a password hash asynchronously.
|
||||
|
||||
Runs the CPU-intensive bcrypt operation in a thread pool to avoid
|
||||
blocking the event loop. This is especially important during user
|
||||
registration and password changes.
|
||||
|
||||
Args:
|
||||
password: Plain text password to hash
|
||||
|
||||
Returns:
|
||||
Hashed password string
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(None, get_password_hash, password)
|
||||
|
||||
|
||||
def create_access_token(
|
||||
subject: Union[str, Any],
|
||||
expires_delta: Optional[timedelta] = None,
|
||||
claims: Optional[Dict[str, Any]] = None
|
||||
subject: str | Any,
|
||||
expires_delta: timedelta | None = None,
|
||||
claims: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Create a JWT access token.
|
||||
@@ -61,17 +102,19 @@ def create_access_token(
|
||||
Encoded JWT token
|
||||
"""
|
||||
if expires_delta:
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
expire = datetime.now(UTC) + expires_delta
|
||||
else:
|
||||
expire = datetime.now(timezone.utc) + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
expire = datetime.now(UTC) + timedelta(
|
||||
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
)
|
||||
|
||||
# Base token data
|
||||
to_encode = {
|
||||
"sub": str(subject),
|
||||
"exp": expire,
|
||||
"iat": datetime.now(tz=timezone.utc),
|
||||
"iat": datetime.now(tz=UTC),
|
||||
"jti": str(uuid.uuid4()),
|
||||
"type": "access"
|
||||
"type": "access",
|
||||
}
|
||||
|
||||
# Add custom claims
|
||||
@@ -79,18 +122,11 @@ def create_access_token(
|
||||
to_encode.update(claims)
|
||||
|
||||
# Create the JWT
|
||||
encoded_jwt = jwt.encode(
|
||||
to_encode,
|
||||
settings.SECRET_KEY,
|
||||
algorithm=settings.ALGORITHM
|
||||
)
|
||||
|
||||
return encoded_jwt
|
||||
return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||||
|
||||
|
||||
def create_refresh_token(
|
||||
subject: Union[str, Any],
|
||||
expires_delta: Optional[timedelta] = None
|
||||
subject: str | Any, expires_delta: timedelta | None = None
|
||||
) -> str:
|
||||
"""
|
||||
Create a JWT refresh token.
|
||||
@@ -103,28 +139,22 @@ def create_refresh_token(
|
||||
Encoded JWT refresh token
|
||||
"""
|
||||
if expires_delta:
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
expire = datetime.now(UTC) + expires_delta
|
||||
else:
|
||||
expire = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
expire = datetime.now(UTC) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
|
||||
to_encode = {
|
||||
"sub": str(subject),
|
||||
"exp": expire,
|
||||
"iat": datetime.now(timezone.utc),
|
||||
"iat": datetime.now(UTC),
|
||||
"jti": str(uuid.uuid4()),
|
||||
"type": "refresh"
|
||||
"type": "refresh",
|
||||
}
|
||||
|
||||
encoded_jwt = jwt.encode(
|
||||
to_encode,
|
||||
settings.SECRET_KEY,
|
||||
algorithm=settings.ALGORITHM
|
||||
)
|
||||
|
||||
return encoded_jwt
|
||||
return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||||
|
||||
|
||||
def decode_token(token: str, verify_type: Optional[str] = None) -> TokenPayload:
|
||||
def decode_token(token: str, verify_type: str | None = None) -> TokenPayload:
|
||||
"""
|
||||
Decode and verify a JWT token.
|
||||
|
||||
@@ -141,12 +171,35 @@ def decode_token(token: str, verify_type: Optional[str] = None) -> TokenPayload:
|
||||
TokenMissingClaimError: If a required claim is missing
|
||||
"""
|
||||
try:
|
||||
# Decode token with strict algorithm validation
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.SECRET_KEY,
|
||||
algorithms=[settings.ALGORITHM]
|
||||
algorithms=[settings.ALGORITHM],
|
||||
options={
|
||||
"verify_signature": True,
|
||||
"verify_exp": True,
|
||||
"verify_iat": True,
|
||||
"require": ["exp", "sub", "iat"],
|
||||
},
|
||||
)
|
||||
|
||||
# SECURITY: Explicitly verify the algorithm to prevent algorithm confusion attacks
|
||||
# Decode header to check algorithm (without verification, just to inspect)
|
||||
header = jwt.get_unverified_header(token)
|
||||
token_algorithm = header.get("alg", "").upper()
|
||||
|
||||
# Reject weak or unexpected algorithms
|
||||
# NOTE: These are defensive checks that provide defense-in-depth.
|
||||
# PyJWT rejects these tokens BEFORE we reach here,
|
||||
# but we keep these checks in case the library changes or is misconfigured.
|
||||
# Coverage: Marked as pragma since library catches first (see tests/core/test_auth_security.py)
|
||||
if token_algorithm == "NONE": # pragma: no cover
|
||||
raise TokenInvalidError("Algorithm 'none' is not allowed")
|
||||
|
||||
if token_algorithm != settings.ALGORITHM.upper(): # pragma: no cover
|
||||
raise TokenInvalidError(f"Invalid algorithm: {token_algorithm}")
|
||||
|
||||
# Check required claims before Pydantic validation
|
||||
if not payload.get("sub"):
|
||||
raise TokenMissingClaimError("Token missing 'sub' claim")
|
||||
@@ -159,10 +212,11 @@ def decode_token(token: str, verify_type: Optional[str] = None) -> TokenPayload:
|
||||
token_data = TokenPayload(**payload)
|
||||
return token_data
|
||||
|
||||
except JWTError as e:
|
||||
# Check if the error is due to an expired token
|
||||
if "expired" in str(e).lower():
|
||||
raise TokenExpiredError("Token has expired")
|
||||
except ExpiredSignatureError:
|
||||
raise TokenExpiredError("Token has expired")
|
||||
except MissingRequiredClaimError as e:
|
||||
raise TokenMissingClaimError(f"Token missing required claim: {e}")
|
||||
except InvalidTokenError:
|
||||
raise TokenInvalidError("Invalid authentication token")
|
||||
except ValidationError:
|
||||
raise TokenInvalidError("Invalid token payload")
|
||||
@@ -182,4 +236,4 @@ def get_token_data(token: str) -> TokenData:
|
||||
user_id = payload.sub
|
||||
is_superuser = payload.is_superuser or False
|
||||
|
||||
return TokenData(user_id=uuid.UUID(user_id), is_superuser=is_superuser)
|
||||
return TokenData(user_id=uuid.UUID(user_id), is_superuser=is_superuser)
|
||||
|
||||
@@ -1,18 +1,22 @@
|
||||
from pydantic_settings import BaseSettings
|
||||
from typing import Optional, List
|
||||
from pydantic import Field, field_validator
|
||||
import logging
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
PROJECT_NAME: str = "App"
|
||||
PROJECT_NAME: str = "PragmaStack"
|
||||
VERSION: str = "1.0.0"
|
||||
API_V1_STR: str = "/api/v1"
|
||||
|
||||
# Environment (must be before SECRET_KEY for validation)
|
||||
ENVIRONMENT: str = Field(
|
||||
default="development",
|
||||
description="Environment: development, staging, or production"
|
||||
description="Environment: development, staging, or production",
|
||||
)
|
||||
DEMO_MODE: bool = Field(
|
||||
default=False,
|
||||
description="Enable demo mode (relaxed security, demo users)",
|
||||
)
|
||||
|
||||
# Security: Content Security Policy
|
||||
@@ -20,8 +24,7 @@ class Settings(BaseSettings):
|
||||
# Set to True for strict CSP (blocks most external resources)
|
||||
# Set to "relaxed" for modern frontend development
|
||||
CSP_MODE: str = Field(
|
||||
default="relaxed",
|
||||
description="CSP mode: 'strict', 'relaxed', or 'disabled'"
|
||||
default="relaxed", description="CSP mode: 'strict', 'relaxed', or 'disabled'"
|
||||
)
|
||||
|
||||
# Database configuration
|
||||
@@ -30,7 +33,7 @@ class Settings(BaseSettings):
|
||||
POSTGRES_HOST: str = "localhost"
|
||||
POSTGRES_PORT: str = "5432"
|
||||
POSTGRES_DB: str = "app"
|
||||
DATABASE_URL: Optional[str] = None
|
||||
DATABASE_URL: str | None = None
|
||||
db_pool_size: int = 20 # Default connection pool size
|
||||
db_max_overflow: int = 50 # Maximum overflow connections
|
||||
db_pool_timeout: int = 30 # Seconds to wait for a connection
|
||||
@@ -58,38 +61,90 @@ class Settings(BaseSettings):
|
||||
SECRET_KEY: str = Field(
|
||||
default="dev_only_insecure_key_change_in_production_32chars_min",
|
||||
min_length=32,
|
||||
description="JWT signing key. MUST be changed in production. Generate with: python -c 'import secrets; print(secrets.token_urlsafe(32))'"
|
||||
description="JWT signing key. MUST be changed in production. Generate with: python -c 'import secrets; print(secrets.token_urlsafe(32))'",
|
||||
)
|
||||
ALGORITHM: str = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 15 # 15 minutes (production standard)
|
||||
REFRESH_TOKEN_EXPIRE_DAYS: int = 7 # 7 days
|
||||
|
||||
# CORS configuration
|
||||
BACKEND_CORS_ORIGINS: List[str] = ["http://localhost:3000"]
|
||||
BACKEND_CORS_ORIGINS: list[str] = ["http://localhost:3000"]
|
||||
|
||||
# Frontend URL for email links
|
||||
FRONTEND_URL: str = Field(
|
||||
default="http://localhost:3000",
|
||||
description="Frontend application URL for email links"
|
||||
description="Frontend application URL for email links",
|
||||
)
|
||||
|
||||
# OAuth Configuration
|
||||
OAUTH_ENABLED: bool = Field(
|
||||
default=False,
|
||||
description="Enable OAuth authentication (social login)",
|
||||
)
|
||||
OAUTH_AUTO_LINK_BY_EMAIL: bool = Field(
|
||||
default=True,
|
||||
description="Automatically link OAuth accounts to existing users with matching email",
|
||||
)
|
||||
OAUTH_STATE_EXPIRE_MINUTES: int = Field(
|
||||
default=10,
|
||||
description="OAuth state parameter expiration time in minutes",
|
||||
)
|
||||
|
||||
# Google OAuth
|
||||
OAUTH_GOOGLE_CLIENT_ID: str | None = Field(
|
||||
default=None,
|
||||
description="Google OAuth client ID from Google Cloud Console",
|
||||
)
|
||||
OAUTH_GOOGLE_CLIENT_SECRET: str | None = Field(
|
||||
default=None,
|
||||
description="Google OAuth client secret from Google Cloud Console",
|
||||
)
|
||||
|
||||
# GitHub OAuth
|
||||
OAUTH_GITHUB_CLIENT_ID: str | None = Field(
|
||||
default=None,
|
||||
description="GitHub OAuth client ID from GitHub Developer Settings",
|
||||
)
|
||||
OAUTH_GITHUB_CLIENT_SECRET: str | None = Field(
|
||||
default=None,
|
||||
description="GitHub OAuth client secret from GitHub Developer Settings",
|
||||
)
|
||||
|
||||
# OAuth Provider Mode (for MCP clients - skeleton)
|
||||
OAUTH_PROVIDER_ENABLED: bool = Field(
|
||||
default=False,
|
||||
description="Enable OAuth provider mode (act as authorization server for MCP clients)",
|
||||
)
|
||||
OAUTH_ISSUER: str = Field(
|
||||
default="http://localhost:8000",
|
||||
description="OAuth issuer URL (your API base URL)",
|
||||
)
|
||||
|
||||
@property
|
||||
def enabled_oauth_providers(self) -> list[str]:
|
||||
"""Get list of enabled OAuth providers based on configured credentials."""
|
||||
providers = []
|
||||
if self.OAUTH_GOOGLE_CLIENT_ID and self.OAUTH_GOOGLE_CLIENT_SECRET:
|
||||
providers.append("google")
|
||||
if self.OAUTH_GITHUB_CLIENT_ID and self.OAUTH_GITHUB_CLIENT_SECRET:
|
||||
providers.append("github")
|
||||
return providers
|
||||
|
||||
# Admin user
|
||||
FIRST_SUPERUSER_EMAIL: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Email for first superuser account"
|
||||
FIRST_SUPERUSER_EMAIL: str | None = Field(
|
||||
default=None, description="Email for first superuser account"
|
||||
)
|
||||
FIRST_SUPERUSER_PASSWORD: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Password for first superuser (min 12 characters)"
|
||||
FIRST_SUPERUSER_PASSWORD: str | None = Field(
|
||||
default=None, description="Password for first superuser (min 12 characters)"
|
||||
)
|
||||
|
||||
@field_validator('SECRET_KEY')
|
||||
@field_validator("SECRET_KEY")
|
||||
@classmethod
|
||||
def validate_secret_key(cls, v: str, info) -> str:
|
||||
"""Validate SECRET_KEY is secure, especially in production."""
|
||||
# Get environment from values if available
|
||||
values_data = info.data if info.data else {}
|
||||
env = values_data.get('ENVIRONMENT', 'development')
|
||||
env = values_data.get("ENVIRONMENT", "development")
|
||||
|
||||
if v.startswith("your_secret_key_here"):
|
||||
if env == "production":
|
||||
@@ -105,22 +160,40 @@ class Settings(BaseSettings):
|
||||
)
|
||||
|
||||
if len(v) < 32:
|
||||
raise ValueError("SECRET_KEY must be at least 32 characters long for security")
|
||||
raise ValueError(
|
||||
"SECRET_KEY must be at least 32 characters long for security"
|
||||
)
|
||||
|
||||
return v
|
||||
|
||||
@field_validator('FIRST_SUPERUSER_PASSWORD')
|
||||
@field_validator("FIRST_SUPERUSER_PASSWORD")
|
||||
@classmethod
|
||||
def validate_superuser_password(cls, v: Optional[str]) -> Optional[str]:
|
||||
def validate_superuser_password(cls, v: str | None, info) -> str | None:
|
||||
"""Validate superuser password strength."""
|
||||
if v is None:
|
||||
return v
|
||||
|
||||
# Get environment from values if available
|
||||
values_data = info.data if info.data else {}
|
||||
demo_mode = values_data.get("DEMO_MODE", False)
|
||||
|
||||
if demo_mode:
|
||||
# In demo mode, allow specific weak passwords for demo accounts
|
||||
demo_passwords = {"Demo123!", "Admin123!"}
|
||||
if v in demo_passwords:
|
||||
return v
|
||||
|
||||
if len(v) < 12:
|
||||
raise ValueError("FIRST_SUPERUSER_PASSWORD must be at least 12 characters")
|
||||
|
||||
# Check for common weak passwords
|
||||
weak_passwords = {'admin123', 'Admin123', 'password123', 'Password123', '123456789012'}
|
||||
weak_passwords = {
|
||||
"admin123",
|
||||
"Admin123",
|
||||
"password123",
|
||||
"Password123",
|
||||
"123456789012",
|
||||
}
|
||||
if v in weak_passwords:
|
||||
raise ValueError(
|
||||
"FIRST_SUPERUSER_PASSWORD is too weak. "
|
||||
@@ -143,8 +216,8 @@ class Settings(BaseSettings):
|
||||
"env_file": "../.env",
|
||||
"env_file_encoding": "utf-8",
|
||||
"case_sensitive": True,
|
||||
"extra": "ignore" # Ignore extra fields from .env (e.g., frontend-specific vars)
|
||||
"extra": "ignore", # Ignore extra fields from .env (e.g., frontend-specific vars)
|
||||
}
|
||||
|
||||
|
||||
settings = Settings()
|
||||
settings = Settings()
|
||||
|
||||
212
backend/app/core/database.py
Normal file → Executable file
212
backend/app/core/database.py
Normal file → Executable file
@@ -1,112 +1,186 @@
|
||||
# app/core/database.py
|
||||
"""
|
||||
Database configuration using SQLAlchemy 2.0 and asyncpg.
|
||||
|
||||
This module provides async database connectivity with proper connection pooling
|
||||
and session management for FastAPI endpoints.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from typing import Generator
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker, Session
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncEngine,
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
)
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
# Configure logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# SQLite compatibility for testing
|
||||
@compiles(JSONB, 'sqlite')
|
||||
@compiles(JSONB, "sqlite")
|
||||
def compile_jsonb_sqlite(type_, compiler, **kw):
|
||||
return "TEXT"
|
||||
|
||||
@compiles(UUID, 'sqlite')
|
||||
|
||||
@compiles(UUID, "sqlite")
|
||||
def compile_uuid_sqlite(type_, compiler, **kw):
|
||||
return "TEXT"
|
||||
|
||||
# Declarative base for models
|
||||
Base = declarative_base()
|
||||
|
||||
# Create engine with optimized settings for PostgreSQL
|
||||
def create_production_engine():
|
||||
return create_engine(
|
||||
settings.database_url,
|
||||
# Connection pool settings
|
||||
pool_size=settings.db_pool_size,
|
||||
max_overflow=settings.db_max_overflow,
|
||||
pool_timeout=settings.db_pool_timeout,
|
||||
pool_recycle=settings.db_pool_recycle,
|
||||
pool_pre_ping=True,
|
||||
# Query execution settings
|
||||
connect_args={
|
||||
"application_name": "eventspace",
|
||||
"keepalives": 1,
|
||||
"keepalives_idle": 60,
|
||||
"keepalives_interval": 10,
|
||||
"keepalives_count": 5,
|
||||
"options": "-c timezone=UTC",
|
||||
},
|
||||
isolation_level="READ COMMITTED",
|
||||
echo=settings.sql_echo,
|
||||
echo_pool=settings.sql_echo_pool,
|
||||
)
|
||||
# Declarative base for models (SQLAlchemy 2.0 style)
|
||||
class Base(DeclarativeBase):
|
||||
"""Base class for all database models."""
|
||||
|
||||
# Default production engine and session factory
|
||||
engine = create_production_engine()
|
||||
SessionLocal = sessionmaker(
|
||||
|
||||
def get_async_database_url(url: str) -> str:
|
||||
"""
|
||||
Convert sync database URL to async URL.
|
||||
|
||||
postgresql:// -> postgresql+asyncpg://
|
||||
sqlite:// -> sqlite+aiosqlite://
|
||||
"""
|
||||
if url.startswith("postgresql://"):
|
||||
return url.replace("postgresql://", "postgresql+asyncpg://")
|
||||
elif url.startswith("sqlite://"):
|
||||
return url.replace("sqlite://", "sqlite+aiosqlite://")
|
||||
return url
|
||||
|
||||
|
||||
# Create async engine with optimized settings
|
||||
def create_async_production_engine() -> AsyncEngine:
|
||||
"""Create an async database engine with production settings."""
|
||||
async_url = get_async_database_url(settings.database_url)
|
||||
|
||||
# Base engine config
|
||||
engine_config = {
|
||||
"pool_size": settings.db_pool_size,
|
||||
"max_overflow": settings.db_max_overflow,
|
||||
"pool_timeout": settings.db_pool_timeout,
|
||||
"pool_recycle": settings.db_pool_recycle,
|
||||
"pool_pre_ping": True,
|
||||
"echo": settings.sql_echo,
|
||||
"echo_pool": settings.sql_echo_pool,
|
||||
}
|
||||
|
||||
# Add PostgreSQL-specific connect_args
|
||||
if "postgresql" in async_url:
|
||||
engine_config["connect_args"] = { # type: ignore[assignment]
|
||||
"server_settings": {
|
||||
"application_name": settings.PROJECT_NAME,
|
||||
"timezone": "UTC",
|
||||
},
|
||||
# asyncpg-specific settings
|
||||
"command_timeout": 60,
|
||||
"timeout": 10,
|
||||
}
|
||||
|
||||
return create_async_engine(async_url, **engine_config)
|
||||
|
||||
|
||||
# Create async engine and session factory
|
||||
engine = create_async_production_engine()
|
||||
SessionLocal = async_sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
bind=engine,
|
||||
expire_on_commit=False # Prevent unnecessary queries after commit
|
||||
expire_on_commit=False, # Prevent unnecessary queries after commit
|
||||
)
|
||||
|
||||
# FastAPI dependency
|
||||
def get_db() -> Generator[Session, None, None]:
|
||||
|
||||
# FastAPI dependency for async database sessions
|
||||
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""
|
||||
FastAPI dependency that provides a database session.
|
||||
FastAPI dependency that provides an async database session.
|
||||
Automatically closes the session after the request completes.
|
||||
|
||||
Usage:
|
||||
@router.get("/users")
|
||||
async def get_users(db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(User))
|
||||
return result.scalars().all()
|
||||
"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
async with SessionLocal() as session:
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def transaction_scope() -> Generator[Session, None, None]:
|
||||
@asynccontextmanager
|
||||
async def async_transaction_scope() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""
|
||||
Provide a transactional scope for database operations.
|
||||
Provide an async transactional scope for database operations.
|
||||
|
||||
Automatically commits on success or rolls back on exception.
|
||||
Useful for grouping multiple operations in a single transaction.
|
||||
|
||||
Usage:
|
||||
with transaction_scope() as db:
|
||||
user = user_crud.create(db, obj_in=user_create)
|
||||
profile = profile_crud.create(db, obj_in=profile_create)
|
||||
async with async_transaction_scope() as db:
|
||||
user = await user_repo.create(db, obj_in=user_create)
|
||||
profile = await profile_repo.create(db, obj_in=profile_create)
|
||||
# Both operations committed together
|
||||
"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
db.commit()
|
||||
logger.debug("Transaction committed successfully")
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Transaction failed, rolling back: {str(e)}")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
async with SessionLocal() as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
logger.debug("Async transaction committed successfully")
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.error("Async transaction failed, rolling back: %s", e)
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
def check_database_health() -> bool:
|
||||
async def check_async_database_health() -> bool:
|
||||
"""
|
||||
Check if database connection is healthy.
|
||||
Check if async database connection is healthy.
|
||||
Returns True if connection is successful, False otherwise.
|
||||
"""
|
||||
try:
|
||||
with transaction_scope() as db:
|
||||
db.execute(text("SELECT 1"))
|
||||
async with async_transaction_scope() as db:
|
||||
await db.execute(text("SELECT 1"))
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Database health check failed: {str(e)}")
|
||||
return False
|
||||
logger.error("Async database health check failed: %s", e)
|
||||
return False
|
||||
|
||||
|
||||
# Alias for consistency with main.py
|
||||
check_database_health = check_async_database_health
|
||||
|
||||
|
||||
async def init_async_db() -> None:
|
||||
"""
|
||||
Initialize async database tables.
|
||||
|
||||
This creates all tables defined in the models.
|
||||
Should only be used in development or testing.
|
||||
In production, use Alembic migrations.
|
||||
"""
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
logger.info("Async database tables created")
|
||||
|
||||
|
||||
async def close_async_db() -> None:
|
||||
"""
|
||||
Close all async database connections.
|
||||
|
||||
Should be called during application shutdown.
|
||||
"""
|
||||
await engine.dispose()
|
||||
logger.info("Async database connections closed")
|
||||
|
||||
@@ -1,182 +0,0 @@
|
||||
# app/core/database_async.py
|
||||
"""
|
||||
Async database configuration using SQLAlchemy 2.0 and asyncpg.
|
||||
|
||||
This module provides async database connectivity with proper connection pooling
|
||||
and session management for FastAPI endpoints.
|
||||
"""
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncSession,
|
||||
AsyncEngine,
|
||||
create_async_engine,
|
||||
async_sessionmaker,
|
||||
)
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
# Configure logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# SQLite compatibility for testing
|
||||
@compiles(JSONB, 'sqlite')
|
||||
def compile_jsonb_sqlite(type_, compiler, **kw):
|
||||
return "TEXT"
|
||||
|
||||
|
||||
@compiles(UUID, 'sqlite')
|
||||
def compile_uuid_sqlite(type_, compiler, **kw):
|
||||
return "TEXT"
|
||||
|
||||
|
||||
# Declarative base for models (SQLAlchemy 2.0 style)
|
||||
class Base(DeclarativeBase):
|
||||
"""Base class for all database models."""
|
||||
pass
|
||||
|
||||
|
||||
def get_async_database_url(url: str) -> str:
|
||||
"""
|
||||
Convert sync database URL to async URL.
|
||||
|
||||
postgresql:// -> postgresql+asyncpg://
|
||||
sqlite:// -> sqlite+aiosqlite://
|
||||
"""
|
||||
if url.startswith("postgresql://"):
|
||||
return url.replace("postgresql://", "postgresql+asyncpg://")
|
||||
elif url.startswith("sqlite://"):
|
||||
return url.replace("sqlite://", "sqlite+aiosqlite://")
|
||||
return url
|
||||
|
||||
|
||||
# Create async engine with optimized settings
|
||||
def create_async_production_engine() -> AsyncEngine:
|
||||
"""Create an async database engine with production settings."""
|
||||
async_url = get_async_database_url(settings.database_url)
|
||||
|
||||
# Base engine config
|
||||
engine_config = {
|
||||
"pool_size": settings.db_pool_size,
|
||||
"max_overflow": settings.db_max_overflow,
|
||||
"pool_timeout": settings.db_pool_timeout,
|
||||
"pool_recycle": settings.db_pool_recycle,
|
||||
"pool_pre_ping": True,
|
||||
"echo": settings.sql_echo,
|
||||
"echo_pool": settings.sql_echo_pool,
|
||||
}
|
||||
|
||||
# Add PostgreSQL-specific connect_args
|
||||
if "postgresql" in async_url:
|
||||
engine_config["connect_args"] = {
|
||||
"server_settings": {
|
||||
"application_name": "eventspace",
|
||||
"timezone": "UTC",
|
||||
},
|
||||
# asyncpg-specific settings
|
||||
"command_timeout": 60,
|
||||
"timeout": 10,
|
||||
}
|
||||
|
||||
return create_async_engine(async_url, **engine_config)
|
||||
|
||||
|
||||
# Create async engine and session factory
|
||||
async_engine = create_async_production_engine()
|
||||
AsyncSessionLocal = async_sessionmaker(
|
||||
async_engine,
|
||||
class_=AsyncSession,
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
expire_on_commit=False, # Prevent unnecessary queries after commit
|
||||
)
|
||||
|
||||
|
||||
# FastAPI dependency for async database sessions
|
||||
async def get_async_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""
|
||||
FastAPI dependency that provides an async database session.
|
||||
Automatically closes the session after the request completes.
|
||||
|
||||
Usage:
|
||||
@router.get("/users")
|
||||
async def get_users(db: AsyncSession = Depends(get_async_db)):
|
||||
result = await db.execute(select(User))
|
||||
return result.scalars().all()
|
||||
"""
|
||||
async with AsyncSessionLocal() as session:
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_transaction_scope() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""
|
||||
Provide an async transactional scope for database operations.
|
||||
|
||||
Automatically commits on success or rolls back on exception.
|
||||
Useful for grouping multiple operations in a single transaction.
|
||||
|
||||
Usage:
|
||||
async with async_transaction_scope() as db:
|
||||
user = await user_crud.create(db, obj_in=user_create)
|
||||
profile = await profile_crud.create(db, obj_in=profile_create)
|
||||
# Both operations committed together
|
||||
"""
|
||||
async with AsyncSessionLocal() as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
logger.debug("Async transaction committed successfully")
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.error(f"Async transaction failed, rolling back: {str(e)}")
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
async def check_async_database_health() -> bool:
|
||||
"""
|
||||
Check if async database connection is healthy.
|
||||
Returns True if connection is successful, False otherwise.
|
||||
"""
|
||||
try:
|
||||
async with async_transaction_scope() as db:
|
||||
await db.execute(text("SELECT 1"))
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Async database health check failed: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
async def init_async_db() -> None:
|
||||
"""
|
||||
Initialize async database tables.
|
||||
|
||||
This creates all tables defined in the models.
|
||||
Should only be used in development or testing.
|
||||
In production, use Alembic migrations.
|
||||
"""
|
||||
async with async_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
logger.info("Async database tables created")
|
||||
|
||||
|
||||
async def close_async_db() -> None:
|
||||
"""
|
||||
Close all async database connections.
|
||||
|
||||
Should be called during application shutdown.
|
||||
"""
|
||||
await async_engine.dispose()
|
||||
logger.info("Async database connections closed")
|
||||
366
backend/app/core/demo_data.json
Normal file
366
backend/app/core/demo_data.json
Normal file
@@ -0,0 +1,366 @@
|
||||
{
|
||||
"organizations": [
|
||||
{
|
||||
"name": "Acme Corp",
|
||||
"slug": "acme-corp",
|
||||
"description": "A leading provider of coyote-catching equipment."
|
||||
},
|
||||
{
|
||||
"name": "Globex Corporation",
|
||||
"slug": "globex",
|
||||
"description": "We own the East Coast."
|
||||
},
|
||||
{
|
||||
"name": "Soylent Corp",
|
||||
"slug": "soylent",
|
||||
"description": "Making food for the future."
|
||||
},
|
||||
{
|
||||
"name": "Initech",
|
||||
"slug": "initech",
|
||||
"description": "Software for the soul."
|
||||
},
|
||||
{
|
||||
"name": "Umbrella Corporation",
|
||||
"slug": "umbrella",
|
||||
"description": "Our business is life itself."
|
||||
},
|
||||
{
|
||||
"name": "Massive Dynamic",
|
||||
"slug": "massive-dynamic",
|
||||
"description": "What don't we do?"
|
||||
}
|
||||
],
|
||||
"users": [
|
||||
{
|
||||
"email": "demo@example.com",
|
||||
"password": "DemoPass1234!",
|
||||
"first_name": "Demo",
|
||||
"last_name": "User",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "acme-corp",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "alice@acme.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Alice",
|
||||
"last_name": "Smith",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "acme-corp",
|
||||
"role": "admin",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "bob@acme.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Bob",
|
||||
"last_name": "Jones",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "acme-corp",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "charlie@acme.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Charlie",
|
||||
"last_name": "Brown",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "acme-corp",
|
||||
"role": "member",
|
||||
"is_active": false
|
||||
},
|
||||
{
|
||||
"email": "diana@acme.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Diana",
|
||||
"last_name": "Prince",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "acme-corp",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "carol@globex.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Carol",
|
||||
"last_name": "Williams",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "globex",
|
||||
"role": "owner",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "dan@globex.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Dan",
|
||||
"last_name": "Miller",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "globex",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "ellen@globex.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Ellen",
|
||||
"last_name": "Ripley",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "globex",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "fred@globex.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Fred",
|
||||
"last_name": "Flintstone",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "globex",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "dave@soylent.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Dave",
|
||||
"last_name": "Brown",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "soylent",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "gina@soylent.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Gina",
|
||||
"last_name": "Torres",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "soylent",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "harry@soylent.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Harry",
|
||||
"last_name": "Potter",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "soylent",
|
||||
"role": "admin",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "eve@initech.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Eve",
|
||||
"last_name": "Davis",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "initech",
|
||||
"role": "admin",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "iris@initech.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Iris",
|
||||
"last_name": "West",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "initech",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "jack@initech.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Jack",
|
||||
"last_name": "Sparrow",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "initech",
|
||||
"role": "member",
|
||||
"is_active": false
|
||||
},
|
||||
{
|
||||
"email": "frank@umbrella.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Frank",
|
||||
"last_name": "Miller",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "umbrella",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "george@umbrella.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "George",
|
||||
"last_name": "Costanza",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "umbrella",
|
||||
"role": "member",
|
||||
"is_active": false
|
||||
},
|
||||
{
|
||||
"email": "kate@umbrella.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Kate",
|
||||
"last_name": "Bishop",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "umbrella",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "leo@massive.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Leo",
|
||||
"last_name": "Messi",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "massive-dynamic",
|
||||
"role": "owner",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "mary@massive.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Mary",
|
||||
"last_name": "Jane",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "massive-dynamic",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "nathan@massive.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Nathan",
|
||||
"last_name": "Drake",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "massive-dynamic",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "olivia@massive.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Olivia",
|
||||
"last_name": "Dunham",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "massive-dynamic",
|
||||
"role": "admin",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "peter@massive.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Peter",
|
||||
"last_name": "Parker",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "massive-dynamic",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "quinn@massive.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Quinn",
|
||||
"last_name": "Mallory",
|
||||
"is_superuser": false,
|
||||
"organization_slug": "massive-dynamic",
|
||||
"role": "member",
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "grace@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Grace",
|
||||
"last_name": "Hopper",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "heidi@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Heidi",
|
||||
"last_name": "Klum",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "ivan@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Ivan",
|
||||
"last_name": "Drago",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": false
|
||||
},
|
||||
{
|
||||
"email": "rachel@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Rachel",
|
||||
"last_name": "Green",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "sam@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Sam",
|
||||
"last_name": "Wilson",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "tony@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Tony",
|
||||
"last_name": "Stark",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "una@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Una",
|
||||
"last_name": "Chin-Riley",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": false
|
||||
},
|
||||
{
|
||||
"email": "victor@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Victor",
|
||||
"last_name": "Von Doom",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": true
|
||||
},
|
||||
{
|
||||
"email": "wanda@example.com",
|
||||
"password": "Demo123!",
|
||||
"first_name": "Wanda",
|
||||
"last_name": "Maximoff",
|
||||
"is_superuser": false,
|
||||
"organization_slug": null,
|
||||
"role": null,
|
||||
"is_active": true
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -1,11 +1,12 @@
|
||||
"""
|
||||
Custom exceptions and global exception handlers for the API.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, Union, List
|
||||
|
||||
from fastapi import HTTPException, Request, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.schemas.errors import ErrorCode, ErrorDetail, ErrorResponse
|
||||
@@ -26,17 +27,13 @@ class APIException(HTTPException):
|
||||
status_code: int,
|
||||
error_code: ErrorCode,
|
||||
message: str,
|
||||
field: Optional[str] = None,
|
||||
headers: Optional[dict] = None
|
||||
field: str | None = None,
|
||||
headers: dict | None = None,
|
||||
):
|
||||
self.error_code = error_code
|
||||
self.field = field
|
||||
self.message = message
|
||||
super().__init__(
|
||||
status_code=status_code,
|
||||
detail=message,
|
||||
headers=headers
|
||||
)
|
||||
super().__init__(status_code=status_code, detail=message, headers=headers)
|
||||
|
||||
|
||||
class AuthenticationError(APIException):
|
||||
@@ -46,14 +43,14 @@ class AuthenticationError(APIException):
|
||||
self,
|
||||
message: str = "Authentication failed",
|
||||
error_code: ErrorCode = ErrorCode.INVALID_CREDENTIALS,
|
||||
field: Optional[str] = None
|
||||
field: str | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
error_code=error_code,
|
||||
message=message,
|
||||
field=field,
|
||||
headers={"WWW-Authenticate": "Bearer"}
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
|
||||
@@ -63,12 +60,12 @@ class AuthorizationError(APIException):
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Insufficient permissions",
|
||||
error_code: ErrorCode = ErrorCode.INSUFFICIENT_PERMISSIONS
|
||||
error_code: ErrorCode = ErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
error_code=error_code,
|
||||
message=message
|
||||
message=message,
|
||||
)
|
||||
|
||||
|
||||
@@ -78,12 +75,12 @@ class NotFoundError(APIException):
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Resource not found",
|
||||
error_code: ErrorCode = ErrorCode.NOT_FOUND
|
||||
error_code: ErrorCode = ErrorCode.NOT_FOUND,
|
||||
):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
error_code=error_code,
|
||||
message=message
|
||||
message=message,
|
||||
)
|
||||
|
||||
|
||||
@@ -94,13 +91,13 @@ class DuplicateError(APIException):
|
||||
self,
|
||||
message: str = "Resource already exists",
|
||||
error_code: ErrorCode = ErrorCode.DUPLICATE_ENTRY,
|
||||
field: Optional[str] = None
|
||||
field: str | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
error_code=error_code,
|
||||
message=message,
|
||||
field=field
|
||||
field=field,
|
||||
)
|
||||
|
||||
|
||||
@@ -111,13 +108,13 @@ class ValidationException(APIException):
|
||||
self,
|
||||
message: str = "Validation error",
|
||||
error_code: ErrorCode = ErrorCode.VALIDATION_ERROR,
|
||||
field: Optional[str] = None
|
||||
field: str | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
error_code=error_code,
|
||||
message=message,
|
||||
field=field
|
||||
field=field,
|
||||
)
|
||||
|
||||
|
||||
@@ -127,12 +124,12 @@ class DatabaseError(APIException):
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Database operation failed",
|
||||
error_code: ErrorCode = ErrorCode.DATABASE_ERROR
|
||||
error_code: ErrorCode = ErrorCode.DATABASE_ERROR,
|
||||
):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
error_code=error_code,
|
||||
message=message
|
||||
message=message,
|
||||
)
|
||||
|
||||
|
||||
@@ -146,28 +143,26 @@ async def api_exception_handler(request: Request, exc: APIException) -> JSONResp
|
||||
Returns a standardized error response with error code and message.
|
||||
"""
|
||||
logger.warning(
|
||||
f"API exception: {exc.error_code} - {exc.message} "
|
||||
f"(status: {exc.status_code}, path: {request.url.path})"
|
||||
"API exception: %s - %s (status: %s, path: %s)",
|
||||
exc.error_code,
|
||||
exc.message,
|
||||
exc.status_code,
|
||||
request.url.path,
|
||||
)
|
||||
|
||||
error_response = ErrorResponse(
|
||||
errors=[ErrorDetail(
|
||||
code=exc.error_code,
|
||||
message=exc.message,
|
||||
field=exc.field
|
||||
)]
|
||||
errors=[ErrorDetail(code=exc.error_code, message=exc.message, field=exc.field)]
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content=error_response.model_dump(),
|
||||
headers=exc.headers
|
||||
headers=exc.headers,
|
||||
)
|
||||
|
||||
|
||||
async def validation_exception_handler(
|
||||
request: Request,
|
||||
exc: Union[RequestValidationError, ValidationError]
|
||||
request: Request, exc: RequestValidationError | ValidationError
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
Handler for Pydantic validation errors.
|
||||
@@ -188,22 +183,21 @@ async def validation_exception_handler(
|
||||
# Skip 'body' or 'query' prefix in location
|
||||
field = ".".join(str(x) for x in error["loc"][1:])
|
||||
|
||||
errors.append(ErrorDetail(
|
||||
code=ErrorCode.VALIDATION_ERROR,
|
||||
message=error["msg"],
|
||||
field=field
|
||||
))
|
||||
errors.append(
|
||||
ErrorDetail(
|
||||
code=ErrorCode.VALIDATION_ERROR, message=error["msg"], field=field
|
||||
)
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
f"Validation error: {len(errors)} errors "
|
||||
f"(path: {request.url.path})"
|
||||
"Validation error: %s errors (path: %s)", len(errors), request.url.path
|
||||
)
|
||||
|
||||
error_response = ErrorResponse(errors=errors)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
content=error_response.model_dump()
|
||||
content=error_response.model_dump(),
|
||||
)
|
||||
|
||||
|
||||
@@ -225,26 +219,24 @@ async def http_exception_handler(request: Request, exc: HTTPException) -> JSONRe
|
||||
}
|
||||
|
||||
error_code = status_code_to_error_code.get(
|
||||
exc.status_code,
|
||||
ErrorCode.INTERNAL_ERROR
|
||||
exc.status_code, ErrorCode.INTERNAL_ERROR
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
f"HTTP exception: {exc.status_code} - {exc.detail} "
|
||||
f"(path: {request.url.path})"
|
||||
"HTTP exception: %s - %s (path: %s)",
|
||||
exc.status_code,
|
||||
exc.detail,
|
||||
request.url.path,
|
||||
)
|
||||
|
||||
error_response = ErrorResponse(
|
||||
errors=[ErrorDetail(
|
||||
code=error_code,
|
||||
message=str(exc.detail)
|
||||
)]
|
||||
errors=[ErrorDetail(code=error_code, message=str(exc.detail), field=None)]
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content=error_response.model_dump(),
|
||||
headers=exc.headers
|
||||
headers=exc.headers,
|
||||
)
|
||||
|
||||
|
||||
@@ -255,27 +247,26 @@ async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONR
|
||||
Logs the full exception and returns a generic error response to avoid
|
||||
leaking sensitive information in production.
|
||||
"""
|
||||
logger.error(
|
||||
f"Unhandled exception: {type(exc).__name__} - {str(exc)} "
|
||||
f"(path: {request.url.path})",
|
||||
exc_info=True
|
||||
logger.exception(
|
||||
"Unhandled exception: %s - %s (path: %s)",
|
||||
type(exc).__name__,
|
||||
exc,
|
||||
request.url.path,
|
||||
)
|
||||
|
||||
# In production, don't expose internal error details
|
||||
from app.core.config import settings
|
||||
|
||||
if settings.ENVIRONMENT == "production":
|
||||
message = "An internal error occurred. Please try again later."
|
||||
else:
|
||||
message = f"{type(exc).__name__}: {str(exc)}"
|
||||
message = f"{type(exc).__name__}: {exc!s}"
|
||||
|
||||
error_response = ErrorResponse(
|
||||
errors=[ErrorDetail(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message=message
|
||||
)]
|
||||
errors=[ErrorDetail(code=ErrorCode.INTERNAL_ERROR, message=message, field=None)]
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content=error_response.model_dump()
|
||||
content=error_response.model_dump(),
|
||||
)
|
||||
|
||||
26
backend/app/core/repository_exceptions.py
Normal file
26
backend/app/core/repository_exceptions.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""
|
||||
Custom exceptions for the repository layer.
|
||||
|
||||
These exceptions allow services and routes to handle database-level errors
|
||||
with proper semantics, without leaking SQLAlchemy internals.
|
||||
"""
|
||||
|
||||
|
||||
class RepositoryError(Exception):
|
||||
"""Base for all repository-layer errors."""
|
||||
|
||||
|
||||
class DuplicateEntryError(RepositoryError):
|
||||
"""Raised on unique constraint violations. Maps to HTTP 409 Conflict."""
|
||||
|
||||
|
||||
class IntegrityConstraintError(RepositoryError):
|
||||
"""Raised on FK or check constraint violations."""
|
||||
|
||||
|
||||
class RecordNotFoundError(RepositoryError):
|
||||
"""Raised when an expected record doesn't exist."""
|
||||
|
||||
|
||||
class InvalidInputError(RepositoryError):
|
||||
"""Raised on bad pagination params, invalid UUIDs, or other invalid inputs."""
|
||||
@@ -1,6 +0,0 @@
|
||||
# app/crud/__init__.py
|
||||
from .user import user
|
||||
from .session import session as session_crud
|
||||
from .organization import organization
|
||||
|
||||
__all__ = ["user", "session_crud", "organization"]
|
||||
@@ -1,304 +0,0 @@
|
||||
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union, Tuple
|
||||
from datetime import datetime, timezone
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError, DataError
|
||||
from sqlalchemy import func, asc, desc
|
||||
from app.core.database import Base
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ModelType = TypeVar("ModelType", bound=Base)
|
||||
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
|
||||
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
|
||||
|
||||
|
||||
class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
def __init__(self, model: Type[ModelType]):
|
||||
"""
|
||||
CRUD object with default methods to Create, Read, Update, Delete (CRUD).
|
||||
|
||||
Parameters:
|
||||
model: A SQLAlchemy model class
|
||||
"""
|
||||
self.model = model
|
||||
|
||||
def get(self, db: Session, id: str) -> Optional[ModelType]:
|
||||
"""Get a single record by ID with UUID validation."""
|
||||
# Validate UUID format and convert to UUID object if string
|
||||
try:
|
||||
if isinstance(id, uuid.UUID):
|
||||
uuid_obj = id
|
||||
else:
|
||||
uuid_obj = uuid.UUID(str(id))
|
||||
except (ValueError, AttributeError, TypeError) as e:
|
||||
logger.warning(f"Invalid UUID format: {id} - {str(e)}")
|
||||
return None
|
||||
|
||||
try:
|
||||
return db.query(self.model).filter(self.model.id == uuid_obj).first()
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving {self.model.__name__} with id {id}: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_multi(
|
||||
self, db: Session, *, skip: int = 0, limit: int = 100
|
||||
) -> List[ModelType]:
|
||||
"""Get multiple records with pagination validation."""
|
||||
# Validate pagination parameters
|
||||
if skip < 0:
|
||||
raise ValueError("skip must be non-negative")
|
||||
if limit < 0:
|
||||
raise ValueError("limit must be non-negative")
|
||||
if limit > 1000:
|
||||
raise ValueError("Maximum limit is 1000")
|
||||
|
||||
try:
|
||||
return db.query(self.model).offset(skip).limit(limit).all()
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving multiple {self.model.__name__} records: {str(e)}")
|
||||
raise
|
||||
|
||||
def create(self, db: Session, *, obj_in: CreateSchemaType) -> ModelType:
|
||||
"""Create a new record with error handling."""
|
||||
try:
|
||||
obj_in_data = jsonable_encoder(obj_in)
|
||||
db_obj = self.model(**obj_in_data)
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
||||
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
|
||||
logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(f"A {self.model.__name__} with this data already exists")
|
||||
logger.error(f"Integrity error creating {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
except (OperationalError, DataError) as e:
|
||||
db.rollback()
|
||||
logger.error(f"Database error creating {self.model.__name__}: {str(e)}")
|
||||
raise ValueError(f"Database operation failed: {str(e)}")
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Unexpected error creating {self.model.__name__}: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def update(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
db_obj: ModelType,
|
||||
obj_in: Union[UpdateSchemaType, Dict[str, Any]]
|
||||
) -> ModelType:
|
||||
"""Update a record with error handling."""
|
||||
try:
|
||||
obj_data = jsonable_encoder(db_obj)
|
||||
if isinstance(obj_in, dict):
|
||||
update_data = obj_in
|
||||
else:
|
||||
update_data = obj_in.model_dump(exclude_unset=True)
|
||||
for field in obj_data:
|
||||
if field in update_data:
|
||||
setattr(db_obj, field, update_data[field])
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
||||
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
|
||||
logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(f"A {self.model.__name__} with this data already exists")
|
||||
logger.error(f"Integrity error updating {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
except (OperationalError, DataError) as e:
|
||||
db.rollback()
|
||||
logger.error(f"Database error updating {self.model.__name__}: {str(e)}")
|
||||
raise ValueError(f"Database operation failed: {str(e)}")
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Unexpected error updating {self.model.__name__}: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def remove(self, db: Session, *, id: str) -> Optional[ModelType]:
|
||||
"""Delete a record with error handling and null check."""
|
||||
# Validate UUID format and convert to UUID object if string
|
||||
try:
|
||||
if isinstance(id, uuid.UUID):
|
||||
uuid_obj = id
|
||||
else:
|
||||
uuid_obj = uuid.UUID(str(id))
|
||||
except (ValueError, AttributeError, TypeError) as e:
|
||||
logger.warning(f"Invalid UUID format for deletion: {id} - {str(e)}")
|
||||
return None
|
||||
|
||||
try:
|
||||
obj = db.query(self.model).filter(self.model.id == uuid_obj).first()
|
||||
if obj is None:
|
||||
logger.warning(f"{self.model.__name__} with id {id} not found for deletion")
|
||||
return None
|
||||
|
||||
db.delete(obj)
|
||||
db.commit()
|
||||
return obj
|
||||
except IntegrityError as e:
|
||||
db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
||||
logger.error(f"Integrity error deleting {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(f"Cannot delete {self.model.__name__}: referenced by other records")
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def get_multi_with_total(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
sort_by: Optional[str] = None,
|
||||
sort_order: str = "asc",
|
||||
filters: Optional[Dict[str, Any]] = None
|
||||
) -> Tuple[List[ModelType], int]:
|
||||
"""
|
||||
Get multiple records with total count, filtering, and sorting.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
skip: Number of records to skip
|
||||
limit: Maximum number of records to return
|
||||
sort_by: Field name to sort by (must be a valid model attribute)
|
||||
sort_order: Sort order ("asc" or "desc")
|
||||
filters: Dictionary of filters (field_name: value)
|
||||
|
||||
Returns:
|
||||
Tuple of (items, total_count)
|
||||
"""
|
||||
# Validate pagination parameters
|
||||
if skip < 0:
|
||||
raise ValueError("skip must be non-negative")
|
||||
if limit < 0:
|
||||
raise ValueError("limit must be non-negative")
|
||||
if limit > 1000:
|
||||
raise ValueError("Maximum limit is 1000")
|
||||
|
||||
try:
|
||||
# Build base query
|
||||
query = db.query(self.model)
|
||||
|
||||
# Exclude soft-deleted records by default
|
||||
if hasattr(self.model, 'deleted_at'):
|
||||
query = query.filter(self.model.deleted_at.is_(None))
|
||||
|
||||
# Apply filters
|
||||
if filters:
|
||||
for field, value in filters.items():
|
||||
if hasattr(self.model, field) and value is not None:
|
||||
query = query.filter(getattr(self.model, field) == value)
|
||||
|
||||
# Get total count (before pagination)
|
||||
total = query.count()
|
||||
|
||||
# Apply sorting
|
||||
if sort_by and hasattr(self.model, sort_by):
|
||||
sort_column = getattr(self.model, sort_by)
|
||||
if sort_order.lower() == "desc":
|
||||
query = query.order_by(desc(sort_column))
|
||||
else:
|
||||
query = query.order_by(asc(sort_column))
|
||||
|
||||
# Apply pagination
|
||||
items = query.offset(skip).limit(limit).all()
|
||||
|
||||
return items, total
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving paginated {self.model.__name__} records: {str(e)}")
|
||||
raise
|
||||
|
||||
def soft_delete(self, db: Session, *, id: str) -> Optional[ModelType]:
|
||||
"""
|
||||
Soft delete a record by setting deleted_at timestamp.
|
||||
|
||||
Only works if the model has a 'deleted_at' column.
|
||||
"""
|
||||
# Validate UUID format and convert to UUID object if string
|
||||
try:
|
||||
if isinstance(id, uuid.UUID):
|
||||
uuid_obj = id
|
||||
else:
|
||||
uuid_obj = uuid.UUID(str(id))
|
||||
except (ValueError, AttributeError, TypeError) as e:
|
||||
logger.warning(f"Invalid UUID format for soft deletion: {id} - {str(e)}")
|
||||
return None
|
||||
|
||||
try:
|
||||
obj = db.query(self.model).filter(self.model.id == uuid_obj).first()
|
||||
|
||||
if obj is None:
|
||||
logger.warning(f"{self.model.__name__} with id {id} not found for soft deletion")
|
||||
return None
|
||||
|
||||
# Check if model supports soft deletes
|
||||
if not hasattr(self.model, 'deleted_at'):
|
||||
logger.error(f"{self.model.__name__} does not support soft deletes")
|
||||
raise ValueError(f"{self.model.__name__} does not have a deleted_at column")
|
||||
|
||||
# Set deleted_at timestamp
|
||||
obj.deleted_at = datetime.now(timezone.utc)
|
||||
db.add(obj)
|
||||
db.commit()
|
||||
db.refresh(obj)
|
||||
return obj
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error soft deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def restore(self, db: Session, *, id: str) -> Optional[ModelType]:
|
||||
"""
|
||||
Restore a soft-deleted record by clearing the deleted_at timestamp.
|
||||
|
||||
Only works if the model has a 'deleted_at' column.
|
||||
"""
|
||||
# Validate UUID format
|
||||
try:
|
||||
if isinstance(id, uuid.UUID):
|
||||
uuid_obj = id
|
||||
else:
|
||||
uuid_obj = uuid.UUID(str(id))
|
||||
except (ValueError, AttributeError, TypeError) as e:
|
||||
logger.warning(f"Invalid UUID format for restoration: {id} - {str(e)}")
|
||||
return None
|
||||
|
||||
try:
|
||||
# Find the soft-deleted record
|
||||
if hasattr(self.model, 'deleted_at'):
|
||||
obj = db.query(self.model).filter(
|
||||
self.model.id == uuid_obj,
|
||||
self.model.deleted_at.isnot(None)
|
||||
).first()
|
||||
else:
|
||||
logger.error(f"{self.model.__name__} does not support soft deletes")
|
||||
raise ValueError(f"{self.model.__name__} does not have a deleted_at column")
|
||||
|
||||
if obj is None:
|
||||
logger.warning(f"Soft-deleted {self.model.__name__} with id {id} not found for restoration")
|
||||
return None
|
||||
|
||||
# Clear deleted_at timestamp
|
||||
obj.deleted_at = None
|
||||
db.add(obj)
|
||||
db.commit()
|
||||
db.refresh(obj)
|
||||
return obj
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error restoring {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
|
||||
raise
|
||||
@@ -1,228 +0,0 @@
|
||||
# app/crud/base_async.py
|
||||
"""
|
||||
Async CRUD operations base class using SQLAlchemy 2.0 async patterns.
|
||||
|
||||
Provides reusable create, read, update, and delete operations for all models.
|
||||
"""
|
||||
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union, Tuple
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError, DataError
|
||||
|
||||
from app.core.database_async import Base
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ModelType = TypeVar("ModelType", bound=Base)
|
||||
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
|
||||
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
|
||||
|
||||
|
||||
class CRUDBaseAsync(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
"""Async CRUD operations for a model."""
|
||||
|
||||
def __init__(self, model: Type[ModelType]):
|
||||
"""
|
||||
CRUD object with default async methods to Create, Read, Update, Delete.
|
||||
|
||||
Parameters:
|
||||
model: A SQLAlchemy model class
|
||||
"""
|
||||
self.model = model
|
||||
|
||||
async def get(self, db: AsyncSession, id: str) -> Optional[ModelType]:
|
||||
"""Get a single record by ID with UUID validation."""
|
||||
# Validate UUID format and convert to UUID object if string
|
||||
try:
|
||||
if isinstance(id, uuid.UUID):
|
||||
uuid_obj = id
|
||||
else:
|
||||
uuid_obj = uuid.UUID(str(id))
|
||||
except (ValueError, AttributeError, TypeError) as e:
|
||||
logger.warning(f"Invalid UUID format: {id} - {str(e)}")
|
||||
return None
|
||||
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(self.model).where(self.model.id == uuid_obj)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving {self.model.__name__} with id {id}: {str(e)}")
|
||||
raise
|
||||
|
||||
async def get_multi(
|
||||
self, db: AsyncSession, *, skip: int = 0, limit: int = 100
|
||||
) -> List[ModelType]:
|
||||
"""Get multiple records with pagination validation."""
|
||||
# Validate pagination parameters
|
||||
if skip < 0:
|
||||
raise ValueError("skip must be non-negative")
|
||||
if limit < 0:
|
||||
raise ValueError("limit must be non-negative")
|
||||
if limit > 1000:
|
||||
raise ValueError("Maximum limit is 1000")
|
||||
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(self.model).offset(skip).limit(limit)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving multiple {self.model.__name__} records: {str(e)}")
|
||||
raise
|
||||
|
||||
async def create(self, db: AsyncSession, *, obj_in: CreateSchemaType) -> ModelType:
|
||||
"""Create a new record with error handling."""
|
||||
try:
|
||||
obj_in_data = jsonable_encoder(obj_in)
|
||||
db_obj = self.model(**obj_in_data)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
||||
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
|
||||
logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(f"A {self.model.__name__} with this data already exists")
|
||||
logger.error(f"Integrity error creating {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
except (OperationalError, DataError) as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Database error creating {self.model.__name__}: {str(e)}")
|
||||
raise ValueError(f"Database operation failed: {str(e)}")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Unexpected error creating {self.model.__name__}: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def update(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
db_obj: ModelType,
|
||||
obj_in: Union[UpdateSchemaType, Dict[str, Any]]
|
||||
) -> ModelType:
|
||||
"""Update a record with error handling."""
|
||||
try:
|
||||
obj_data = jsonable_encoder(db_obj)
|
||||
if isinstance(obj_in, dict):
|
||||
update_data = obj_in
|
||||
else:
|
||||
update_data = obj_in.model_dump(exclude_unset=True)
|
||||
|
||||
for field in obj_data:
|
||||
if field in update_data:
|
||||
setattr(db_obj, field, update_data[field])
|
||||
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
||||
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
|
||||
logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(f"A {self.model.__name__} with this data already exists")
|
||||
logger.error(f"Integrity error updating {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
except (OperationalError, DataError) as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Database error updating {self.model.__name__}: {str(e)}")
|
||||
raise ValueError(f"Database operation failed: {str(e)}")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Unexpected error updating {self.model.__name__}: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def remove(self, db: AsyncSession, *, id: str) -> Optional[ModelType]:
|
||||
"""Delete a record with error handling and null check."""
|
||||
# Validate UUID format and convert to UUID object if string
|
||||
try:
|
||||
if isinstance(id, uuid.UUID):
|
||||
uuid_obj = id
|
||||
else:
|
||||
uuid_obj = uuid.UUID(str(id))
|
||||
except (ValueError, AttributeError, TypeError) as e:
|
||||
logger.warning(f"Invalid UUID format for deletion: {id} - {str(e)}")
|
||||
return None
|
||||
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(self.model).where(self.model.id == uuid_obj)
|
||||
)
|
||||
obj = result.scalar_one_or_none()
|
||||
|
||||
if obj is None:
|
||||
logger.warning(f"{self.model.__name__} with id {id} not found for deletion")
|
||||
return None
|
||||
|
||||
await db.delete(obj)
|
||||
await db.commit()
|
||||
return obj
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
||||
logger.error(f"Integrity error deleting {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(f"Cannot delete {self.model.__name__}: referenced by other records")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_multi_with_total(
|
||||
self, db: AsyncSession, *, skip: int = 0, limit: int = 100
|
||||
) -> Tuple[List[ModelType], int]:
|
||||
"""
|
||||
Get multiple records with total count for pagination.
|
||||
|
||||
Returns:
|
||||
Tuple of (items, total_count)
|
||||
"""
|
||||
# Validate pagination parameters
|
||||
if skip < 0:
|
||||
raise ValueError("skip must be non-negative")
|
||||
if limit < 0:
|
||||
raise ValueError("limit must be non-negative")
|
||||
if limit > 1000:
|
||||
raise ValueError("Maximum limit is 1000")
|
||||
|
||||
try:
|
||||
# Get total count
|
||||
count_result = await db.execute(
|
||||
select(func.count(self.model.id))
|
||||
)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# Get paginated items
|
||||
items_result = await db.execute(
|
||||
select(self.model).offset(skip).limit(limit)
|
||||
)
|
||||
items = list(items_result.scalars().all())
|
||||
|
||||
return items, total
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving paginated {self.model.__name__} records: {str(e)}")
|
||||
raise
|
||||
|
||||
async def count(self, db: AsyncSession) -> int:
|
||||
"""Get total count of records."""
|
||||
try:
|
||||
result = await db.execute(select(func.count(self.model.id)))
|
||||
return result.scalar_one()
|
||||
except Exception as e:
|
||||
logger.error(f"Error counting {self.model.__name__} records: {str(e)}")
|
||||
raise
|
||||
|
||||
async def exists(self, db: AsyncSession, id: str) -> bool:
|
||||
"""Check if a record exists by ID."""
|
||||
obj = await self.get(db, id=id)
|
||||
return obj is not None
|
||||
@@ -1,322 +0,0 @@
|
||||
# app/crud/organization.py
|
||||
from typing import Optional, List, Dict, Any, Union
|
||||
from uuid import UUID
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy import func, or_, and_
|
||||
|
||||
from app.crud.base import CRUDBase
|
||||
from app.models.organization import Organization
|
||||
from app.models.user_organization import UserOrganization, OrganizationRole
|
||||
from app.models.user import User
|
||||
from app.schemas.organizations import (
|
||||
OrganizationCreate,
|
||||
OrganizationUpdate,
|
||||
UserOrganizationCreate,
|
||||
UserOrganizationUpdate
|
||||
)
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUpdate]):
|
||||
"""CRUD operations for Organization model."""
|
||||
|
||||
def get_by_slug(self, db: Session, *, slug: str) -> Optional[Organization]:
|
||||
"""Get organization by slug."""
|
||||
return db.query(Organization).filter(Organization.slug == slug).first()
|
||||
|
||||
def create(self, db: Session, *, obj_in: OrganizationCreate) -> Organization:
|
||||
"""Create a new organization with error handling."""
|
||||
try:
|
||||
db_obj = Organization(
|
||||
name=obj_in.name,
|
||||
slug=obj_in.slug,
|
||||
description=obj_in.description,
|
||||
is_active=obj_in.is_active,
|
||||
settings=obj_in.settings or {}
|
||||
)
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
||||
if "slug" in error_msg.lower():
|
||||
logger.warning(f"Duplicate slug attempted: {obj_in.slug}")
|
||||
raise ValueError(f"Organization with slug '{obj_in.slug}' already exists")
|
||||
logger.error(f"Integrity error creating organization: {error_msg}")
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Unexpected error creating organization: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def get_multi_with_filters(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: Optional[bool] = None,
|
||||
search: Optional[str] = None,
|
||||
sort_by: str = "created_at",
|
||||
sort_order: str = "desc"
|
||||
) -> tuple[List[Organization], int]:
|
||||
"""
|
||||
Get multiple organizations with filtering, searching, and sorting.
|
||||
|
||||
Returns:
|
||||
Tuple of (organizations list, total count)
|
||||
"""
|
||||
query = db.query(Organization)
|
||||
|
||||
# Apply filters
|
||||
if is_active is not None:
|
||||
query = query.filter(Organization.is_active == is_active)
|
||||
|
||||
if search:
|
||||
search_filter = or_(
|
||||
Organization.name.ilike(f"%{search}%"),
|
||||
Organization.slug.ilike(f"%{search}%"),
|
||||
Organization.description.ilike(f"%{search}%")
|
||||
)
|
||||
query = query.filter(search_filter)
|
||||
|
||||
# Get total count before pagination
|
||||
total = query.count()
|
||||
|
||||
# Apply sorting
|
||||
sort_column = getattr(Organization, sort_by, Organization.created_at)
|
||||
if sort_order == "desc":
|
||||
query = query.order_by(sort_column.desc())
|
||||
else:
|
||||
query = query.order_by(sort_column.asc())
|
||||
|
||||
# Apply pagination
|
||||
organizations = query.offset(skip).limit(limit).all()
|
||||
|
||||
return organizations, total
|
||||
|
||||
def get_member_count(self, db: Session, *, organization_id: UUID) -> int:
|
||||
"""Get the count of active members in an organization."""
|
||||
return db.query(func.count(UserOrganization.user_id)).filter(
|
||||
and_(
|
||||
UserOrganization.organization_id == organization_id,
|
||||
UserOrganization.is_active == True
|
||||
)
|
||||
).scalar() or 0
|
||||
|
||||
def add_user(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
organization_id: UUID,
|
||||
user_id: UUID,
|
||||
role: OrganizationRole = OrganizationRole.MEMBER,
|
||||
custom_permissions: Optional[str] = None
|
||||
) -> UserOrganization:
|
||||
"""Add a user to an organization with a specific role."""
|
||||
try:
|
||||
# Check if relationship already exists
|
||||
existing = db.query(UserOrganization).filter(
|
||||
and_(
|
||||
UserOrganization.user_id == user_id,
|
||||
UserOrganization.organization_id == organization_id
|
||||
)
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
# Reactivate if inactive, or raise error if already active
|
||||
if not existing.is_active:
|
||||
existing.is_active = True
|
||||
existing.role = role
|
||||
existing.custom_permissions = custom_permissions
|
||||
db.commit()
|
||||
db.refresh(existing)
|
||||
return existing
|
||||
else:
|
||||
raise ValueError("User is already a member of this organization")
|
||||
|
||||
# Create new relationship
|
||||
user_org = UserOrganization(
|
||||
user_id=user_id,
|
||||
organization_id=organization_id,
|
||||
role=role,
|
||||
is_active=True,
|
||||
custom_permissions=custom_permissions
|
||||
)
|
||||
db.add(user_org)
|
||||
db.commit()
|
||||
db.refresh(user_org)
|
||||
return user_org
|
||||
except IntegrityError as e:
|
||||
db.rollback()
|
||||
logger.error(f"Integrity error adding user to organization: {str(e)}")
|
||||
raise ValueError("Failed to add user to organization")
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error adding user to organization: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def remove_user(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
organization_id: UUID,
|
||||
user_id: UUID
|
||||
) -> bool:
|
||||
"""Remove a user from an organization (soft delete)."""
|
||||
try:
|
||||
user_org = db.query(UserOrganization).filter(
|
||||
and_(
|
||||
UserOrganization.user_id == user_id,
|
||||
UserOrganization.organization_id == organization_id
|
||||
)
|
||||
).first()
|
||||
|
||||
if not user_org:
|
||||
return False
|
||||
|
||||
user_org.is_active = False
|
||||
db.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error removing user from organization: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def update_user_role(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
organization_id: UUID,
|
||||
user_id: UUID,
|
||||
role: OrganizationRole,
|
||||
custom_permissions: Optional[str] = None
|
||||
) -> Optional[UserOrganization]:
|
||||
"""Update a user's role in an organization."""
|
||||
try:
|
||||
user_org = db.query(UserOrganization).filter(
|
||||
and_(
|
||||
UserOrganization.user_id == user_id,
|
||||
UserOrganization.organization_id == organization_id
|
||||
)
|
||||
).first()
|
||||
|
||||
if not user_org:
|
||||
return None
|
||||
|
||||
user_org.role = role
|
||||
if custom_permissions is not None:
|
||||
user_org.custom_permissions = custom_permissions
|
||||
db.commit()
|
||||
db.refresh(user_org)
|
||||
return user_org
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error updating user role: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def get_organization_members(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
organization_id: UUID,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: bool = True
|
||||
) -> tuple[List[Dict[str, Any]], int]:
|
||||
"""
|
||||
Get members of an organization with user details.
|
||||
|
||||
Returns:
|
||||
Tuple of (members list with user details, total count)
|
||||
"""
|
||||
query = db.query(UserOrganization, User).join(
|
||||
User, UserOrganization.user_id == User.id
|
||||
).filter(UserOrganization.organization_id == organization_id)
|
||||
|
||||
if is_active is not None:
|
||||
query = query.filter(UserOrganization.is_active == is_active)
|
||||
|
||||
total = query.count()
|
||||
|
||||
results = query.order_by(UserOrganization.created_at.desc()).offset(skip).limit(limit).all()
|
||||
|
||||
members = []
|
||||
for user_org, user in results:
|
||||
members.append({
|
||||
"user_id": user.id,
|
||||
"email": user.email,
|
||||
"first_name": user.first_name,
|
||||
"last_name": user.last_name,
|
||||
"role": user_org.role,
|
||||
"is_active": user_org.is_active,
|
||||
"joined_at": user_org.created_at
|
||||
})
|
||||
|
||||
return members, total
|
||||
|
||||
def get_user_organizations(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
user_id: UUID,
|
||||
is_active: bool = True
|
||||
) -> List[Organization]:
|
||||
"""Get all organizations a user belongs to."""
|
||||
query = db.query(Organization).join(
|
||||
UserOrganization, Organization.id == UserOrganization.organization_id
|
||||
).filter(UserOrganization.user_id == user_id)
|
||||
|
||||
if is_active is not None:
|
||||
query = query.filter(UserOrganization.is_active == is_active)
|
||||
|
||||
return query.all()
|
||||
|
||||
def get_user_role_in_org(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
user_id: UUID,
|
||||
organization_id: UUID
|
||||
) -> Optional[OrganizationRole]:
|
||||
"""Get a user's role in a specific organization."""
|
||||
user_org = db.query(UserOrganization).filter(
|
||||
and_(
|
||||
UserOrganization.user_id == user_id,
|
||||
UserOrganization.organization_id == organization_id,
|
||||
UserOrganization.is_active == True
|
||||
)
|
||||
).first()
|
||||
|
||||
return user_org.role if user_org else None
|
||||
|
||||
def is_user_org_owner(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
user_id: UUID,
|
||||
organization_id: UUID
|
||||
) -> bool:
|
||||
"""Check if a user is an owner of an organization."""
|
||||
role = self.get_user_role_in_org(db, user_id=user_id, organization_id=organization_id)
|
||||
return role == OrganizationRole.OWNER
|
||||
|
||||
def is_user_org_admin(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
user_id: UUID,
|
||||
organization_id: UUID
|
||||
) -> bool:
|
||||
"""Check if a user is an owner or admin of an organization."""
|
||||
role = self.get_user_role_in_org(db, user_id=user_id, organization_id=organization_id)
|
||||
return role in [OrganizationRole.OWNER, OrganizationRole.ADMIN]
|
||||
|
||||
|
||||
# Create a singleton instance for use across the application
|
||||
organization = CRUDOrganization(Organization)
|
||||
@@ -1,339 +0,0 @@
|
||||
"""
|
||||
CRUD operations for user sessions.
|
||||
"""
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import and_
|
||||
import logging
|
||||
|
||||
from app.crud.base import CRUDBase
|
||||
from app.models.user_session import UserSession
|
||||
from app.schemas.sessions import SessionCreate, SessionUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
"""CRUD operations for user sessions."""
|
||||
|
||||
def get_by_jti(self, db: Session, *, jti: str) -> Optional[UserSession]:
|
||||
"""
|
||||
Get session by refresh token JTI.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
jti: Refresh token JWT ID
|
||||
|
||||
Returns:
|
||||
UserSession if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
return db.query(UserSession).filter(
|
||||
UserSession.refresh_token_jti == jti
|
||||
).first()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting session by JTI {jti}: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_active_by_jti(self, db: Session, *, jti: str) -> Optional[UserSession]:
|
||||
"""
|
||||
Get active session by refresh token JTI.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
jti: Refresh token JWT ID
|
||||
|
||||
Returns:
|
||||
Active UserSession if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
return db.query(UserSession).filter(
|
||||
and_(
|
||||
UserSession.refresh_token_jti == jti,
|
||||
UserSession.is_active == True
|
||||
)
|
||||
).first()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting active session by JTI {jti}: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_user_sessions(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
user_id: str,
|
||||
active_only: bool = True
|
||||
) -> List[UserSession]:
|
||||
"""
|
||||
Get all sessions for a user.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_id: User ID
|
||||
active_only: If True, return only active sessions
|
||||
|
||||
Returns:
|
||||
List of UserSession objects
|
||||
"""
|
||||
try:
|
||||
# Convert user_id string to UUID if needed
|
||||
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||
|
||||
query = db.query(UserSession).filter(UserSession.user_id == user_uuid)
|
||||
|
||||
if active_only:
|
||||
query = query.filter(UserSession.is_active == True)
|
||||
|
||||
return query.order_by(UserSession.last_used_at.desc()).all()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting sessions for user {user_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
def create_session(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
obj_in: SessionCreate
|
||||
) -> UserSession:
|
||||
"""
|
||||
Create a new user session.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
obj_in: SessionCreate schema with session data
|
||||
|
||||
Returns:
|
||||
Created UserSession
|
||||
|
||||
Raises:
|
||||
ValueError: If session creation fails
|
||||
"""
|
||||
try:
|
||||
db_obj = UserSession(
|
||||
user_id=obj_in.user_id,
|
||||
refresh_token_jti=obj_in.refresh_token_jti,
|
||||
device_name=obj_in.device_name,
|
||||
device_id=obj_in.device_id,
|
||||
ip_address=obj_in.ip_address,
|
||||
user_agent=obj_in.user_agent,
|
||||
last_used_at=obj_in.last_used_at,
|
||||
expires_at=obj_in.expires_at,
|
||||
is_active=True,
|
||||
location_city=obj_in.location_city,
|
||||
location_country=obj_in.location_country,
|
||||
)
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
|
||||
logger.info(
|
||||
f"Session created for user {obj_in.user_id} from {obj_in.device_name} "
|
||||
f"(IP: {obj_in.ip_address})"
|
||||
)
|
||||
|
||||
return db_obj
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error creating session: {str(e)}", exc_info=True)
|
||||
raise ValueError(f"Failed to create session: {str(e)}")
|
||||
|
||||
def deactivate(self, db: Session, *, session_id: str) -> Optional[UserSession]:
|
||||
"""
|
||||
Deactivate a session (logout from device).
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
session_id: Session UUID
|
||||
|
||||
Returns:
|
||||
Deactivated UserSession if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
session = self.get(db, id=session_id)
|
||||
if not session:
|
||||
logger.warning(f"Session {session_id} not found for deactivation")
|
||||
return None
|
||||
|
||||
session.is_active = False
|
||||
db.add(session)
|
||||
db.commit()
|
||||
db.refresh(session)
|
||||
|
||||
logger.info(
|
||||
f"Session {session_id} deactivated for user {session.user_id} "
|
||||
f"({session.device_name})"
|
||||
)
|
||||
|
||||
return session
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error deactivating session {session_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
def deactivate_all_user_sessions(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
user_id: str
|
||||
) -> int:
|
||||
"""
|
||||
Deactivate all active sessions for a user (logout from all devices).
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
Number of sessions deactivated
|
||||
"""
|
||||
try:
|
||||
# Convert user_id string to UUID if needed
|
||||
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||
|
||||
count = db.query(UserSession).filter(
|
||||
and_(
|
||||
UserSession.user_id == user_uuid,
|
||||
UserSession.is_active == True
|
||||
)
|
||||
).update({"is_active": False})
|
||||
|
||||
db.commit()
|
||||
|
||||
logger.info(f"Deactivated {count} sessions for user {user_id}")
|
||||
|
||||
return count
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error deactivating all sessions for user {user_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
def update_last_used(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
session: UserSession
|
||||
) -> UserSession:
|
||||
"""
|
||||
Update the last_used_at timestamp for a session.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
session: UserSession object
|
||||
|
||||
Returns:
|
||||
Updated UserSession
|
||||
"""
|
||||
try:
|
||||
session.last_used_at = datetime.now(timezone.utc)
|
||||
db.add(session)
|
||||
db.commit()
|
||||
db.refresh(session)
|
||||
return session
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error updating last_used for session {session.id}: {str(e)}")
|
||||
raise
|
||||
|
||||
def update_refresh_token(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
session: UserSession,
|
||||
new_jti: str,
|
||||
new_expires_at: datetime
|
||||
) -> UserSession:
|
||||
"""
|
||||
Update session with new refresh token JTI and expiration.
|
||||
|
||||
Called during token refresh.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
session: UserSession object
|
||||
new_jti: New refresh token JTI
|
||||
new_expires_at: New expiration datetime
|
||||
|
||||
Returns:
|
||||
Updated UserSession
|
||||
"""
|
||||
try:
|
||||
session.refresh_token_jti = new_jti
|
||||
session.expires_at = new_expires_at
|
||||
session.last_used_at = datetime.now(timezone.utc)
|
||||
db.add(session)
|
||||
db.commit()
|
||||
db.refresh(session)
|
||||
return session
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error updating refresh token for session {session.id}: {str(e)}")
|
||||
raise
|
||||
|
||||
def cleanup_expired(self, db: Session, *, keep_days: int = 30) -> int:
|
||||
"""
|
||||
Clean up expired sessions.
|
||||
|
||||
Deletes sessions that are:
|
||||
- Expired AND inactive
|
||||
- Older than keep_days
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
keep_days: Keep inactive sessions for this many days (for audit)
|
||||
|
||||
Returns:
|
||||
Number of sessions deleted
|
||||
"""
|
||||
try:
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=keep_days)
|
||||
|
||||
# Delete sessions that are:
|
||||
# 1. Expired (expires_at < now) AND inactive
|
||||
# AND
|
||||
# 2. Older than keep_days
|
||||
count = db.query(UserSession).filter(
|
||||
and_(
|
||||
UserSession.is_active == False,
|
||||
UserSession.expires_at < datetime.now(timezone.utc),
|
||||
UserSession.created_at < cutoff_date
|
||||
)
|
||||
).delete()
|
||||
|
||||
db.commit()
|
||||
|
||||
if count > 0:
|
||||
logger.info(f"Cleaned up {count} expired sessions")
|
||||
|
||||
return count
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error cleaning up expired sessions: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_user_session_count(self, db: Session, *, user_id: str) -> int:
|
||||
"""
|
||||
Get count of active sessions for a user.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
Number of active sessions
|
||||
"""
|
||||
try:
|
||||
return db.query(UserSession).filter(
|
||||
and_(
|
||||
UserSession.user_id == user_id,
|
||||
UserSession.is_active == True
|
||||
)
|
||||
).count()
|
||||
except Exception as e:
|
||||
logger.error(f"Error counting sessions for user {user_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
# Create singleton instance
|
||||
session = CRUDSession(UserSession)
|
||||
@@ -1,151 +0,0 @@
|
||||
# app/crud/user.py
|
||||
from typing import Optional, Union, Dict, Any, List, Tuple
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy import or_, asc, desc
|
||||
from app.crud.base import CRUDBase
|
||||
from app.models.user import User
|
||||
from app.schemas.users import UserCreate, UserUpdate
|
||||
from app.core.auth import get_password_hash
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
def get_by_email(self, db: Session, *, email: str) -> Optional[User]:
|
||||
return db.query(User).filter(User.email == email).first()
|
||||
|
||||
def create(self, db: Session, *, obj_in: UserCreate) -> User:
|
||||
"""Create a new user with password hashing and error handling."""
|
||||
try:
|
||||
db_obj = User(
|
||||
email=obj_in.email,
|
||||
password_hash=get_password_hash(obj_in.password),
|
||||
first_name=obj_in.first_name,
|
||||
last_name=obj_in.last_name,
|
||||
phone_number=obj_in.phone_number if hasattr(obj_in, 'phone_number') else None,
|
||||
is_superuser=obj_in.is_superuser if hasattr(obj_in, 'is_superuser') else False,
|
||||
preferences={}
|
||||
)
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
||||
if "email" in error_msg.lower():
|
||||
logger.warning(f"Duplicate email attempted: {obj_in.email}")
|
||||
raise ValueError(f"User with email {obj_in.email} already exists")
|
||||
logger.error(f"Integrity error creating user: {error_msg}")
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Unexpected error creating user: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def update(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
db_obj: User,
|
||||
obj_in: Union[UserUpdate, Dict[str, Any]]
|
||||
) -> User:
|
||||
if isinstance(obj_in, dict):
|
||||
update_data = obj_in
|
||||
else:
|
||||
update_data = obj_in.model_dump(exclude_unset=True)
|
||||
|
||||
# Handle password separately if it exists in update data
|
||||
if "password" in update_data:
|
||||
update_data["password_hash"] = get_password_hash(update_data["password"])
|
||||
del update_data["password"]
|
||||
|
||||
return super().update(db, db_obj=db_obj, obj_in=update_data)
|
||||
|
||||
def get_multi_with_total(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
sort_by: Optional[str] = None,
|
||||
sort_order: str = "asc",
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
search: Optional[str] = None
|
||||
) -> Tuple[List[User], int]:
|
||||
"""
|
||||
Get multiple users with total count, filtering, sorting, and search.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
skip: Number of records to skip
|
||||
limit: Maximum number of records to return
|
||||
sort_by: Field name to sort by
|
||||
sort_order: Sort order ("asc" or "desc")
|
||||
filters: Dictionary of filters (field_name: value)
|
||||
search: Search term to match against email, first_name, last_name
|
||||
|
||||
Returns:
|
||||
Tuple of (users list, total count)
|
||||
"""
|
||||
# Validate pagination
|
||||
if skip < 0:
|
||||
raise ValueError("skip must be non-negative")
|
||||
if limit < 0:
|
||||
raise ValueError("limit must be non-negative")
|
||||
if limit > 1000:
|
||||
raise ValueError("Maximum limit is 1000")
|
||||
|
||||
try:
|
||||
# Build base query
|
||||
query = db.query(User)
|
||||
|
||||
# Exclude soft-deleted users
|
||||
query = query.filter(User.deleted_at.is_(None))
|
||||
|
||||
# Apply filters
|
||||
if filters:
|
||||
for field, value in filters.items():
|
||||
if hasattr(User, field) and value is not None:
|
||||
query = query.filter(getattr(User, field) == value)
|
||||
|
||||
# Apply search
|
||||
if search:
|
||||
search_filter = or_(
|
||||
User.email.ilike(f"%{search}%"),
|
||||
User.first_name.ilike(f"%{search}%"),
|
||||
User.last_name.ilike(f"%{search}%")
|
||||
)
|
||||
query = query.filter(search_filter)
|
||||
|
||||
# Get total count
|
||||
total = query.count()
|
||||
|
||||
# Apply sorting
|
||||
if sort_by and hasattr(User, sort_by):
|
||||
sort_column = getattr(User, sort_by)
|
||||
if sort_order.lower() == "desc":
|
||||
query = query.order_by(desc(sort_column))
|
||||
else:
|
||||
query = query.order_by(asc(sort_column))
|
||||
|
||||
# Apply pagination
|
||||
users = query.offset(skip).limit(limit).all()
|
||||
|
||||
return users, total
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving paginated users: {str(e)}")
|
||||
raise
|
||||
|
||||
def is_active(self, user: User) -> bool:
|
||||
return user.is_active
|
||||
|
||||
def is_superuser(self, user: User) -> bool:
|
||||
return user.is_superuser
|
||||
|
||||
|
||||
# Create a singleton instance for use across the application
|
||||
user = CRUDUser(User)
|
||||
@@ -1,16 +1,31 @@
|
||||
# app/init_db.py
|
||||
"""
|
||||
Async database initialization script.
|
||||
|
||||
Creates the first superuser if configured and doesn't already exist.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional
|
||||
from sqlalchemy.orm import Session
|
||||
import random
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy import select, text
|
||||
|
||||
from app.core.config import settings
|
||||
from app.crud.user import user as user_crud
|
||||
from app.core.database import SessionLocal, engine
|
||||
from app.models.organization import Organization
|
||||
from app.models.user import User
|
||||
from app.models.user_organization import UserOrganization
|
||||
from app.repositories.user import user_repo as user_repo
|
||||
from app.schemas.users import UserCreate
|
||||
from app.core.database import engine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def init_db(db: Session) -> Optional[UserCreate]:
|
||||
async def init_db() -> User | None:
|
||||
"""
|
||||
Initialize database with first superuser if settings are configured and user doesn't exist.
|
||||
|
||||
@@ -19,58 +34,197 @@ def init_db(db: Session) -> Optional[UserCreate]:
|
||||
"""
|
||||
# Use default values if not set in environment variables
|
||||
superuser_email = settings.FIRST_SUPERUSER_EMAIL or "admin@example.com"
|
||||
superuser_password = settings.FIRST_SUPERUSER_PASSWORD or "Admin123!Change"
|
||||
|
||||
default_password = "AdminPassword123!"
|
||||
if settings.DEMO_MODE:
|
||||
default_password = "AdminPass1234!"
|
||||
|
||||
superuser_password = settings.FIRST_SUPERUSER_PASSWORD or default_password
|
||||
|
||||
if not settings.FIRST_SUPERUSER_EMAIL or not settings.FIRST_SUPERUSER_PASSWORD:
|
||||
logger.warning(
|
||||
"First superuser credentials not configured in settings. "
|
||||
f"Using defaults: {superuser_email}"
|
||||
"Using defaults: %s",
|
||||
superuser_email,
|
||||
)
|
||||
|
||||
async with SessionLocal() as session:
|
||||
try:
|
||||
# Check if superuser already exists
|
||||
existing_user = await user_repo.get_by_email(session, email=superuser_email)
|
||||
|
||||
if existing_user:
|
||||
logger.info("Superuser already exists: %s", existing_user.email)
|
||||
return existing_user
|
||||
|
||||
# Create superuser if doesn't exist
|
||||
user_in = UserCreate(
|
||||
email=superuser_email,
|
||||
password=superuser_password,
|
||||
first_name="Admin",
|
||||
last_name="User",
|
||||
is_superuser=True,
|
||||
)
|
||||
|
||||
user = await user_repo.create(session, obj_in=user_in)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
|
||||
logger.info("Created first superuser: %s", user.email)
|
||||
|
||||
# Create demo data if in demo mode
|
||||
if settings.DEMO_MODE:
|
||||
await load_demo_data(session)
|
||||
|
||||
return user
|
||||
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.error("Error initializing database: %s", e)
|
||||
raise
|
||||
|
||||
|
||||
def _load_json_file(path: Path):
|
||||
with open(path) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
async def load_demo_data(session):
|
||||
"""Load demo data from JSON file."""
|
||||
demo_data_path = Path(__file__).parent / "core" / "demo_data.json"
|
||||
if not demo_data_path.exists():
|
||||
logger.warning("Demo data file not found: %s", demo_data_path)
|
||||
return
|
||||
|
||||
try:
|
||||
# Check if superuser already exists
|
||||
existing_user = user_crud.get_by_email(db, email=superuser_email)
|
||||
# Use asyncio.to_thread to avoid blocking the event loop
|
||||
data = await asyncio.to_thread(_load_json_file, demo_data_path)
|
||||
|
||||
if existing_user:
|
||||
logger.info(f"Superuser already exists: {existing_user.email}")
|
||||
return existing_user
|
||||
# Create Organizations
|
||||
org_map = {}
|
||||
for org_data in data.get("organizations", []):
|
||||
# Check if org exists
|
||||
result = await session.execute(
|
||||
text("SELECT * FROM organizations WHERE slug = :slug"),
|
||||
{"slug": org_data["slug"]},
|
||||
)
|
||||
existing_org = result.first()
|
||||
|
||||
# Create superuser if doesn't exist
|
||||
user_in = UserCreate(
|
||||
email=superuser_email,
|
||||
password=superuser_password,
|
||||
first_name="Admin",
|
||||
last_name="User",
|
||||
is_superuser=True
|
||||
)
|
||||
if not existing_org:
|
||||
org = Organization(
|
||||
name=org_data["name"],
|
||||
slug=org_data["slug"],
|
||||
description=org_data.get("description"),
|
||||
is_active=True,
|
||||
)
|
||||
session.add(org)
|
||||
await session.flush() # Flush to get ID
|
||||
org_map[org.slug] = org
|
||||
logger.info("Created demo organization: %s", org.name)
|
||||
else:
|
||||
# We can't easily get the ORM object from raw SQL result for map without querying again or mapping
|
||||
# So let's just query it properly if we need it for relationships
|
||||
# But for simplicity in this script, let's just assume we created it or it exists.
|
||||
# To properly map for users, we need the ID.
|
||||
# Let's use a simpler approach: just try to create, if slug conflict, skip.
|
||||
pass
|
||||
|
||||
user = user_crud.create(db, obj_in=user_in)
|
||||
logger.info(f"Created first superuser: {user.email}")
|
||||
# Re-query all orgs to build map for users
|
||||
result = await session.execute(select(Organization))
|
||||
orgs = result.scalars().all()
|
||||
org_map = {org.slug: org for org in orgs}
|
||||
|
||||
return user
|
||||
# Create Users
|
||||
for user_data in data.get("users", []):
|
||||
existing_user = await user_repo.get_by_email(
|
||||
session, email=user_data["email"]
|
||||
)
|
||||
if not existing_user:
|
||||
# Create user
|
||||
user_in = UserCreate(
|
||||
email=user_data["email"],
|
||||
password=user_data["password"],
|
||||
first_name=user_data["first_name"],
|
||||
last_name=user_data["last_name"],
|
||||
is_superuser=user_data["is_superuser"],
|
||||
is_active=user_data.get("is_active", True),
|
||||
)
|
||||
user = await user_repo.create(session, obj_in=user_in)
|
||||
|
||||
# Randomize created_at for demo data (last 30 days)
|
||||
# This makes the charts look more realistic
|
||||
days_ago = random.randint(0, 30) # noqa: S311
|
||||
random_time = datetime.now(UTC) - timedelta(days=days_ago)
|
||||
# Add some random hours/minutes variation
|
||||
random_time = random_time.replace(
|
||||
hour=random.randint(0, 23), # noqa: S311
|
||||
minute=random.randint(0, 59), # noqa: S311
|
||||
)
|
||||
|
||||
# Update the timestamp and is_active directly in the database
|
||||
# We do this to ensure the values are persisted correctly
|
||||
await session.execute(
|
||||
text(
|
||||
"UPDATE users SET created_at = :created_at, is_active = :is_active WHERE id = :user_id"
|
||||
),
|
||||
{
|
||||
"created_at": random_time,
|
||||
"is_active": user_data.get("is_active", True),
|
||||
"user_id": user.id,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Created demo user: %s (created %s days ago, active=%s)",
|
||||
user.email,
|
||||
days_ago,
|
||||
user_data.get("is_active", True),
|
||||
)
|
||||
|
||||
# Add to organization if specified
|
||||
org_slug = user_data.get("organization_slug")
|
||||
role = user_data.get("role")
|
||||
if org_slug and org_slug in org_map and role:
|
||||
org = org_map[org_slug]
|
||||
# Check if membership exists (it shouldn't for new user)
|
||||
member = UserOrganization(
|
||||
user_id=user.id, organization_id=org.id, role=role
|
||||
)
|
||||
session.add(member)
|
||||
logger.info("Added %s to %s as %s", user.email, org.name, role)
|
||||
else:
|
||||
logger.info("Demo user already exists: %s", existing_user.email)
|
||||
|
||||
await session.commit()
|
||||
logger.info("Demo data loaded successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing database: {e}")
|
||||
logger.error("Error loading demo data: %s", e)
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
async def main():
|
||||
"""Main entry point for database initialization."""
|
||||
# Configure logging to show info logs
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
|
||||
with Session(engine) as session:
|
||||
try:
|
||||
user = init_db(session)
|
||||
if user:
|
||||
print(f"✓ Database initialized successfully")
|
||||
print(f"✓ Superuser: {user.email}")
|
||||
else:
|
||||
print("✗ Failed to initialize database")
|
||||
except Exception as e:
|
||||
print(f"✗ Error initializing database: {e}")
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
try:
|
||||
user = await init_db()
|
||||
if user:
|
||||
print("✓ Database initialized successfully")
|
||||
print(f"✓ Superuser: {user.email}")
|
||||
else:
|
||||
print("✗ Failed to initialize database")
|
||||
except Exception as e:
|
||||
print(f"✗ Error initializing database: {e}")
|
||||
raise
|
||||
finally:
|
||||
# Close the engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
187
backend/app/main.py
Normal file → Executable file
187
backend/app/main.py
Normal file → Executable file
@@ -1,26 +1,28 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from fastapi import FastAPI, status, Request, HTTPException
|
||||
from fastapi import FastAPI, HTTPException, Request, status
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import HTMLResponse, JSONResponse
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from slowapi import Limiter, _rate_limit_exceeded_handler
|
||||
from slowapi.util import get_remote_address
|
||||
from slowapi.errors import RateLimitExceeded
|
||||
from sqlalchemy import text
|
||||
from slowapi.util import get_remote_address
|
||||
|
||||
from app.api.main import api_router
|
||||
from app.api.routes.oauth_provider import wellknown_router as oauth_wellknown_router
|
||||
from app.core.config import settings
|
||||
from app.core.database import get_db, check_database_health
|
||||
from app.core.database import check_database_health, close_async_db
|
||||
from app.core.exceptions import (
|
||||
APIException,
|
||||
api_exception_handler,
|
||||
validation_exception_handler,
|
||||
http_exception_handler,
|
||||
unhandled_exception_handler
|
||||
unhandled_exception_handler,
|
||||
validation_exception_handler,
|
||||
)
|
||||
|
||||
scheduler = AsyncIOScheduler()
|
||||
@@ -30,11 +32,55 @@ logger = logging.getLogger(__name__)
|
||||
# Initialize rate limiter
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
logger.info(f"Starting app!!!")
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""
|
||||
Application lifespan context manager.
|
||||
|
||||
Handles startup and shutdown events for the application.
|
||||
Sets up background jobs and scheduled tasks on startup,
|
||||
cleans up resources on shutdown.
|
||||
"""
|
||||
# Startup
|
||||
logger.info("Application starting up...")
|
||||
|
||||
# Skip scheduler in test environment
|
||||
if os.getenv("IS_TEST", "False") != "True":
|
||||
from app.services.session_cleanup import cleanup_expired_sessions
|
||||
|
||||
# Schedule session cleanup job
|
||||
# Runs daily at 2:00 AM server time
|
||||
scheduler.add_job(
|
||||
cleanup_expired_sessions,
|
||||
"cron",
|
||||
hour=2,
|
||||
minute=0,
|
||||
id="cleanup_expired_sessions",
|
||||
replace_existing=True,
|
||||
)
|
||||
|
||||
scheduler.start()
|
||||
logger.info("Scheduled jobs started: session cleanup (daily at 2 AM)")
|
||||
else:
|
||||
logger.info("Test environment detected - skipping scheduler")
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
logger.info("Application shutting down...")
|
||||
if os.getenv("IS_TEST", "False") != "True":
|
||||
scheduler.shutdown()
|
||||
logger.info("Scheduled jobs stopped")
|
||||
await close_async_db()
|
||||
|
||||
|
||||
logger.info("Starting app!!!")
|
||||
app = FastAPI(
|
||||
title=settings.PROJECT_NAME,
|
||||
version=settings.VERSION,
|
||||
openapi_url=f"{settings.API_V1_STR}/openapi.json"
|
||||
openapi_url=f"{settings.API_V1_STR}/openapi.json",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# Add rate limiter state to app
|
||||
@@ -52,7 +98,14 @@ app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.BACKEND_CORS_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"], # Explicit methods only
|
||||
allow_methods=[
|
||||
"GET",
|
||||
"POST",
|
||||
"PUT",
|
||||
"PATCH",
|
||||
"DELETE",
|
||||
"OPTIONS",
|
||||
], # Explicit methods only
|
||||
allow_headers=[
|
||||
"Content-Type",
|
||||
"Authorization",
|
||||
@@ -69,6 +122,36 @@ app.add_middleware(
|
||||
)
|
||||
|
||||
|
||||
# Add request size limit middleware
|
||||
@app.middleware("http")
|
||||
async def limit_request_size(request: Request, call_next):
|
||||
"""
|
||||
Limit request body size to prevent DoS attacks via large payloads.
|
||||
|
||||
Max size: 10MB for file uploads and large payloads.
|
||||
"""
|
||||
MAX_REQUEST_SIZE = 10 * 1024 * 1024 # 10MB in bytes
|
||||
|
||||
content_length = request.headers.get("content-length")
|
||||
if content_length and int(content_length) > MAX_REQUEST_SIZE:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
|
||||
content={
|
||||
"success": False,
|
||||
"errors": [
|
||||
{
|
||||
"code": "REQUEST_TOO_LARGE",
|
||||
"message": f"Request body too large. Maximum size is {MAX_REQUEST_SIZE // (1024 * 1024)}MB",
|
||||
"field": None,
|
||||
}
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
|
||||
# Add security headers middleware
|
||||
@app.middleware("http")
|
||||
async def add_security_headers(request: Request, call_next):
|
||||
@@ -93,15 +176,19 @@ async def add_security_headers(request: Request, call_next):
|
||||
|
||||
# Enforce HTTPS in production
|
||||
if settings.ENVIRONMENT == "production":
|
||||
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
|
||||
response.headers["Strict-Transport-Security"] = (
|
||||
"max-age=31536000; includeSubDomains"
|
||||
)
|
||||
|
||||
# Content Security Policy
|
||||
csp_mode = settings.CSP_MODE.lower()
|
||||
|
||||
# Special handling for API docs
|
||||
is_docs = request.url.path in ["/docs", "/redoc"] or \
|
||||
request.url.path.startswith("/docs/") or \
|
||||
request.url.path.startswith("/redoc/")
|
||||
is_docs = (
|
||||
request.url.path in ["/docs", "/redoc"]
|
||||
or request.url.path.startswith("/docs/")
|
||||
or request.url.path.startswith("/redoc/")
|
||||
)
|
||||
|
||||
if csp_mode == "disabled":
|
||||
# No CSP (only for local development/debugging)
|
||||
@@ -192,7 +279,7 @@ async def root():
|
||||
description="Check the health status of the API and its dependencies",
|
||||
response_description="Health status information",
|
||||
tags=["Health"],
|
||||
operation_id="health_check"
|
||||
operation_id="health_check",
|
||||
)
|
||||
async def health_check() -> JSONResponse:
|
||||
"""
|
||||
@@ -206,23 +293,23 @@ async def health_check() -> JSONResponse:
|
||||
- environment: Current environment (development, staging, production)
|
||||
- database: Database connectivity status
|
||||
"""
|
||||
health_status: Dict[str, Any] = {
|
||||
health_status: dict[str, Any] = {
|
||||
"status": "healthy",
|
||||
"timestamp": datetime.utcnow().isoformat() + "Z",
|
||||
"timestamp": datetime.now(UTC).isoformat().replace("+00:00", "Z"),
|
||||
"version": settings.VERSION,
|
||||
"environment": settings.ENVIRONMENT,
|
||||
"checks": {}
|
||||
"checks": {},
|
||||
}
|
||||
|
||||
response_status = status.HTTP_200_OK
|
||||
|
||||
# Database health check using dedicated health check function
|
||||
try:
|
||||
db_healthy = check_database_health()
|
||||
db_healthy = await check_database_health()
|
||||
if db_healthy:
|
||||
health_status["checks"]["database"] = {
|
||||
"status": "healthy",
|
||||
"message": "Database connection successful"
|
||||
"message": "Database connection successful",
|
||||
}
|
||||
else:
|
||||
raise Exception("Database health check returned unhealthy status")
|
||||
@@ -230,60 +317,16 @@ async def health_check() -> JSONResponse:
|
||||
health_status["status"] = "unhealthy"
|
||||
health_status["checks"]["database"] = {
|
||||
"status": "unhealthy",
|
||||
"message": f"Database connection failed: {str(e)}"
|
||||
"message": f"Database connection failed: {e!s}",
|
||||
}
|
||||
response_status = status.HTTP_503_SERVICE_UNAVAILABLE
|
||||
logger.error(f"Health check failed - database error: {e}")
|
||||
logger.error("Health check failed - database error: %s", e)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=response_status,
|
||||
content=health_status
|
||||
)
|
||||
return JSONResponse(status_code=response_status, content=health_status)
|
||||
|
||||
|
||||
app.include_router(api_router, prefix=settings.API_V1_STR)
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""
|
||||
Application startup event.
|
||||
|
||||
Sets up background jobs and scheduled tasks.
|
||||
"""
|
||||
import os
|
||||
|
||||
# Skip scheduler in test environment
|
||||
if os.getenv("IS_TEST", "False") == "True":
|
||||
logger.info("Test environment detected - skipping scheduler")
|
||||
return
|
||||
|
||||
from app.services.session_cleanup import cleanup_expired_sessions
|
||||
|
||||
# Schedule session cleanup job
|
||||
# Runs daily at 2:00 AM server time
|
||||
scheduler.add_job(
|
||||
cleanup_expired_sessions,
|
||||
'cron',
|
||||
hour=2,
|
||||
minute=0,
|
||||
id='cleanup_expired_sessions',
|
||||
replace_existing=True
|
||||
)
|
||||
|
||||
scheduler.start()
|
||||
logger.info("Scheduled jobs started: session cleanup (daily at 2 AM)")
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
"""
|
||||
Application shutdown event.
|
||||
|
||||
Cleans up resources and stops background jobs.
|
||||
"""
|
||||
import os
|
||||
|
||||
if os.getenv("IS_TEST", "False") != "True":
|
||||
scheduler.shutdown()
|
||||
logger.info("Scheduled jobs stopped")
|
||||
# OAuth 2.0 well-known endpoint at root level per RFC 8414
|
||||
# This allows MCP clients to discover the OAuth server metadata at /.well-known/oauth-authorization-server
|
||||
app.include_router(oauth_wellknown_router)
|
||||
|
||||
@@ -2,18 +2,40 @@
|
||||
Models package initialization.
|
||||
Imports all models to ensure they're registered with SQLAlchemy.
|
||||
"""
|
||||
|
||||
# First import Base to avoid circular imports
|
||||
from app.core.database import Base
|
||||
|
||||
from .base import TimestampMixin, UUIDMixin
|
||||
|
||||
# OAuth models (client mode - authenticate via Google/GitHub)
|
||||
from .oauth_account import OAuthAccount
|
||||
|
||||
# OAuth provider models (server mode - act as authorization server for MCP)
|
||||
from .oauth_authorization_code import OAuthAuthorizationCode
|
||||
from .oauth_client import OAuthClient
|
||||
from .oauth_provider_token import OAuthConsent, OAuthProviderRefreshToken
|
||||
from .oauth_state import OAuthState
|
||||
from .organization import Organization
|
||||
|
||||
# Import models
|
||||
from .user import User
|
||||
from .user_organization import OrganizationRole, UserOrganization
|
||||
from .user_session import UserSession
|
||||
from .organization import Organization
|
||||
from .user_organization import UserOrganization, OrganizationRole
|
||||
|
||||
__all__ = [
|
||||
'Base', 'TimestampMixin', 'UUIDMixin',
|
||||
'User', 'UserSession',
|
||||
'Organization', 'UserOrganization', 'OrganizationRole',
|
||||
]
|
||||
"Base",
|
||||
"OAuthAccount",
|
||||
"OAuthAuthorizationCode",
|
||||
"OAuthClient",
|
||||
"OAuthConsent",
|
||||
"OAuthProviderRefreshToken",
|
||||
"OAuthState",
|
||||
"Organization",
|
||||
"OrganizationRole",
|
||||
"TimestampMixin",
|
||||
"UUIDMixin",
|
||||
"User",
|
||||
"UserOrganization",
|
||||
"UserSession",
|
||||
]
|
||||
|
||||
@@ -1,20 +1,28 @@
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import Column, DateTime
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
from app.core.database import Base
|
||||
from app.core.database import Base # Re-exported for other models
|
||||
|
||||
|
||||
class TimestampMixin:
|
||||
"""Mixin to add created_at and updated_at timestamps to models"""
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False)
|
||||
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc), nullable=False)
|
||||
|
||||
created_at = Column(
|
||||
DateTime(timezone=True), default=lambda: datetime.now(UTC), nullable=False
|
||||
)
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(UTC),
|
||||
onupdate=lambda: datetime.now(UTC),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
|
||||
class UUIDMixin:
|
||||
"""Mixin to add UUID primary keys to models"""
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
|
||||
55
backend/app/models/oauth_account.py
Executable file
55
backend/app/models/oauth_account.py
Executable file
@@ -0,0 +1,55 @@
|
||||
"""OAuth account model for linking external OAuth providers to users."""
|
||||
|
||||
from sqlalchemy import Column, DateTime, ForeignKey, Index, String, UniqueConstraint
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from .base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
|
||||
class OAuthAccount(Base, UUIDMixin, TimestampMixin):
|
||||
"""
|
||||
Links OAuth provider accounts to users.
|
||||
|
||||
Supports multiple OAuth providers per user (e.g., user can have both
|
||||
Google and GitHub connected). Each provider account is uniquely identified
|
||||
by (provider, provider_user_id).
|
||||
"""
|
||||
|
||||
__tablename__ = "oauth_accounts"
|
||||
|
||||
# Link to user
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# OAuth provider identification
|
||||
provider = Column(
|
||||
String(50), nullable=False, index=True
|
||||
) # google, github, microsoft
|
||||
provider_user_id = Column(String(255), nullable=False) # Provider's unique user ID
|
||||
provider_email = Column(
|
||||
String(255), nullable=True, index=True
|
||||
) # Email from provider (for reference)
|
||||
|
||||
# Optional: store provider tokens for API access
|
||||
# TODO: Encrypt these at rest in production (requires key management infrastructure)
|
||||
access_token = Column(String(2048), nullable=True)
|
||||
refresh_token = Column(String(2048), nullable=True)
|
||||
token_expires_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Relationship
|
||||
user = relationship("User", back_populates="oauth_accounts")
|
||||
|
||||
__table_args__ = (
|
||||
# Each provider account can only be linked to one user
|
||||
UniqueConstraint("provider", "provider_user_id", name="uq_oauth_provider_user"),
|
||||
# Index for finding all OAuth accounts for a user + provider
|
||||
Index("ix_oauth_accounts_user_provider", "user_id", "provider"),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<OAuthAccount {self.provider}:{self.provider_user_id}>"
|
||||
100
backend/app/models/oauth_authorization_code.py
Executable file
100
backend/app/models/oauth_authorization_code.py
Executable file
@@ -0,0 +1,100 @@
|
||||
"""OAuth authorization code model for OAuth provider mode."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Index, String
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from .base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
|
||||
class OAuthAuthorizationCode(Base, UUIDMixin, TimestampMixin):
|
||||
"""
|
||||
OAuth 2.0 Authorization Code for the authorization code flow.
|
||||
|
||||
Authorization codes are:
|
||||
- Single-use (marked as used after exchange)
|
||||
- Short-lived (10 minutes default)
|
||||
- Bound to specific client, user, redirect_uri
|
||||
- Support PKCE (code_challenge/code_challenge_method)
|
||||
|
||||
Security considerations:
|
||||
- Code must be cryptographically random (64 chars, URL-safe)
|
||||
- Must validate redirect_uri matches exactly
|
||||
- Must verify PKCE code_verifier for public clients
|
||||
- Must be consumed within expiration time
|
||||
|
||||
Performance indexes (defined in migration 0002_add_performance_indexes.py):
|
||||
- ix_perf_oauth_auth_codes_expires: expires_at WHERE used = false
|
||||
"""
|
||||
|
||||
__tablename__ = "oauth_authorization_codes"
|
||||
|
||||
# The authorization code (cryptographically random, URL-safe)
|
||||
code = Column(String(128), unique=True, nullable=False, index=True)
|
||||
|
||||
# Client that requested the code
|
||||
client_id = Column(
|
||||
String(64),
|
||||
ForeignKey("oauth_clients.client_id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# User who authorized the request
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Redirect URI (must match exactly on token exchange)
|
||||
redirect_uri = Column(String(2048), nullable=False)
|
||||
|
||||
# Granted scopes (space-separated)
|
||||
scope = Column(String(1000), nullable=False, default="")
|
||||
|
||||
# PKCE support (required for public clients)
|
||||
code_challenge = Column(String(128), nullable=True)
|
||||
code_challenge_method = Column(String(10), nullable=True) # "S256" or "plain"
|
||||
|
||||
# State parameter (for CSRF protection, returned to client)
|
||||
state = Column(String(256), nullable=True)
|
||||
|
||||
# Nonce (for OpenID Connect, included in ID token)
|
||||
nonce = Column(String(256), nullable=True)
|
||||
|
||||
# Expiration (codes are short-lived, typically 10 minutes)
|
||||
expires_at = Column(DateTime(timezone=True), nullable=False)
|
||||
|
||||
# Single-use flag (set to True after successful exchange)
|
||||
used = Column(Boolean, default=False, nullable=False)
|
||||
|
||||
# Relationships
|
||||
client = relationship("OAuthClient", backref="authorization_codes")
|
||||
user = relationship("User", backref="oauth_authorization_codes")
|
||||
|
||||
# Indexes for efficient cleanup queries
|
||||
__table_args__ = (
|
||||
Index("ix_oauth_authorization_codes_expires_at", "expires_at"),
|
||||
Index("ix_oauth_authorization_codes_client_user", "client_id", "user_id"),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<OAuthAuthorizationCode {self.code[:8]}... for {self.client_id}>"
|
||||
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if the authorization code has expired."""
|
||||
# Use timezone-aware comparison (datetime.utcnow() is deprecated)
|
||||
now = datetime.now(UTC)
|
||||
expires_at = self.expires_at
|
||||
# Handle both timezone-aware and naive datetimes from DB
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=UTC)
|
||||
return bool(now > expires_at)
|
||||
|
||||
@property
|
||||
def is_valid(self) -> bool:
|
||||
"""Check if the authorization code is valid (not used, not expired)."""
|
||||
return not self.used and not self.is_expired
|
||||
67
backend/app/models/oauth_client.py
Executable file
67
backend/app/models/oauth_client.py
Executable file
@@ -0,0 +1,67 @@
|
||||
"""OAuth client model for OAuth provider mode (MCP clients)."""
|
||||
|
||||
from sqlalchemy import Boolean, Column, ForeignKey, String
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from .base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
|
||||
class OAuthClient(Base, UUIDMixin, TimestampMixin):
|
||||
"""
|
||||
Registered OAuth clients (for OAuth provider mode).
|
||||
|
||||
This model stores third-party applications that can authenticate
|
||||
against this API using OAuth 2.0. Used for MCP (Model Context Protocol)
|
||||
client authentication and API access.
|
||||
|
||||
NOTE: This is a skeleton implementation. The full OAuth provider
|
||||
functionality (authorization endpoint, token endpoint, etc.) can be
|
||||
expanded when needed.
|
||||
"""
|
||||
|
||||
__tablename__ = "oauth_clients"
|
||||
|
||||
# Client credentials
|
||||
client_id = Column(String(64), unique=True, nullable=False, index=True)
|
||||
client_secret_hash = Column(
|
||||
String(255), nullable=True
|
||||
) # NULL for public clients (PKCE)
|
||||
|
||||
# Client metadata
|
||||
client_name = Column(String(255), nullable=False)
|
||||
client_description = Column(String(1000), nullable=True)
|
||||
|
||||
# Client type: "public" (SPA, mobile) or "confidential" (server-side)
|
||||
client_type = Column(String(20), nullable=False, default="public")
|
||||
|
||||
# Allowed redirect URIs (JSON array)
|
||||
redirect_uris = Column(JSONB, nullable=False, default=list)
|
||||
|
||||
# Allowed scopes (JSON array of scope names)
|
||||
allowed_scopes = Column(JSONB, nullable=False, default=list)
|
||||
|
||||
# Token lifetimes (in seconds)
|
||||
access_token_lifetime = Column(String(10), nullable=False, default="3600") # 1 hour
|
||||
refresh_token_lifetime = Column(
|
||||
String(10), nullable=False, default="604800"
|
||||
) # 7 days
|
||||
|
||||
# Status
|
||||
is_active = Column(Boolean, default=True, nullable=False, index=True)
|
||||
|
||||
# Optional: owner user (for user-registered applications)
|
||||
owner_user_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
# MCP-specific: URL of the MCP server this client represents
|
||||
mcp_server_url = Column(String(2048), nullable=True)
|
||||
|
||||
# Relationship
|
||||
owner = relationship("User", backref="owned_oauth_clients")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<OAuthClient {self.client_name} ({self.client_id[:8]}...)>"
|
||||
162
backend/app/models/oauth_provider_token.py
Executable file
162
backend/app/models/oauth_provider_token.py
Executable file
@@ -0,0 +1,162 @@
|
||||
"""OAuth provider token models for OAuth provider mode."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Index, String
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from .base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
|
||||
class OAuthProviderRefreshToken(Base, UUIDMixin, TimestampMixin):
|
||||
"""
|
||||
OAuth 2.0 Refresh Token for the OAuth provider.
|
||||
|
||||
Refresh tokens are:
|
||||
- Opaque (stored as hash in DB, actual token given to client)
|
||||
- Long-lived (configurable, default 30 days)
|
||||
- Revocable (via revoked flag or deletion)
|
||||
- Bound to specific client, user, and scope
|
||||
|
||||
Access tokens are JWTs and not stored in DB (self-contained).
|
||||
This model only tracks refresh tokens for revocation support.
|
||||
|
||||
Security considerations:
|
||||
- Store token hash, not plaintext
|
||||
- Support token rotation (new refresh token on use)
|
||||
- Track last used time for security auditing
|
||||
- Support revocation by user, client, or admin
|
||||
|
||||
Performance indexes (defined in migration 0002_add_performance_indexes.py):
|
||||
- ix_perf_oauth_refresh_tokens_expires: expires_at WHERE revoked = false
|
||||
"""
|
||||
|
||||
__tablename__ = "oauth_provider_refresh_tokens"
|
||||
|
||||
# Hash of the refresh token (SHA-256)
|
||||
# We store hash, not plaintext, for security
|
||||
token_hash = Column(String(64), unique=True, nullable=False, index=True)
|
||||
|
||||
# Unique token ID (JTI) - used in JWT access tokens to reference this refresh token
|
||||
jti = Column(String(64), unique=True, nullable=False, index=True)
|
||||
|
||||
# Client that owns this token
|
||||
client_id = Column(
|
||||
String(64),
|
||||
ForeignKey("oauth_clients.client_id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# User who authorized this token
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Granted scopes (space-separated)
|
||||
scope = Column(String(1000), nullable=False, default="")
|
||||
|
||||
# Token expiration
|
||||
expires_at = Column(DateTime(timezone=True), nullable=False)
|
||||
|
||||
# Revocation flag
|
||||
revoked = Column(Boolean, default=False, nullable=False, index=True)
|
||||
|
||||
# Last used timestamp (for security auditing)
|
||||
last_used_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Device/session info (optional, for user visibility)
|
||||
device_info = Column(String(500), nullable=True)
|
||||
ip_address = Column(String(45), nullable=True)
|
||||
|
||||
# Relationships
|
||||
client = relationship("OAuthClient", backref="refresh_tokens")
|
||||
user = relationship("User", backref="oauth_provider_refresh_tokens")
|
||||
|
||||
# Indexes
|
||||
__table_args__ = (
|
||||
Index("ix_oauth_provider_refresh_tokens_expires_at", "expires_at"),
|
||||
Index("ix_oauth_provider_refresh_tokens_client_user", "client_id", "user_id"),
|
||||
Index(
|
||||
"ix_oauth_provider_refresh_tokens_user_revoked",
|
||||
"user_id",
|
||||
"revoked",
|
||||
),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
status = "revoked" if self.revoked else "active"
|
||||
return f"<OAuthProviderRefreshToken {self.jti[:8]}... ({status})>"
|
||||
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if the refresh token has expired."""
|
||||
# Use timezone-aware comparison (datetime.utcnow() is deprecated)
|
||||
now = datetime.now(UTC)
|
||||
expires_at = self.expires_at
|
||||
# Handle both timezone-aware and naive datetimes from DB
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=UTC)
|
||||
return bool(now > expires_at)
|
||||
|
||||
@property
|
||||
def is_valid(self) -> bool:
|
||||
"""Check if the refresh token is valid (not revoked, not expired)."""
|
||||
return not self.revoked and not self.is_expired
|
||||
|
||||
|
||||
class OAuthConsent(Base, UUIDMixin, TimestampMixin):
|
||||
"""
|
||||
OAuth consent record - remembers user consent for a client.
|
||||
|
||||
When a user grants consent to an OAuth client, we store the record
|
||||
so they don't have to re-consent on subsequent authorizations
|
||||
(unless scopes change).
|
||||
|
||||
This enables a better UX - users only see consent screen once per client,
|
||||
unless the client requests additional scopes.
|
||||
"""
|
||||
|
||||
__tablename__ = "oauth_consents"
|
||||
|
||||
# User who granted consent
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Client that received consent
|
||||
client_id = Column(
|
||||
String(64),
|
||||
ForeignKey("oauth_clients.client_id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Granted scopes (space-separated)
|
||||
granted_scopes = Column(String(1000), nullable=False, default="")
|
||||
|
||||
# Relationships
|
||||
client = relationship("OAuthClient", backref="consents")
|
||||
user = relationship("User", backref="oauth_consents")
|
||||
|
||||
# Unique constraint: one consent record per user+client
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"ix_oauth_consents_user_client",
|
||||
"user_id",
|
||||
"client_id",
|
||||
unique=True,
|
||||
),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<OAuthConsent user={self.user_id} client={self.client_id}>"
|
||||
|
||||
def has_scopes(self, requested_scopes: list[str]) -> bool:
|
||||
"""Check if all requested scopes are already granted."""
|
||||
granted = set(self.granted_scopes.split()) if self.granted_scopes else set()
|
||||
requested = set(requested_scopes)
|
||||
return requested.issubset(granted)
|
||||
45
backend/app/models/oauth_state.py
Executable file
45
backend/app/models/oauth_state.py
Executable file
@@ -0,0 +1,45 @@
|
||||
"""OAuth state model for CSRF protection during OAuth flows."""
|
||||
|
||||
from sqlalchemy import Column, DateTime, String
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
from .base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
|
||||
class OAuthState(Base, UUIDMixin, TimestampMixin):
|
||||
"""
|
||||
Temporary storage for OAuth state parameters.
|
||||
|
||||
Prevents CSRF attacks during OAuth flows by storing a random state
|
||||
value that must match on callback. Also stores PKCE code_verifier
|
||||
for the Authorization Code flow with PKCE.
|
||||
|
||||
These records are short-lived (10 minutes by default) and should
|
||||
be deleted after use or expiration.
|
||||
"""
|
||||
|
||||
__tablename__ = "oauth_states"
|
||||
|
||||
# Random state parameter (CSRF protection)
|
||||
state = Column(String(255), unique=True, nullable=False, index=True)
|
||||
|
||||
# PKCE code_verifier (used to generate code_challenge)
|
||||
code_verifier = Column(String(128), nullable=True)
|
||||
|
||||
# OIDC nonce for ID token replay protection
|
||||
nonce = Column(String(255), nullable=True)
|
||||
|
||||
# OAuth provider (google, github, etc.)
|
||||
provider = Column(String(50), nullable=False)
|
||||
|
||||
# Original redirect URI (for callback validation)
|
||||
redirect_uri = Column(String(500), nullable=True)
|
||||
|
||||
# User ID if this is an account linking flow (user is already logged in)
|
||||
user_id = Column(UUID(as_uuid=True), nullable=True)
|
||||
|
||||
# Expiration time
|
||||
expires_at = Column(DateTime(timezone=True), nullable=False)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<OAuthState {self.state[:8]}... ({self.provider})>"
|
||||
@@ -1,5 +1,5 @@
|
||||
# app/models/organization.py
|
||||
from sqlalchemy import Column, String, Boolean, Text, Index
|
||||
from sqlalchemy import Boolean, Column, Index, String, Text
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
@@ -10,8 +10,12 @@ class Organization(Base, UUIDMixin, TimestampMixin):
|
||||
"""
|
||||
Organization model for multi-tenant support.
|
||||
Users can belong to multiple organizations with different roles.
|
||||
|
||||
Performance indexes (defined in migration 0002_add_performance_indexes.py):
|
||||
- ix_perf_organizations_slug_lower: LOWER(slug) WHERE is_active = true
|
||||
"""
|
||||
__tablename__ = 'organizations'
|
||||
|
||||
__tablename__ = "organizations"
|
||||
|
||||
name = Column(String(255), nullable=False, index=True)
|
||||
slug = Column(String(255), unique=True, nullable=False, index=True)
|
||||
@@ -20,11 +24,13 @@ class Organization(Base, UUIDMixin, TimestampMixin):
|
||||
settings = Column(JSONB, default={})
|
||||
|
||||
# Relationships
|
||||
user_organizations = relationship("UserOrganization", back_populates="organization", cascade="all, delete-orphan")
|
||||
user_organizations = relationship(
|
||||
"UserOrganization", back_populates="organization", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index('ix_organizations_name_active', 'name', 'is_active'),
|
||||
Index('ix_organizations_slug_active', 'slug', 'is_active'),
|
||||
Index("ix_organizations_name_active", "name", "is_active"),
|
||||
Index("ix_organizations_slug_active", "slug", "is_active"),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from sqlalchemy import Column, String, Boolean, DateTime
|
||||
from sqlalchemy import Boolean, Column, DateTime, String
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
@@ -6,20 +6,45 @@ from .base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
|
||||
class User(Base, UUIDMixin, TimestampMixin):
|
||||
__tablename__ = 'users'
|
||||
"""
|
||||
User model for authentication and profile data.
|
||||
|
||||
Performance indexes (defined in migration 0002_add_performance_indexes.py):
|
||||
- ix_perf_users_email_lower: LOWER(email) WHERE deleted_at IS NULL
|
||||
- ix_perf_users_active: is_active WHERE deleted_at IS NULL
|
||||
"""
|
||||
|
||||
__tablename__ = "users"
|
||||
|
||||
email = Column(String(255), unique=True, nullable=False, index=True)
|
||||
password_hash = Column(String(255), nullable=False)
|
||||
# Nullable to support OAuth-only users who never set a password
|
||||
password_hash = Column(String(255), nullable=True)
|
||||
first_name = Column(String(100), nullable=False, default="user")
|
||||
last_name = Column(String(100), nullable=True)
|
||||
phone_number = Column(String(20))
|
||||
is_active = Column(Boolean, default=True, nullable=False, index=True)
|
||||
is_superuser = Column(Boolean, default=False, nullable=False, index=True)
|
||||
preferences = Column(JSONB)
|
||||
locale = Column(String(10), nullable=True, index=True)
|
||||
deleted_at = Column(DateTime(timezone=True), nullable=True, index=True)
|
||||
|
||||
# Relationships
|
||||
user_organizations = relationship("UserOrganization", back_populates="user", cascade="all, delete-orphan")
|
||||
user_organizations = relationship(
|
||||
"UserOrganization", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
oauth_accounts = relationship(
|
||||
"OAuthAccount", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
@property
|
||||
def has_password(self) -> bool:
|
||||
"""Check if user can login with password (not OAuth-only)."""
|
||||
return self.password_hash is not None
|
||||
|
||||
@property
|
||||
def can_remove_oauth(self) -> bool:
|
||||
"""Check if user can safely remove an OAuth account link."""
|
||||
return self.has_password or len(self.oauth_accounts) > 1
|
||||
|
||||
def __repr__(self):
|
||||
return f"<User {self.email}>"
|
||||
return f"<User {self.email}>"
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# app/models/user_organization.py
|
||||
from enum import Enum as PyEnum
|
||||
|
||||
from sqlalchemy import Column, ForeignKey, Boolean, String, Index, Enum
|
||||
from sqlalchemy import Boolean, Column, Enum, ForeignKey, Index, String
|
||||
from sqlalchemy.dialects.postgresql import UUID as PGUUID
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
@@ -14,6 +14,7 @@ class OrganizationRole(str, PyEnum):
|
||||
These provide a baseline role system that can be optionally used.
|
||||
Projects can extend this or implement their own permission system.
|
||||
"""
|
||||
|
||||
OWNER = "owner" # Full control over organization
|
||||
ADMIN = "admin" # Can manage users and settings
|
||||
MEMBER = "member" # Regular member with standard access
|
||||
@@ -25,25 +26,41 @@ class UserOrganization(Base, TimestampMixin):
|
||||
Junction table for many-to-many relationship between Users and Organizations.
|
||||
Includes role information for flexible RBAC.
|
||||
"""
|
||||
__tablename__ = 'user_organizations'
|
||||
|
||||
user_id = Column(PGUUID(as_uuid=True), ForeignKey('users.id', ondelete='CASCADE'), primary_key=True)
|
||||
organization_id = Column(PGUUID(as_uuid=True), ForeignKey('organizations.id', ondelete='CASCADE'), primary_key=True)
|
||||
__tablename__ = "user_organizations"
|
||||
|
||||
role = Column(Enum(OrganizationRole), default=OrganizationRole.MEMBER, nullable=False, index=True)
|
||||
user_id = Column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
)
|
||||
organization_id = Column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("organizations.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
)
|
||||
|
||||
role: Column[OrganizationRole] = Column(
|
||||
Enum(OrganizationRole),
|
||||
default=OrganizationRole.MEMBER,
|
||||
nullable=False,
|
||||
# Note: index defined in __table_args__ as ix_user_org_role
|
||||
)
|
||||
is_active = Column(Boolean, default=True, nullable=False, index=True)
|
||||
|
||||
# Optional: Custom permissions override for specific users
|
||||
custom_permissions = Column(String(500), nullable=True) # JSON array of permission strings
|
||||
custom_permissions = Column(
|
||||
String(500), nullable=True
|
||||
) # JSON array of permission strings
|
||||
|
||||
# Relationships
|
||||
user = relationship("User", back_populates="user_organizations")
|
||||
organization = relationship("Organization", back_populates="user_organizations")
|
||||
|
||||
__table_args__ = (
|
||||
Index('ix_user_org_user_active', 'user_id', 'is_active'),
|
||||
Index('ix_user_org_org_active', 'organization_id', 'is_active'),
|
||||
Index('ix_user_org_role', 'role'),
|
||||
Index("ix_user_org_user_active", "user_id", "is_active"),
|
||||
Index("ix_user_org_org_active", "organization_id", "is_active"),
|
||||
Index("ix_user_org_role", "role"),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
|
||||
@@ -6,7 +6,10 @@ This allows users to:
|
||||
- Logout from specific devices
|
||||
- Manage their active sessions
|
||||
"""
|
||||
from sqlalchemy import Column, String, Boolean, DateTime, ForeignKey, Index
|
||||
|
||||
from datetime import UTC
|
||||
|
||||
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Index, String
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
@@ -19,20 +22,31 @@ class UserSession(Base, UUIDMixin, TimestampMixin):
|
||||
|
||||
Each time a user logs in from a device, a new session is created.
|
||||
Sessions are identified by the refresh token JTI (JWT ID).
|
||||
|
||||
Performance indexes (defined in migration 0002_add_performance_indexes.py):
|
||||
- ix_perf_user_sessions_expires: expires_at WHERE is_active = true
|
||||
"""
|
||||
__tablename__ = 'user_sessions'
|
||||
|
||||
__tablename__ = "user_sessions"
|
||||
|
||||
# Foreign key to user
|
||||
user_id = Column(UUID(as_uuid=True), ForeignKey('users.id', ondelete='CASCADE'), nullable=False, index=True)
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Refresh token identifier (JWT ID from the refresh token)
|
||||
refresh_token_jti = Column(String(255), unique=True, nullable=False, index=True)
|
||||
|
||||
# Device information
|
||||
device_name = Column(String(255), nullable=True) # "iPhone 14", "Chrome on MacBook"
|
||||
device_id = Column(String(255), nullable=True) # Persistent device identifier (from client)
|
||||
ip_address = Column(String(45), nullable=True) # IPv4 (15 chars) or IPv6 (45 chars)
|
||||
user_agent = Column(String(500), nullable=True) # Browser/app user agent
|
||||
device_id = Column(
|
||||
String(255), nullable=True
|
||||
) # Persistent device identifier (from client)
|
||||
ip_address = Column(String(45), nullable=True) # IPv4 (15 chars) or IPv6 (45 chars)
|
||||
user_agent = Column(String(500), nullable=True) # Browser/app user agent
|
||||
|
||||
# Session timing
|
||||
last_used_at = Column(DateTime(timezone=True), nullable=False)
|
||||
@@ -50,8 +64,8 @@ class UserSession(Base, UUIDMixin, TimestampMixin):
|
||||
|
||||
# Composite indexes for performance (defined in migration)
|
||||
__table_args__ = (
|
||||
Index('ix_user_sessions_user_active', 'user_id', 'is_active'),
|
||||
Index('ix_user_sessions_jti_active', 'refresh_token_jti', 'is_active'),
|
||||
Index("ix_user_sessions_user_active", "user_id", "is_active"),
|
||||
Index("ix_user_sessions_jti_active", "refresh_token_jti", "is_active"),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
@@ -60,21 +74,28 @@ class UserSession(Base, UUIDMixin, TimestampMixin):
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if session has expired."""
|
||||
from datetime import datetime, timezone
|
||||
return self.expires_at < datetime.now(timezone.utc)
|
||||
from datetime import datetime
|
||||
|
||||
now = datetime.now(UTC)
|
||||
expires_at = self.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=UTC)
|
||||
return bool(expires_at < now)
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert session to dictionary for serialization."""
|
||||
return {
|
||||
'id': str(self.id),
|
||||
'user_id': str(self.user_id),
|
||||
'device_name': self.device_name,
|
||||
'device_id': self.device_id,
|
||||
'ip_address': self.ip_address,
|
||||
'last_used_at': self.last_used_at.isoformat() if self.last_used_at else None,
|
||||
'expires_at': self.expires_at.isoformat() if self.expires_at else None,
|
||||
'is_active': self.is_active,
|
||||
'location_city': self.location_city,
|
||||
'location_country': self.location_country,
|
||||
'created_at': self.created_at.isoformat() if self.created_at else None,
|
||||
"id": str(self.id),
|
||||
"user_id": str(self.user_id),
|
||||
"device_name": self.device_name,
|
||||
"device_id": self.device_id,
|
||||
"ip_address": self.ip_address,
|
||||
"last_used_at": self.last_used_at.isoformat()
|
||||
if self.last_used_at
|
||||
else None,
|
||||
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
|
||||
"is_active": self.is_active,
|
||||
"location_city": self.location_city,
|
||||
"location_country": self.location_country,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
}
|
||||
|
||||
39
backend/app/repositories/__init__.py
Normal file
39
backend/app/repositories/__init__.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# app/repositories/__init__.py
|
||||
"""Repository layer — all database access goes through these classes."""
|
||||
|
||||
from app.repositories.oauth_account import OAuthAccountRepository, oauth_account_repo
|
||||
from app.repositories.oauth_authorization_code import (
|
||||
OAuthAuthorizationCodeRepository,
|
||||
oauth_authorization_code_repo,
|
||||
)
|
||||
from app.repositories.oauth_client import OAuthClientRepository, oauth_client_repo
|
||||
from app.repositories.oauth_consent import OAuthConsentRepository, oauth_consent_repo
|
||||
from app.repositories.oauth_provider_token import (
|
||||
OAuthProviderTokenRepository,
|
||||
oauth_provider_token_repo,
|
||||
)
|
||||
from app.repositories.oauth_state import OAuthStateRepository, oauth_state_repo
|
||||
from app.repositories.organization import OrganizationRepository, organization_repo
|
||||
from app.repositories.session import SessionRepository, session_repo
|
||||
from app.repositories.user import UserRepository, user_repo
|
||||
|
||||
__all__ = [
|
||||
"OAuthAccountRepository",
|
||||
"OAuthAuthorizationCodeRepository",
|
||||
"OAuthClientRepository",
|
||||
"OAuthConsentRepository",
|
||||
"OAuthProviderTokenRepository",
|
||||
"OAuthStateRepository",
|
||||
"OrganizationRepository",
|
||||
"SessionRepository",
|
||||
"UserRepository",
|
||||
"oauth_account_repo",
|
||||
"oauth_authorization_code_repo",
|
||||
"oauth_client_repo",
|
||||
"oauth_consent_repo",
|
||||
"oauth_provider_token_repo",
|
||||
"oauth_state_repo",
|
||||
"organization_repo",
|
||||
"session_repo",
|
||||
"user_repo",
|
||||
]
|
||||
420
backend/app/repositories/base.py
Normal file
420
backend/app/repositories/base.py
Normal file
@@ -0,0 +1,420 @@
|
||||
# app/repositories/base.py
|
||||
"""
|
||||
Base repository class for async database operations using SQLAlchemy 2.0 async patterns.
|
||||
|
||||
Provides reusable create, read, update, and delete operations for all models.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import UTC
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.exc import DataError, IntegrityError, OperationalError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Load
|
||||
|
||||
from app.core.database import Base
|
||||
from app.core.repository_exceptions import (
|
||||
DuplicateEntryError,
|
||||
IntegrityConstraintError,
|
||||
InvalidInputError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ModelType = TypeVar("ModelType", bound=Base)
|
||||
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
|
||||
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
|
||||
|
||||
|
||||
class BaseRepository[
|
||||
ModelType: Base,
|
||||
CreateSchemaType: BaseModel,
|
||||
UpdateSchemaType: BaseModel,
|
||||
]:
|
||||
"""Async repository operations for a model."""
|
||||
|
||||
def __init__(self, model: type[ModelType]):
|
||||
"""
|
||||
Repository object with default async methods to Create, Read, Update, Delete.
|
||||
|
||||
Parameters:
|
||||
model: A SQLAlchemy model class
|
||||
"""
|
||||
self.model = model
|
||||
|
||||
async def get(
|
||||
self, db: AsyncSession, id: str, options: list[Load] | None = None
|
||||
) -> ModelType | None:
|
||||
"""
|
||||
Get a single record by ID with UUID validation and optional eager loading.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
id: Record UUID
|
||||
options: Optional list of SQLAlchemy load options (e.g., joinedload, selectinload)
|
||||
for eager loading relationships to prevent N+1 queries
|
||||
|
||||
Returns:
|
||||
Model instance or None if not found
|
||||
"""
|
||||
try:
|
||||
if isinstance(id, uuid.UUID):
|
||||
uuid_obj = id
|
||||
else:
|
||||
uuid_obj = uuid.UUID(str(id))
|
||||
except (ValueError, AttributeError, TypeError) as e:
|
||||
logger.warning("Invalid UUID format: %s - %s", id, e)
|
||||
return None
|
||||
|
||||
try:
|
||||
query = select(self.model).where(self.model.id == uuid_obj)
|
||||
|
||||
if options:
|
||||
for option in options:
|
||||
query = query.options(option)
|
||||
|
||||
result = await db.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error retrieving %s with id %s: %s", self.model.__name__, id, e
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_multi(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
options: list[Load] | None = None,
|
||||
) -> list[ModelType]:
|
||||
"""
|
||||
Get multiple records with pagination validation and optional eager loading.
|
||||
"""
|
||||
if skip < 0:
|
||||
raise InvalidInputError("skip must be non-negative")
|
||||
if limit < 0:
|
||||
raise InvalidInputError("limit must be non-negative")
|
||||
if limit > 1000:
|
||||
raise InvalidInputError("Maximum limit is 1000")
|
||||
|
||||
try:
|
||||
query = select(self.model).order_by(self.model.id).offset(skip).limit(limit)
|
||||
|
||||
if options:
|
||||
for option in options:
|
||||
query = query.options(option)
|
||||
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error retrieving multiple %s records: %s", self.model.__name__, e
|
||||
)
|
||||
raise
|
||||
|
||||
async def create(
|
||||
self, db: AsyncSession, *, obj_in: CreateSchemaType
|
||||
) -> ModelType: # pragma: no cover
|
||||
"""Create a new record with error handling.
|
||||
|
||||
NOTE: This method is defensive code that's never called in practice.
|
||||
All repository subclasses override this method with their own implementations.
|
||||
Marked as pragma: no cover to avoid false coverage gaps.
|
||||
"""
|
||||
try: # pragma: no cover
|
||||
obj_in_data = jsonable_encoder(obj_in)
|
||||
db_obj = self.model(**obj_in_data)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
return db_obj
|
||||
except IntegrityError as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
|
||||
logger.warning(
|
||||
"Duplicate entry attempted for %s: %s",
|
||||
self.model.__name__,
|
||||
error_msg,
|
||||
)
|
||||
raise DuplicateEntryError(
|
||||
f"A {self.model.__name__} with this data already exists"
|
||||
)
|
||||
logger.error(
|
||||
"Integrity error creating %s: %s", self.model.__name__, error_msg
|
||||
)
|
||||
raise IntegrityConstraintError(f"Database integrity error: {error_msg}")
|
||||
except (OperationalError, DataError) as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error("Database error creating %s: %s", self.model.__name__, e)
|
||||
raise IntegrityConstraintError(f"Database operation failed: {e!s}")
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.exception("Unexpected error creating %s: %s", self.model.__name__, e)
|
||||
raise
|
||||
|
||||
async def update(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
db_obj: ModelType,
|
||||
obj_in: UpdateSchemaType | dict[str, Any],
|
||||
) -> ModelType:
|
||||
"""Update a record with error handling."""
|
||||
try:
|
||||
obj_data = jsonable_encoder(db_obj)
|
||||
if isinstance(obj_in, dict):
|
||||
update_data = obj_in
|
||||
else:
|
||||
update_data = obj_in.model_dump(exclude_unset=True)
|
||||
|
||||
for field in obj_data:
|
||||
if field in update_data:
|
||||
setattr(db_obj, field, update_data[field])
|
||||
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
|
||||
logger.warning(
|
||||
"Duplicate entry attempted for %s: %s",
|
||||
self.model.__name__,
|
||||
error_msg,
|
||||
)
|
||||
raise DuplicateEntryError(
|
||||
f"A {self.model.__name__} with this data already exists"
|
||||
)
|
||||
logger.error(
|
||||
"Integrity error updating %s: %s", self.model.__name__, error_msg
|
||||
)
|
||||
raise IntegrityConstraintError(f"Database integrity error: {error_msg}")
|
||||
except (OperationalError, DataError) as e:
|
||||
await db.rollback()
|
||||
logger.error("Database error updating %s: %s", self.model.__name__, e)
|
||||
raise IntegrityConstraintError(f"Database operation failed: {e!s}")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.exception("Unexpected error updating %s: %s", self.model.__name__, e)
|
||||
raise
|
||||
|
||||
async def remove(self, db: AsyncSession, *, id: str) -> ModelType | None:
|
||||
"""Delete a record with error handling and null check."""
|
||||
try:
|
||||
if isinstance(id, uuid.UUID):
|
||||
uuid_obj = id
|
||||
else:
|
||||
uuid_obj = uuid.UUID(str(id))
|
||||
except (ValueError, AttributeError, TypeError) as e:
|
||||
logger.warning("Invalid UUID format for deletion: %s - %s", id, e)
|
||||
return None
|
||||
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(self.model).where(self.model.id == uuid_obj)
|
||||
)
|
||||
obj = result.scalar_one_or_none()
|
||||
|
||||
if obj is None:
|
||||
logger.warning(
|
||||
"%s with id %s not found for deletion", self.model.__name__, id
|
||||
)
|
||||
return None
|
||||
|
||||
await db.delete(obj)
|
||||
await db.commit()
|
||||
return obj
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
logger.error(
|
||||
"Integrity error deleting %s: %s", self.model.__name__, error_msg
|
||||
)
|
||||
raise IntegrityConstraintError(
|
||||
f"Cannot delete {self.model.__name__}: referenced by other records"
|
||||
)
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.exception(
|
||||
"Error deleting %s with id %s: %s", self.model.__name__, id, e
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_multi_with_total(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
sort_by: str | None = None,
|
||||
sort_order: str = "asc",
|
||||
filters: dict[str, Any] | None = None,
|
||||
) -> tuple[list[ModelType], int]: # pragma: no cover
|
||||
"""
|
||||
Get multiple records with total count, filtering, and sorting.
|
||||
|
||||
NOTE: This method is defensive code that's never called in practice.
|
||||
All repository subclasses override this method with their own implementations.
|
||||
Marked as pragma: no cover to avoid false coverage gaps.
|
||||
"""
|
||||
if skip < 0:
|
||||
raise InvalidInputError("skip must be non-negative")
|
||||
if limit < 0:
|
||||
raise InvalidInputError("limit must be non-negative")
|
||||
if limit > 1000:
|
||||
raise InvalidInputError("Maximum limit is 1000")
|
||||
|
||||
try:
|
||||
query = select(self.model)
|
||||
|
||||
if hasattr(self.model, "deleted_at"):
|
||||
query = query.where(self.model.deleted_at.is_(None))
|
||||
|
||||
if filters:
|
||||
for field, value in filters.items():
|
||||
if hasattr(self.model, field) and value is not None:
|
||||
query = query.where(getattr(self.model, field) == value)
|
||||
|
||||
count_query = select(func.count()).select_from(query.alias())
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
if sort_by and hasattr(self.model, sort_by):
|
||||
sort_column = getattr(self.model, sort_by)
|
||||
if sort_order.lower() == "desc":
|
||||
query = query.order_by(sort_column.desc())
|
||||
else:
|
||||
query = query.order_by(sort_column.asc())
|
||||
else:
|
||||
query = query.order_by(self.model.id)
|
||||
|
||||
query = query.offset(skip).limit(limit)
|
||||
items_result = await db.execute(query)
|
||||
items = list(items_result.scalars().all())
|
||||
|
||||
return items, total
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error(
|
||||
"Error retrieving paginated %s records: %s", self.model.__name__, e
|
||||
)
|
||||
raise
|
||||
|
||||
async def count(self, db: AsyncSession) -> int:
|
||||
"""Get total count of records."""
|
||||
try:
|
||||
result = await db.execute(select(func.count(self.model.id)))
|
||||
return result.scalar_one()
|
||||
except Exception as e:
|
||||
logger.error("Error counting %s records: %s", self.model.__name__, e)
|
||||
raise
|
||||
|
||||
async def exists(self, db: AsyncSession, id: str) -> bool:
|
||||
"""Check if a record exists by ID."""
|
||||
obj = await self.get(db, id=id)
|
||||
return obj is not None
|
||||
|
||||
async def soft_delete(self, db: AsyncSession, *, id: str) -> ModelType | None:
|
||||
"""
|
||||
Soft delete a record by setting deleted_at timestamp.
|
||||
|
||||
Only works if the model has a 'deleted_at' column.
|
||||
"""
|
||||
from datetime import datetime
|
||||
|
||||
try:
|
||||
if isinstance(id, uuid.UUID):
|
||||
uuid_obj = id
|
||||
else:
|
||||
uuid_obj = uuid.UUID(str(id))
|
||||
except (ValueError, AttributeError, TypeError) as e:
|
||||
logger.warning("Invalid UUID format for soft deletion: %s - %s", id, e)
|
||||
return None
|
||||
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(self.model).where(self.model.id == uuid_obj)
|
||||
)
|
||||
obj = result.scalar_one_or_none()
|
||||
|
||||
if obj is None:
|
||||
logger.warning(
|
||||
"%s with id %s not found for soft deletion", self.model.__name__, id
|
||||
)
|
||||
return None
|
||||
|
||||
if not hasattr(self.model, "deleted_at"):
|
||||
logger.error("%s does not support soft deletes", self.model.__name__)
|
||||
raise InvalidInputError(
|
||||
f"{self.model.__name__} does not have a deleted_at column"
|
||||
)
|
||||
|
||||
obj.deleted_at = datetime.now(UTC)
|
||||
db.add(obj)
|
||||
await db.commit()
|
||||
await db.refresh(obj)
|
||||
return obj
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.exception(
|
||||
"Error soft deleting %s with id %s: %s", self.model.__name__, id, e
|
||||
)
|
||||
raise
|
||||
|
||||
async def restore(self, db: AsyncSession, *, id: str) -> ModelType | None:
|
||||
"""
|
||||
Restore a soft-deleted record by clearing the deleted_at timestamp.
|
||||
|
||||
Only works if the model has a 'deleted_at' column.
|
||||
"""
|
||||
try:
|
||||
if isinstance(id, uuid.UUID):
|
||||
uuid_obj = id
|
||||
else:
|
||||
uuid_obj = uuid.UUID(str(id))
|
||||
except (ValueError, AttributeError, TypeError) as e:
|
||||
logger.warning("Invalid UUID format for restoration: %s - %s", id, e)
|
||||
return None
|
||||
|
||||
try:
|
||||
if hasattr(self.model, "deleted_at"):
|
||||
result = await db.execute(
|
||||
select(self.model).where(
|
||||
self.model.id == uuid_obj, self.model.deleted_at.isnot(None)
|
||||
)
|
||||
)
|
||||
obj = result.scalar_one_or_none()
|
||||
else:
|
||||
logger.error("%s does not support soft deletes", self.model.__name__)
|
||||
raise InvalidInputError(
|
||||
f"{self.model.__name__} does not have a deleted_at column"
|
||||
)
|
||||
|
||||
if obj is None:
|
||||
logger.warning(
|
||||
"Soft-deleted %s with id %s not found for restoration",
|
||||
self.model.__name__,
|
||||
id,
|
||||
)
|
||||
return None
|
||||
|
||||
obj.deleted_at = None
|
||||
db.add(obj)
|
||||
await db.commit()
|
||||
await db.refresh(obj)
|
||||
return obj
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.exception(
|
||||
"Error restoring %s with id %s: %s", self.model.__name__, id, e
|
||||
)
|
||||
raise
|
||||
249
backend/app/repositories/oauth_account.py
Normal file
249
backend/app/repositories/oauth_account.py
Normal file
@@ -0,0 +1,249 @@
|
||||
# app/repositories/oauth_account.py
|
||||
"""Repository for OAuthAccount model async database operations."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import and_, delete, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from app.core.repository_exceptions import DuplicateEntryError
|
||||
from app.models.oauth_account import OAuthAccount
|
||||
from app.repositories.base import BaseRepository
|
||||
from app.schemas.oauth import OAuthAccountCreate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmptySchema(BaseModel):
|
||||
"""Placeholder schema for repository operations that don't need update schemas."""
|
||||
|
||||
|
||||
class OAuthAccountRepository(
|
||||
BaseRepository[OAuthAccount, OAuthAccountCreate, EmptySchema]
|
||||
):
|
||||
"""Repository for OAuth account links."""
|
||||
|
||||
async def get_by_provider_id(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
provider: str,
|
||||
provider_user_id: str,
|
||||
) -> OAuthAccount | None:
|
||||
"""Get OAuth account by provider and provider user ID."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(OAuthAccount)
|
||||
.where(
|
||||
and_(
|
||||
OAuthAccount.provider == provider,
|
||||
OAuthAccount.provider_user_id == provider_user_id,
|
||||
)
|
||||
)
|
||||
.options(joinedload(OAuthAccount.user))
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error(
|
||||
"Error getting OAuth account for %s:%s: %s",
|
||||
provider,
|
||||
provider_user_id,
|
||||
e,
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_by_provider_email(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
provider: str,
|
||||
email: str,
|
||||
) -> OAuthAccount | None:
|
||||
"""Get OAuth account by provider and email."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(OAuthAccount)
|
||||
.where(
|
||||
and_(
|
||||
OAuthAccount.provider == provider,
|
||||
OAuthAccount.provider_email == email,
|
||||
)
|
||||
)
|
||||
.options(joinedload(OAuthAccount.user))
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error(
|
||||
"Error getting OAuth account for %s email %s: %s", provider, email, e
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_user_accounts(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: str | UUID,
|
||||
) -> list[OAuthAccount]:
|
||||
"""Get all OAuth accounts linked to a user."""
|
||||
try:
|
||||
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
|
||||
|
||||
result = await db.execute(
|
||||
select(OAuthAccount)
|
||||
.where(OAuthAccount.user_id == user_uuid)
|
||||
.order_by(OAuthAccount.created_at.desc())
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error("Error getting OAuth accounts for user %s: %s", user_id, e)
|
||||
raise
|
||||
|
||||
async def get_user_account_by_provider(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: str | UUID,
|
||||
provider: str,
|
||||
) -> OAuthAccount | None:
|
||||
"""Get a specific OAuth account for a user and provider."""
|
||||
try:
|
||||
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
|
||||
|
||||
result = await db.execute(
|
||||
select(OAuthAccount).where(
|
||||
and_(
|
||||
OAuthAccount.user_id == user_uuid,
|
||||
OAuthAccount.provider == provider,
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error(
|
||||
"Error getting OAuth account for user %s, provider %s: %s",
|
||||
user_id,
|
||||
provider,
|
||||
e,
|
||||
)
|
||||
raise
|
||||
|
||||
async def create_account(
|
||||
self, db: AsyncSession, *, obj_in: OAuthAccountCreate
|
||||
) -> OAuthAccount:
|
||||
"""Create a new OAuth account link."""
|
||||
try:
|
||||
db_obj = OAuthAccount(
|
||||
user_id=obj_in.user_id,
|
||||
provider=obj_in.provider,
|
||||
provider_user_id=obj_in.provider_user_id,
|
||||
provider_email=obj_in.provider_email,
|
||||
access_token=obj_in.access_token,
|
||||
refresh_token=obj_in.refresh_token,
|
||||
token_expires_at=obj_in.token_expires_at,
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
|
||||
logger.info(
|
||||
"OAuth account created: %s linked to user %s",
|
||||
obj_in.provider,
|
||||
obj_in.user_id,
|
||||
)
|
||||
return db_obj
|
||||
except IntegrityError as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
if "uq_oauth_provider_user" in error_msg.lower():
|
||||
logger.warning(
|
||||
"OAuth account already exists: %s:%s",
|
||||
obj_in.provider,
|
||||
obj_in.provider_user_id,
|
||||
)
|
||||
raise DuplicateEntryError(
|
||||
f"This {obj_in.provider} account is already linked to another user"
|
||||
)
|
||||
logger.error("Integrity error creating OAuth account: %s", error_msg)
|
||||
raise DuplicateEntryError(f"Failed to create OAuth account: {error_msg}")
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.exception("Error creating OAuth account: %s", e)
|
||||
raise
|
||||
|
||||
async def delete_account(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: str | UUID,
|
||||
provider: str,
|
||||
) -> bool:
|
||||
"""Delete an OAuth account link."""
|
||||
try:
|
||||
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
|
||||
|
||||
result = await db.execute(
|
||||
delete(OAuthAccount).where(
|
||||
and_(
|
||||
OAuthAccount.user_id == user_uuid,
|
||||
OAuthAccount.provider == provider,
|
||||
)
|
||||
)
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
deleted = result.rowcount > 0
|
||||
if deleted:
|
||||
logger.info(
|
||||
"OAuth account deleted: %s unlinked from user %s", provider, user_id
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"OAuth account not found for deletion: %s for user %s",
|
||||
provider,
|
||||
user_id,
|
||||
)
|
||||
|
||||
return deleted
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
"Error deleting OAuth account %s for user %s: %s", provider, user_id, e
|
||||
)
|
||||
raise
|
||||
|
||||
async def update_tokens(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
account: OAuthAccount,
|
||||
access_token: str | None = None,
|
||||
refresh_token: str | None = None,
|
||||
token_expires_at: datetime | None = None,
|
||||
) -> OAuthAccount:
|
||||
"""Update OAuth tokens for an account."""
|
||||
try:
|
||||
if access_token is not None:
|
||||
account.access_token = access_token
|
||||
if refresh_token is not None:
|
||||
account.refresh_token = refresh_token
|
||||
if token_expires_at is not None:
|
||||
account.token_expires_at = token_expires_at
|
||||
|
||||
db.add(account)
|
||||
await db.commit()
|
||||
await db.refresh(account)
|
||||
|
||||
return account
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error("Error updating OAuth tokens: %s", e)
|
||||
raise
|
||||
|
||||
|
||||
# Singleton instance
|
||||
oauth_account_repo = OAuthAccountRepository(OAuthAccount)
|
||||
108
backend/app/repositories/oauth_authorization_code.py
Normal file
108
backend/app/repositories/oauth_authorization_code.py
Normal file
@@ -0,0 +1,108 @@
|
||||
# app/repositories/oauth_authorization_code.py
|
||||
"""Repository for OAuthAuthorizationCode model."""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_, delete, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.oauth_authorization_code import OAuthAuthorizationCode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuthAuthorizationCodeRepository:
|
||||
"""Repository for OAuth 2.0 authorization codes."""
|
||||
|
||||
async def create_code(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
code: str,
|
||||
client_id: str,
|
||||
user_id: UUID,
|
||||
redirect_uri: str,
|
||||
scope: str,
|
||||
expires_at: datetime,
|
||||
code_challenge: str | None = None,
|
||||
code_challenge_method: str | None = None,
|
||||
state: str | None = None,
|
||||
nonce: str | None = None,
|
||||
) -> OAuthAuthorizationCode:
|
||||
"""Create and persist a new authorization code."""
|
||||
auth_code = OAuthAuthorizationCode(
|
||||
code=code,
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
redirect_uri=redirect_uri,
|
||||
scope=scope,
|
||||
code_challenge=code_challenge,
|
||||
code_challenge_method=code_challenge_method,
|
||||
state=state,
|
||||
nonce=nonce,
|
||||
expires_at=expires_at,
|
||||
used=False,
|
||||
)
|
||||
db.add(auth_code)
|
||||
await db.commit()
|
||||
return auth_code
|
||||
|
||||
async def consume_code_atomically(
|
||||
self, db: AsyncSession, *, code: str
|
||||
) -> UUID | None:
|
||||
"""
|
||||
Atomically mark a code as used and return its UUID.
|
||||
|
||||
Returns the UUID if the code was found and not yet used, None otherwise.
|
||||
This prevents race conditions per RFC 6749 Section 4.1.2.
|
||||
"""
|
||||
stmt = (
|
||||
update(OAuthAuthorizationCode)
|
||||
.where(
|
||||
and_(
|
||||
OAuthAuthorizationCode.code == code,
|
||||
OAuthAuthorizationCode.used == False, # noqa: E712
|
||||
)
|
||||
)
|
||||
.values(used=True)
|
||||
.returning(OAuthAuthorizationCode.id)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
row_id = result.scalar_one_or_none()
|
||||
if row_id is not None:
|
||||
await db.commit()
|
||||
return row_id
|
||||
|
||||
async def get_by_id(
|
||||
self, db: AsyncSession, *, code_id: UUID
|
||||
) -> OAuthAuthorizationCode | None:
|
||||
"""Get authorization code by its UUID primary key."""
|
||||
result = await db.execute(
|
||||
select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.id == code_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_by_code(
|
||||
self, db: AsyncSession, *, code: str
|
||||
) -> OAuthAuthorizationCode | None:
|
||||
"""Get authorization code by the code string value."""
|
||||
result = await db.execute(
|
||||
select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.code == code)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def cleanup_expired(self, db: AsyncSession) -> int:
|
||||
"""Delete all expired authorization codes. Returns count deleted."""
|
||||
result = await db.execute(
|
||||
delete(OAuthAuthorizationCode).where(
|
||||
OAuthAuthorizationCode.expires_at < datetime.now(UTC)
|
||||
)
|
||||
)
|
||||
await db.commit()
|
||||
return result.rowcount # type: ignore[attr-defined]
|
||||
|
||||
|
||||
# Singleton instance
|
||||
oauth_authorization_code_repo = OAuthAuthorizationCodeRepository()
|
||||
201
backend/app/repositories/oauth_client.py
Normal file
201
backend/app/repositories/oauth_client.py
Normal file
@@ -0,0 +1,201 @@
|
||||
# app/repositories/oauth_client.py
|
||||
"""Repository for OAuthClient model async database operations."""
|
||||
|
||||
import logging
|
||||
import secrets
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import and_, delete, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.repository_exceptions import DuplicateEntryError
|
||||
from app.models.oauth_client import OAuthClient
|
||||
from app.repositories.base import BaseRepository
|
||||
from app.schemas.oauth import OAuthClientCreate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmptySchema(BaseModel):
|
||||
"""Placeholder schema for repository operations that don't need update schemas."""
|
||||
|
||||
|
||||
class OAuthClientRepository(
|
||||
BaseRepository[OAuthClient, OAuthClientCreate, EmptySchema]
|
||||
):
|
||||
"""Repository for OAuth clients (provider mode)."""
|
||||
|
||||
async def get_by_client_id(
|
||||
self, db: AsyncSession, *, client_id: str
|
||||
) -> OAuthClient | None:
|
||||
"""Get OAuth client by client_id."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(OAuthClient).where(
|
||||
and_(
|
||||
OAuthClient.client_id == client_id,
|
||||
OAuthClient.is_active == True, # noqa: E712
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error("Error getting OAuth client %s: %s", client_id, e)
|
||||
raise
|
||||
|
||||
async def create_client(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
obj_in: OAuthClientCreate,
|
||||
owner_user_id: UUID | None = None,
|
||||
) -> tuple[OAuthClient, str | None]:
|
||||
"""Create a new OAuth client."""
|
||||
try:
|
||||
client_id = secrets.token_urlsafe(32)
|
||||
|
||||
client_secret = None
|
||||
client_secret_hash = None
|
||||
if obj_in.client_type == "confidential":
|
||||
client_secret = secrets.token_urlsafe(48)
|
||||
from app.core.auth import get_password_hash
|
||||
|
||||
client_secret_hash = get_password_hash(client_secret)
|
||||
|
||||
db_obj = OAuthClient(
|
||||
client_id=client_id,
|
||||
client_secret_hash=client_secret_hash,
|
||||
client_name=obj_in.client_name,
|
||||
client_description=obj_in.client_description,
|
||||
client_type=obj_in.client_type,
|
||||
redirect_uris=obj_in.redirect_uris,
|
||||
allowed_scopes=obj_in.allowed_scopes,
|
||||
owner_user_id=owner_user_id,
|
||||
is_active=True,
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
|
||||
logger.info(
|
||||
"OAuth client created: %s (%s...)", obj_in.client_name, client_id[:8]
|
||||
)
|
||||
return db_obj, client_secret
|
||||
except IntegrityError as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
logger.error("Error creating OAuth client: %s", error_msg)
|
||||
raise DuplicateEntryError(f"Failed to create OAuth client: {error_msg}")
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.exception("Error creating OAuth client: %s", e)
|
||||
raise
|
||||
|
||||
async def deactivate_client(
|
||||
self, db: AsyncSession, *, client_id: str
|
||||
) -> OAuthClient | None:
|
||||
"""Deactivate an OAuth client."""
|
||||
try:
|
||||
client = await self.get_by_client_id(db, client_id=client_id)
|
||||
if client is None:
|
||||
return None
|
||||
|
||||
client.is_active = False
|
||||
db.add(client)
|
||||
await db.commit()
|
||||
await db.refresh(client)
|
||||
|
||||
logger.info("OAuth client deactivated: %s", client.client_name)
|
||||
return client
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error("Error deactivating OAuth client %s: %s", client_id, e)
|
||||
raise
|
||||
|
||||
async def validate_redirect_uri(
|
||||
self, db: AsyncSession, *, client_id: str, redirect_uri: str
|
||||
) -> bool:
|
||||
"""Validate that a redirect URI is allowed for a client."""
|
||||
try:
|
||||
client = await self.get_by_client_id(db, client_id=client_id)
|
||||
if client is None:
|
||||
return False
|
||||
|
||||
return redirect_uri in (client.redirect_uris or [])
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error("Error validating redirect URI: %s", e)
|
||||
return False
|
||||
|
||||
async def verify_client_secret(
|
||||
self, db: AsyncSession, *, client_id: str, client_secret: str
|
||||
) -> bool:
|
||||
"""Verify client credentials."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(OAuthClient).where(
|
||||
and_(
|
||||
OAuthClient.client_id == client_id,
|
||||
OAuthClient.is_active == True, # noqa: E712
|
||||
)
|
||||
)
|
||||
)
|
||||
client = result.scalar_one_or_none()
|
||||
|
||||
if client is None or client.client_secret_hash is None:
|
||||
return False
|
||||
|
||||
from app.core.auth import verify_password
|
||||
|
||||
stored_hash: str = str(client.client_secret_hash)
|
||||
|
||||
if stored_hash.startswith("$2"):
|
||||
return verify_password(client_secret, stored_hash)
|
||||
else:
|
||||
import hashlib
|
||||
|
||||
secret_hash = hashlib.sha256(client_secret.encode()).hexdigest()
|
||||
return secrets.compare_digest(stored_hash, secret_hash)
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error("Error verifying client secret: %s", e)
|
||||
return False
|
||||
|
||||
async def get_all_clients(
|
||||
self, db: AsyncSession, *, include_inactive: bool = False
|
||||
) -> list[OAuthClient]:
|
||||
"""Get all OAuth clients."""
|
||||
try:
|
||||
query = select(OAuthClient).order_by(OAuthClient.created_at.desc())
|
||||
if not include_inactive:
|
||||
query = query.where(OAuthClient.is_active == True) # noqa: E712
|
||||
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error("Error getting all OAuth clients: %s", e)
|
||||
raise
|
||||
|
||||
async def delete_client(self, db: AsyncSession, *, client_id: str) -> bool:
|
||||
"""Delete an OAuth client permanently."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
delete(OAuthClient).where(OAuthClient.client_id == client_id)
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
deleted = result.rowcount > 0
|
||||
if deleted:
|
||||
logger.info("OAuth client deleted: %s", client_id)
|
||||
else:
|
||||
logger.warning("OAuth client not found for deletion: %s", client_id)
|
||||
|
||||
return deleted
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error("Error deleting OAuth client %s: %s", client_id, e)
|
||||
raise
|
||||
|
||||
|
||||
# Singleton instance
|
||||
oauth_client_repo = OAuthClientRepository(OAuthClient)
|
||||
113
backend/app/repositories/oauth_consent.py
Normal file
113
backend/app/repositories/oauth_consent.py
Normal file
@@ -0,0 +1,113 @@
|
||||
# app/repositories/oauth_consent.py
|
||||
"""Repository for OAuthConsent model."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_, delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.oauth_client import OAuthClient
|
||||
from app.models.oauth_provider_token import OAuthConsent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuthConsentRepository:
|
||||
"""Repository for OAuth consent records (user grants to clients)."""
|
||||
|
||||
async def get_consent(
|
||||
self, db: AsyncSession, *, user_id: UUID, client_id: str
|
||||
) -> OAuthConsent | None:
|
||||
"""Get the consent record for a user-client pair, or None if not found."""
|
||||
result = await db.execute(
|
||||
select(OAuthConsent).where(
|
||||
and_(
|
||||
OAuthConsent.user_id == user_id,
|
||||
OAuthConsent.client_id == client_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def grant_consent(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: UUID,
|
||||
client_id: str,
|
||||
scopes: list[str],
|
||||
) -> OAuthConsent:
|
||||
"""
|
||||
Create or update consent for a user-client pair.
|
||||
|
||||
If consent already exists, the new scopes are merged with existing ones.
|
||||
Returns the created or updated consent record.
|
||||
"""
|
||||
consent = await self.get_consent(db, user_id=user_id, client_id=client_id)
|
||||
|
||||
if consent:
|
||||
existing = (
|
||||
set(consent.granted_scopes.split()) if consent.granted_scopes else set()
|
||||
)
|
||||
merged = existing | set(scopes)
|
||||
consent.granted_scopes = " ".join(sorted(merged)) # type: ignore[assignment]
|
||||
else:
|
||||
consent = OAuthConsent(
|
||||
user_id=user_id,
|
||||
client_id=client_id,
|
||||
granted_scopes=" ".join(sorted(set(scopes))),
|
||||
)
|
||||
db.add(consent)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(consent)
|
||||
return consent
|
||||
|
||||
async def get_user_consents_with_clients(
|
||||
self, db: AsyncSession, *, user_id: UUID
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get all consent records for a user joined with client details."""
|
||||
result = await db.execute(
|
||||
select(OAuthConsent, OAuthClient)
|
||||
.join(OAuthClient, OAuthConsent.client_id == OAuthClient.client_id)
|
||||
.where(OAuthConsent.user_id == user_id)
|
||||
)
|
||||
rows = result.all()
|
||||
return [
|
||||
{
|
||||
"client_id": consent.client_id,
|
||||
"client_name": client.client_name,
|
||||
"client_description": client.client_description,
|
||||
"granted_scopes": consent.granted_scopes.split()
|
||||
if consent.granted_scopes
|
||||
else [],
|
||||
"granted_at": consent.created_at.isoformat(),
|
||||
}
|
||||
for consent, client in rows
|
||||
]
|
||||
|
||||
async def revoke_consent(
|
||||
self, db: AsyncSession, *, user_id: UUID, client_id: str
|
||||
) -> bool:
|
||||
"""
|
||||
Delete the consent record for a user-client pair.
|
||||
|
||||
Returns True if a record was found and deleted.
|
||||
Note: Callers are responsible for also revoking associated tokens.
|
||||
"""
|
||||
result = await db.execute(
|
||||
delete(OAuthConsent).where(
|
||||
and_(
|
||||
OAuthConsent.user_id == user_id,
|
||||
OAuthConsent.client_id == client_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
await db.commit()
|
||||
return result.rowcount > 0 # type: ignore[attr-defined]
|
||||
|
||||
|
||||
# Singleton instance
|
||||
oauth_consent_repo = OAuthConsentRepository()
|
||||
142
backend/app/repositories/oauth_provider_token.py
Normal file
142
backend/app/repositories/oauth_provider_token.py
Normal file
@@ -0,0 +1,142 @@
|
||||
# app/repositories/oauth_provider_token.py
|
||||
"""Repository for OAuthProviderRefreshToken model."""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_, delete, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.oauth_provider_token import OAuthProviderRefreshToken
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuthProviderTokenRepository:
|
||||
"""Repository for OAuth provider refresh tokens."""
|
||||
|
||||
async def create_token(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
token_hash: str,
|
||||
jti: str,
|
||||
client_id: str,
|
||||
user_id: UUID,
|
||||
scope: str,
|
||||
expires_at: datetime,
|
||||
device_info: str | None = None,
|
||||
ip_address: str | None = None,
|
||||
) -> OAuthProviderRefreshToken:
|
||||
"""Create and persist a new refresh token record."""
|
||||
token = OAuthProviderRefreshToken(
|
||||
token_hash=token_hash,
|
||||
jti=jti,
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
scope=scope,
|
||||
expires_at=expires_at,
|
||||
device_info=device_info,
|
||||
ip_address=ip_address,
|
||||
)
|
||||
db.add(token)
|
||||
await db.commit()
|
||||
return token
|
||||
|
||||
async def get_by_token_hash(
|
||||
self, db: AsyncSession, *, token_hash: str
|
||||
) -> OAuthProviderRefreshToken | None:
|
||||
"""Get refresh token record by SHA-256 token hash."""
|
||||
result = await db.execute(
|
||||
select(OAuthProviderRefreshToken).where(
|
||||
OAuthProviderRefreshToken.token_hash == token_hash
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_by_jti(
|
||||
self, db: AsyncSession, *, jti: str
|
||||
) -> OAuthProviderRefreshToken | None:
|
||||
"""Get refresh token record by JWT ID (JTI)."""
|
||||
result = await db.execute(
|
||||
select(OAuthProviderRefreshToken).where(
|
||||
OAuthProviderRefreshToken.jti == jti
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def revoke(
|
||||
self, db: AsyncSession, *, token: OAuthProviderRefreshToken
|
||||
) -> None:
|
||||
"""Mark a specific token record as revoked."""
|
||||
token.revoked = True # type: ignore[assignment]
|
||||
token.last_used_at = datetime.now(UTC) # type: ignore[assignment]
|
||||
await db.commit()
|
||||
|
||||
async def revoke_all_for_user_client(
|
||||
self, db: AsyncSession, *, user_id: UUID, client_id: str
|
||||
) -> int:
|
||||
"""
|
||||
Revoke all active tokens for a specific user-client pair.
|
||||
|
||||
Used when security incidents are detected (e.g., authorization code reuse).
|
||||
Returns the number of tokens revoked.
|
||||
"""
|
||||
result = await db.execute(
|
||||
update(OAuthProviderRefreshToken)
|
||||
.where(
|
||||
and_(
|
||||
OAuthProviderRefreshToken.user_id == user_id,
|
||||
OAuthProviderRefreshToken.client_id == client_id,
|
||||
OAuthProviderRefreshToken.revoked == False, # noqa: E712
|
||||
)
|
||||
)
|
||||
.values(revoked=True)
|
||||
)
|
||||
count = result.rowcount # type: ignore[attr-defined]
|
||||
if count > 0:
|
||||
await db.commit()
|
||||
return count
|
||||
|
||||
async def revoke_all_for_user(self, db: AsyncSession, *, user_id: UUID) -> int:
|
||||
"""
|
||||
Revoke all active tokens for a user across all clients.
|
||||
|
||||
Used when user changes password or logs out everywhere.
|
||||
Returns the number of tokens revoked.
|
||||
"""
|
||||
result = await db.execute(
|
||||
update(OAuthProviderRefreshToken)
|
||||
.where(
|
||||
and_(
|
||||
OAuthProviderRefreshToken.user_id == user_id,
|
||||
OAuthProviderRefreshToken.revoked == False, # noqa: E712
|
||||
)
|
||||
)
|
||||
.values(revoked=True)
|
||||
)
|
||||
count = result.rowcount # type: ignore[attr-defined]
|
||||
if count > 0:
|
||||
await db.commit()
|
||||
return count
|
||||
|
||||
async def cleanup_expired(self, db: AsyncSession, *, cutoff_days: int = 7) -> int:
|
||||
"""
|
||||
Delete expired refresh tokens older than cutoff_days.
|
||||
|
||||
Should be called periodically (e.g., daily).
|
||||
Returns the number of tokens deleted.
|
||||
"""
|
||||
cutoff = datetime.now(UTC) - timedelta(days=cutoff_days)
|
||||
result = await db.execute(
|
||||
delete(OAuthProviderRefreshToken).where(
|
||||
OAuthProviderRefreshToken.expires_at < cutoff
|
||||
)
|
||||
)
|
||||
await db.commit()
|
||||
return result.rowcount # type: ignore[attr-defined]
|
||||
|
||||
|
||||
# Singleton instance
|
||||
oauth_provider_token_repo = OAuthProviderTokenRepository()
|
||||
113
backend/app/repositories/oauth_state.py
Normal file
113
backend/app/repositories/oauth_state.py
Normal file
@@ -0,0 +1,113 @@
|
||||
# app/repositories/oauth_state.py
|
||||
"""Repository for OAuthState model async database operations."""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.repository_exceptions import DuplicateEntryError
|
||||
from app.models.oauth_state import OAuthState
|
||||
from app.repositories.base import BaseRepository
|
||||
from app.schemas.oauth import OAuthStateCreate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmptySchema(BaseModel):
|
||||
"""Placeholder schema for repository operations that don't need update schemas."""
|
||||
|
||||
|
||||
class OAuthStateRepository(BaseRepository[OAuthState, OAuthStateCreate, EmptySchema]):
|
||||
"""Repository for OAuth state (CSRF protection)."""
|
||||
|
||||
async def create_state(
|
||||
self, db: AsyncSession, *, obj_in: OAuthStateCreate
|
||||
) -> OAuthState:
|
||||
"""Create a new OAuth state for CSRF protection."""
|
||||
try:
|
||||
db_obj = OAuthState(
|
||||
state=obj_in.state,
|
||||
code_verifier=obj_in.code_verifier,
|
||||
nonce=obj_in.nonce,
|
||||
provider=obj_in.provider,
|
||||
redirect_uri=obj_in.redirect_uri,
|
||||
user_id=obj_in.user_id,
|
||||
expires_at=obj_in.expires_at,
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
|
||||
logger.debug("OAuth state created for %s", obj_in.provider)
|
||||
return db_obj
|
||||
except IntegrityError as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
logger.error("OAuth state collision: %s", error_msg)
|
||||
raise DuplicateEntryError("Failed to create OAuth state, please retry")
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.exception("Error creating OAuth state: %s", e)
|
||||
raise
|
||||
|
||||
async def get_and_consume_state(
|
||||
self, db: AsyncSession, *, state: str
|
||||
) -> OAuthState | None:
|
||||
"""Get and delete OAuth state (consume it)."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(OAuthState).where(OAuthState.state == state)
|
||||
)
|
||||
db_obj = result.scalar_one_or_none()
|
||||
|
||||
if db_obj is None:
|
||||
logger.warning("OAuth state not found: %s...", state[:8])
|
||||
return None
|
||||
|
||||
now = datetime.now(UTC)
|
||||
expires_at = db_obj.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=UTC)
|
||||
|
||||
if expires_at < now:
|
||||
logger.warning("OAuth state expired: %s...", state[:8])
|
||||
await db.delete(db_obj)
|
||||
await db.commit()
|
||||
return None
|
||||
|
||||
await db.delete(db_obj)
|
||||
await db.commit()
|
||||
|
||||
logger.debug("OAuth state consumed: %s...", state[:8])
|
||||
return db_obj
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error("Error consuming OAuth state: %s", e)
|
||||
raise
|
||||
|
||||
async def cleanup_expired(self, db: AsyncSession) -> int:
|
||||
"""Clean up expired OAuth states."""
|
||||
try:
|
||||
now = datetime.now(UTC)
|
||||
|
||||
stmt = delete(OAuthState).where(OAuthState.expires_at < now)
|
||||
result = await db.execute(stmt)
|
||||
await db.commit()
|
||||
|
||||
count = result.rowcount
|
||||
if count > 0:
|
||||
logger.info("Cleaned up %s expired OAuth states", count)
|
||||
|
||||
return count
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error("Error cleaning up expired OAuth states: %s", e)
|
||||
raise
|
||||
|
||||
|
||||
# Singleton instance
|
||||
oauth_state_repo = OAuthStateRepository(OAuthState)
|
||||
499
backend/app/repositories/organization.py
Normal file
499
backend/app/repositories/organization.py
Normal file
@@ -0,0 +1,499 @@
|
||||
# app/repositories/organization.py
|
||||
"""Repository for Organization model async database operations using SQLAlchemy 2.0 patterns."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_, case, func, or_, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.repository_exceptions import DuplicateEntryError, IntegrityConstraintError
|
||||
from app.models.organization import Organization
|
||||
from app.models.user import User
|
||||
from app.models.user_organization import OrganizationRole, UserOrganization
|
||||
from app.repositories.base import BaseRepository
|
||||
from app.schemas.organizations import (
|
||||
OrganizationCreate,
|
||||
OrganizationUpdate,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OrganizationRepository(
|
||||
BaseRepository[Organization, OrganizationCreate, OrganizationUpdate]
|
||||
):
|
||||
"""Repository for Organization model."""
|
||||
|
||||
async def get_by_slug(self, db: AsyncSession, *, slug: str) -> Organization | None:
|
||||
"""Get organization by slug."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(Organization).where(Organization.slug == slug)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error("Error getting organization by slug %s: %s", slug, e)
|
||||
raise
|
||||
|
||||
async def create(
|
||||
self, db: AsyncSession, *, obj_in: OrganizationCreate
|
||||
) -> Organization:
|
||||
"""Create a new organization with error handling."""
|
||||
try:
|
||||
db_obj = Organization(
|
||||
name=obj_in.name,
|
||||
slug=obj_in.slug,
|
||||
description=obj_in.description,
|
||||
is_active=obj_in.is_active,
|
||||
settings=obj_in.settings or {},
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
if (
|
||||
"slug" in error_msg.lower()
|
||||
or "unique" in error_msg.lower()
|
||||
or "duplicate" in error_msg.lower()
|
||||
):
|
||||
logger.warning("Duplicate slug attempted: %s", obj_in.slug)
|
||||
raise DuplicateEntryError(
|
||||
f"Organization with slug '{obj_in.slug}' already exists"
|
||||
)
|
||||
logger.error("Integrity error creating organization: %s", error_msg)
|
||||
raise IntegrityConstraintError(f"Database integrity error: {error_msg}")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.exception("Unexpected error creating organization: %s", e)
|
||||
raise
|
||||
|
||||
async def get_multi_with_filters(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: bool | None = None,
|
||||
search: str | None = None,
|
||||
sort_by: str = "created_at",
|
||||
sort_order: str = "desc",
|
||||
) -> tuple[list[Organization], int]:
|
||||
"""Get multiple organizations with filtering, searching, and sorting."""
|
||||
try:
|
||||
query = select(Organization)
|
||||
|
||||
if is_active is not None:
|
||||
query = query.where(Organization.is_active == is_active)
|
||||
|
||||
if search:
|
||||
search_filter = or_(
|
||||
Organization.name.ilike(f"%{search}%"),
|
||||
Organization.slug.ilike(f"%{search}%"),
|
||||
Organization.description.ilike(f"%{search}%"),
|
||||
)
|
||||
query = query.where(search_filter)
|
||||
|
||||
count_query = select(func.count()).select_from(query.alias())
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
sort_column = getattr(Organization, sort_by, Organization.created_at)
|
||||
if sort_order == "desc":
|
||||
query = query.order_by(sort_column.desc())
|
||||
else:
|
||||
query = query.order_by(sort_column.asc())
|
||||
|
||||
query = query.offset(skip).limit(limit)
|
||||
result = await db.execute(query)
|
||||
organizations = list(result.scalars().all())
|
||||
|
||||
return organizations, total
|
||||
except Exception as e:
|
||||
logger.error("Error getting organizations with filters: %s", e)
|
||||
raise
|
||||
|
||||
async def get_member_count(self, db: AsyncSession, *, organization_id: UUID) -> int:
|
||||
"""Get the count of active members in an organization."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(func.count(UserOrganization.user_id)).where(
|
||||
and_(
|
||||
UserOrganization.organization_id == organization_id,
|
||||
UserOrganization.is_active,
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalar_one() or 0
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error getting member count for organization %s: %s", organization_id, e
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_multi_with_member_counts(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: bool | None = None,
|
||||
search: str | None = None,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""Get organizations with member counts in a SINGLE QUERY using JOIN and GROUP BY."""
|
||||
try:
|
||||
query = (
|
||||
select(
|
||||
Organization,
|
||||
func.count(
|
||||
func.distinct(
|
||||
case(
|
||||
(
|
||||
UserOrganization.is_active,
|
||||
UserOrganization.user_id,
|
||||
),
|
||||
else_=None,
|
||||
)
|
||||
)
|
||||
).label("member_count"),
|
||||
)
|
||||
.outerjoin(
|
||||
UserOrganization,
|
||||
Organization.id == UserOrganization.organization_id,
|
||||
)
|
||||
.group_by(Organization.id)
|
||||
)
|
||||
|
||||
if is_active is not None:
|
||||
query = query.where(Organization.is_active == is_active)
|
||||
|
||||
search_filter = None
|
||||
if search:
|
||||
search_filter = or_(
|
||||
Organization.name.ilike(f"%{search}%"),
|
||||
Organization.slug.ilike(f"%{search}%"),
|
||||
Organization.description.ilike(f"%{search}%"),
|
||||
)
|
||||
query = query.where(search_filter)
|
||||
|
||||
count_query = select(func.count(Organization.id))
|
||||
if is_active is not None:
|
||||
count_query = count_query.where(Organization.is_active == is_active)
|
||||
if search_filter is not None:
|
||||
count_query = count_query.where(search_filter)
|
||||
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
query = (
|
||||
query.order_by(Organization.created_at.desc()).offset(skip).limit(limit)
|
||||
)
|
||||
|
||||
result = await db.execute(query)
|
||||
rows = result.all()
|
||||
|
||||
orgs_with_counts = [
|
||||
{"organization": org, "member_count": member_count}
|
||||
for org, member_count in rows
|
||||
]
|
||||
|
||||
return orgs_with_counts, total
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error getting organizations with member counts: %s", e)
|
||||
raise
|
||||
|
||||
async def add_user(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
organization_id: UUID,
|
||||
user_id: UUID,
|
||||
role: OrganizationRole = OrganizationRole.MEMBER,
|
||||
custom_permissions: str | None = None,
|
||||
) -> UserOrganization:
|
||||
"""Add a user to an organization with a specific role."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(UserOrganization).where(
|
||||
and_(
|
||||
UserOrganization.user_id == user_id,
|
||||
UserOrganization.organization_id == organization_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
existing = result.scalar_one_or_none()
|
||||
|
||||
if existing:
|
||||
if not existing.is_active:
|
||||
existing.is_active = True
|
||||
existing.role = role
|
||||
existing.custom_permissions = custom_permissions
|
||||
await db.commit()
|
||||
await db.refresh(existing)
|
||||
return existing
|
||||
else:
|
||||
raise DuplicateEntryError(
|
||||
"User is already a member of this organization"
|
||||
)
|
||||
|
||||
user_org = UserOrganization(
|
||||
user_id=user_id,
|
||||
organization_id=organization_id,
|
||||
role=role,
|
||||
is_active=True,
|
||||
custom_permissions=custom_permissions,
|
||||
)
|
||||
db.add(user_org)
|
||||
await db.commit()
|
||||
await db.refresh(user_org)
|
||||
return user_org
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
logger.error("Integrity error adding user to organization: %s", e)
|
||||
raise IntegrityConstraintError("Failed to add user to organization")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.exception("Error adding user to organization: %s", e)
|
||||
raise
|
||||
|
||||
async def remove_user(
|
||||
self, db: AsyncSession, *, organization_id: UUID, user_id: UUID
|
||||
) -> bool:
|
||||
"""Remove a user from an organization (soft delete)."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(UserOrganization).where(
|
||||
and_(
|
||||
UserOrganization.user_id == user_id,
|
||||
UserOrganization.organization_id == organization_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
user_org = result.scalar_one_or_none()
|
||||
|
||||
if not user_org:
|
||||
return False
|
||||
|
||||
user_org.is_active = False
|
||||
await db.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.exception("Error removing user from organization: %s", e)
|
||||
raise
|
||||
|
||||
async def update_user_role(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
organization_id: UUID,
|
||||
user_id: UUID,
|
||||
role: OrganizationRole,
|
||||
custom_permissions: str | None = None,
|
||||
) -> UserOrganization | None:
|
||||
"""Update a user's role in an organization."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(UserOrganization).where(
|
||||
and_(
|
||||
UserOrganization.user_id == user_id,
|
||||
UserOrganization.organization_id == organization_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
user_org = result.scalar_one_or_none()
|
||||
|
||||
if not user_org:
|
||||
return None
|
||||
|
||||
user_org.role = role
|
||||
if custom_permissions is not None:
|
||||
user_org.custom_permissions = custom_permissions
|
||||
await db.commit()
|
||||
await db.refresh(user_org)
|
||||
return user_org
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.exception("Error updating user role: %s", e)
|
||||
raise
|
||||
|
||||
async def get_organization_members(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
organization_id: UUID,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: bool | None = True,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""Get members of an organization with user details."""
|
||||
try:
|
||||
query = (
|
||||
select(UserOrganization, User)
|
||||
.join(User, UserOrganization.user_id == User.id)
|
||||
.where(UserOrganization.organization_id == organization_id)
|
||||
)
|
||||
|
||||
if is_active is not None:
|
||||
query = query.where(UserOrganization.is_active == is_active)
|
||||
|
||||
count_query = select(func.count()).select_from(
|
||||
select(UserOrganization)
|
||||
.where(UserOrganization.organization_id == organization_id)
|
||||
.where(
|
||||
UserOrganization.is_active == is_active
|
||||
if is_active is not None
|
||||
else True
|
||||
)
|
||||
.alias()
|
||||
)
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
query = (
|
||||
query.order_by(UserOrganization.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
result = await db.execute(query)
|
||||
results = result.all()
|
||||
|
||||
members = []
|
||||
for user_org, user in results:
|
||||
members.append(
|
||||
{
|
||||
"user_id": user.id,
|
||||
"email": user.email,
|
||||
"first_name": user.first_name,
|
||||
"last_name": user.last_name,
|
||||
"role": user_org.role,
|
||||
"is_active": user_org.is_active,
|
||||
"joined_at": user_org.created_at,
|
||||
}
|
||||
)
|
||||
|
||||
return members, total
|
||||
except Exception as e:
|
||||
logger.error("Error getting organization members: %s", e)
|
||||
raise
|
||||
|
||||
async def get_user_organizations(
|
||||
self, db: AsyncSession, *, user_id: UUID, is_active: bool | None = True
|
||||
) -> list[Organization]:
|
||||
"""Get all organizations a user belongs to."""
|
||||
try:
|
||||
query = (
|
||||
select(Organization)
|
||||
.join(
|
||||
UserOrganization,
|
||||
Organization.id == UserOrganization.organization_id,
|
||||
)
|
||||
.where(UserOrganization.user_id == user_id)
|
||||
)
|
||||
|
||||
if is_active is not None:
|
||||
query = query.where(UserOrganization.is_active == is_active)
|
||||
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e:
|
||||
logger.error("Error getting user organizations: %s", e)
|
||||
raise
|
||||
|
||||
async def get_user_organizations_with_details(
|
||||
self, db: AsyncSession, *, user_id: UUID, is_active: bool | None = True
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get user's organizations with role and member count in SINGLE QUERY."""
|
||||
try:
|
||||
member_count_subq = (
|
||||
select(
|
||||
UserOrganization.organization_id,
|
||||
func.count(UserOrganization.user_id).label("member_count"),
|
||||
)
|
||||
.where(UserOrganization.is_active)
|
||||
.group_by(UserOrganization.organization_id)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
query = (
|
||||
select(
|
||||
Organization,
|
||||
UserOrganization.role,
|
||||
func.coalesce(member_count_subq.c.member_count, 0).label(
|
||||
"member_count"
|
||||
),
|
||||
)
|
||||
.join(
|
||||
UserOrganization,
|
||||
Organization.id == UserOrganization.organization_id,
|
||||
)
|
||||
.outerjoin(
|
||||
member_count_subq,
|
||||
Organization.id == member_count_subq.c.organization_id,
|
||||
)
|
||||
.where(UserOrganization.user_id == user_id)
|
||||
)
|
||||
|
||||
if is_active is not None:
|
||||
query = query.where(UserOrganization.is_active == is_active)
|
||||
|
||||
result = await db.execute(query)
|
||||
rows = result.all()
|
||||
|
||||
return [
|
||||
{"organization": org, "role": role, "member_count": member_count}
|
||||
for org, role, member_count in rows
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error getting user organizations with details: %s", e)
|
||||
raise
|
||||
|
||||
async def get_user_role_in_org(
|
||||
self, db: AsyncSession, *, user_id: UUID, organization_id: UUID
|
||||
) -> OrganizationRole | None:
|
||||
"""Get a user's role in a specific organization."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(UserOrganization).where(
|
||||
and_(
|
||||
UserOrganization.user_id == user_id,
|
||||
UserOrganization.organization_id == organization_id,
|
||||
UserOrganization.is_active,
|
||||
)
|
||||
)
|
||||
)
|
||||
user_org = result.scalar_one_or_none()
|
||||
|
||||
return user_org.role if user_org else None # pyright: ignore[reportReturnType]
|
||||
except Exception as e:
|
||||
logger.error("Error getting user role in org: %s", e)
|
||||
raise
|
||||
|
||||
async def is_user_org_owner(
|
||||
self, db: AsyncSession, *, user_id: UUID, organization_id: UUID
|
||||
) -> bool:
|
||||
"""Check if a user is an owner of an organization."""
|
||||
role = await self.get_user_role_in_org(
|
||||
db, user_id=user_id, organization_id=organization_id
|
||||
)
|
||||
return role == OrganizationRole.OWNER
|
||||
|
||||
async def is_user_org_admin(
|
||||
self, db: AsyncSession, *, user_id: UUID, organization_id: UUID
|
||||
) -> bool:
|
||||
"""Check if a user is an owner or admin of an organization."""
|
||||
role = await self.get_user_role_in_org(
|
||||
db, user_id=user_id, organization_id=organization_id
|
||||
)
|
||||
return role in [OrganizationRole.OWNER, OrganizationRole.ADMIN]
|
||||
|
||||
|
||||
# Singleton instance
|
||||
organization_repo = OrganizationRepository(Organization)
|
||||
333
backend/app/repositories/session.py
Normal file
333
backend/app/repositories/session.py
Normal file
@@ -0,0 +1,333 @@
|
||||
# app/repositories/session.py
|
||||
"""Repository for UserSession model async database operations using SQLAlchemy 2.0 patterns."""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_, delete, func, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from app.core.repository_exceptions import IntegrityConstraintError, InvalidInputError
|
||||
from app.models.user_session import UserSession
|
||||
from app.repositories.base import BaseRepository
|
||||
from app.schemas.sessions import SessionCreate, SessionUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SessionRepository(BaseRepository[UserSession, SessionCreate, SessionUpdate]):
|
||||
"""Repository for UserSession model."""
|
||||
|
||||
async def get_by_jti(self, db: AsyncSession, *, jti: str) -> UserSession | None:
|
||||
"""Get session by refresh token JTI."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(UserSession).where(UserSession.refresh_token_jti == jti)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error("Error getting session by JTI %s: %s", jti, e)
|
||||
raise
|
||||
|
||||
async def get_active_by_jti(
|
||||
self, db: AsyncSession, *, jti: str
|
||||
) -> UserSession | None:
|
||||
"""Get active session by refresh token JTI."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(UserSession).where(
|
||||
and_(
|
||||
UserSession.refresh_token_jti == jti,
|
||||
UserSession.is_active,
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error("Error getting active session by JTI %s: %s", jti, e)
|
||||
raise
|
||||
|
||||
async def get_user_sessions(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: str,
|
||||
active_only: bool = True,
|
||||
with_user: bool = False,
|
||||
) -> list[UserSession]:
|
||||
"""Get all sessions for a user with optional eager loading."""
|
||||
try:
|
||||
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||
|
||||
query = select(UserSession).where(UserSession.user_id == user_uuid)
|
||||
|
||||
if with_user:
|
||||
query = query.options(joinedload(UserSession.user))
|
||||
|
||||
if active_only:
|
||||
query = query.where(UserSession.is_active)
|
||||
|
||||
query = query.order_by(UserSession.last_used_at.desc())
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e:
|
||||
logger.error("Error getting sessions for user %s: %s", user_id, e)
|
||||
raise
|
||||
|
||||
async def create_session(
|
||||
self, db: AsyncSession, *, obj_in: SessionCreate
|
||||
) -> UserSession:
|
||||
"""Create a new user session."""
|
||||
try:
|
||||
db_obj = UserSession(
|
||||
user_id=obj_in.user_id,
|
||||
refresh_token_jti=obj_in.refresh_token_jti,
|
||||
device_name=obj_in.device_name,
|
||||
device_id=obj_in.device_id,
|
||||
ip_address=obj_in.ip_address,
|
||||
user_agent=obj_in.user_agent,
|
||||
last_used_at=obj_in.last_used_at,
|
||||
expires_at=obj_in.expires_at,
|
||||
is_active=True,
|
||||
location_city=obj_in.location_city,
|
||||
location_country=obj_in.location_country,
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
|
||||
logger.info(
|
||||
"Session created for user %s from %s (IP: %s)",
|
||||
obj_in.user_id,
|
||||
obj_in.device_name,
|
||||
obj_in.ip_address,
|
||||
)
|
||||
|
||||
return db_obj
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.exception("Error creating session: %s", e)
|
||||
raise IntegrityConstraintError(f"Failed to create session: {e!s}")
|
||||
|
||||
async def deactivate(
|
||||
self, db: AsyncSession, *, session_id: str
|
||||
) -> UserSession | None:
|
||||
"""Deactivate a session (logout from device)."""
|
||||
try:
|
||||
session = await self.get(db, id=session_id)
|
||||
if not session:
|
||||
logger.warning("Session %s not found for deactivation", session_id)
|
||||
return None
|
||||
|
||||
session.is_active = False
|
||||
db.add(session)
|
||||
await db.commit()
|
||||
await db.refresh(session)
|
||||
|
||||
logger.info(
|
||||
"Session %s deactivated for user %s (%s)",
|
||||
session_id,
|
||||
session.user_id,
|
||||
session.device_name,
|
||||
)
|
||||
|
||||
return session
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error("Error deactivating session %s: %s", session_id, e)
|
||||
raise
|
||||
|
||||
async def deactivate_all_user_sessions(
|
||||
self, db: AsyncSession, *, user_id: str
|
||||
) -> int:
|
||||
"""Deactivate all active sessions for a user (logout from all devices)."""
|
||||
try:
|
||||
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||
|
||||
stmt = (
|
||||
update(UserSession)
|
||||
.where(and_(UserSession.user_id == user_uuid, UserSession.is_active))
|
||||
.values(is_active=False)
|
||||
)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
await db.commit()
|
||||
|
||||
count = result.rowcount
|
||||
|
||||
logger.info("Deactivated %s sessions for user %s", count, user_id)
|
||||
|
||||
return count
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error("Error deactivating all sessions for user %s: %s", user_id, e)
|
||||
raise
|
||||
|
||||
async def update_last_used(
|
||||
self, db: AsyncSession, *, session: UserSession
|
||||
) -> UserSession:
|
||||
"""Update the last_used_at timestamp for a session."""
|
||||
try:
|
||||
session.last_used_at = datetime.now(UTC)
|
||||
db.add(session)
|
||||
await db.commit()
|
||||
await db.refresh(session)
|
||||
return session
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error("Error updating last_used for session %s: %s", session.id, e)
|
||||
raise
|
||||
|
||||
async def update_refresh_token(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
session: UserSession,
|
||||
new_jti: str,
|
||||
new_expires_at: datetime,
|
||||
) -> UserSession:
|
||||
"""Update session with new refresh token JTI and expiration."""
|
||||
try:
|
||||
session.refresh_token_jti = new_jti
|
||||
session.expires_at = new_expires_at
|
||||
session.last_used_at = datetime.now(UTC)
|
||||
db.add(session)
|
||||
await db.commit()
|
||||
await db.refresh(session)
|
||||
return session
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
"Error updating refresh token for session %s: %s", session.id, e
|
||||
)
|
||||
raise
|
||||
|
||||
async def cleanup_expired(self, db: AsyncSession, *, keep_days: int = 30) -> int:
|
||||
"""Clean up expired sessions using optimized bulk DELETE."""
|
||||
try:
|
||||
cutoff_date = datetime.now(UTC) - timedelta(days=keep_days)
|
||||
now = datetime.now(UTC)
|
||||
|
||||
stmt = delete(UserSession).where(
|
||||
and_(
|
||||
UserSession.is_active == False, # noqa: E712
|
||||
UserSession.expires_at < now,
|
||||
UserSession.created_at < cutoff_date,
|
||||
)
|
||||
)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
await db.commit()
|
||||
|
||||
count = result.rowcount
|
||||
|
||||
if count > 0:
|
||||
logger.info("Cleaned up %s expired sessions using bulk DELETE", count)
|
||||
|
||||
return count
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error("Error cleaning up expired sessions: %s", e)
|
||||
raise
|
||||
|
||||
async def cleanup_expired_for_user(self, db: AsyncSession, *, user_id: str) -> int:
|
||||
"""Clean up expired and inactive sessions for a specific user."""
|
||||
try:
|
||||
try:
|
||||
uuid_obj = uuid.UUID(user_id)
|
||||
except (ValueError, AttributeError):
|
||||
logger.error("Invalid UUID format: %s", user_id)
|
||||
raise InvalidInputError(f"Invalid user ID format: {user_id}")
|
||||
|
||||
now = datetime.now(UTC)
|
||||
|
||||
stmt = delete(UserSession).where(
|
||||
and_(
|
||||
UserSession.user_id == uuid_obj,
|
||||
UserSession.is_active == False, # noqa: E712
|
||||
UserSession.expires_at < now,
|
||||
)
|
||||
)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
await db.commit()
|
||||
|
||||
count = result.rowcount
|
||||
|
||||
if count > 0:
|
||||
logger.info(
|
||||
"Cleaned up %s expired sessions for user %s using bulk DELETE",
|
||||
count,
|
||||
user_id,
|
||||
)
|
||||
|
||||
return count
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
"Error cleaning up expired sessions for user %s: %s", user_id, e
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_user_session_count(self, db: AsyncSession, *, user_id: str) -> int:
|
||||
"""Get count of active sessions for a user."""
|
||||
try:
|
||||
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||
|
||||
result = await db.execute(
|
||||
select(func.count(UserSession.id)).where(
|
||||
and_(UserSession.user_id == user_uuid, UserSession.is_active)
|
||||
)
|
||||
)
|
||||
return result.scalar_one()
|
||||
except Exception as e:
|
||||
logger.error("Error counting sessions for user %s: %s", user_id, e)
|
||||
raise
|
||||
|
||||
async def get_all_sessions(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
active_only: bool = True,
|
||||
with_user: bool = True,
|
||||
) -> tuple[list[UserSession], int]:
|
||||
"""Get all sessions across all users with pagination (admin only)."""
|
||||
try:
|
||||
query = select(UserSession)
|
||||
|
||||
if with_user:
|
||||
query = query.options(joinedload(UserSession.user))
|
||||
|
||||
if active_only:
|
||||
query = query.where(UserSession.is_active)
|
||||
|
||||
count_query = select(func.count(UserSession.id))
|
||||
if active_only:
|
||||
count_query = count_query.where(UserSession.is_active)
|
||||
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
query = (
|
||||
query.order_by(UserSession.last_used_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
result = await db.execute(query)
|
||||
sessions = list(result.scalars().all())
|
||||
|
||||
return sessions, total
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error getting all sessions: %s", e)
|
||||
raise
|
||||
|
||||
|
||||
# Singleton instance
|
||||
session_repo = SessionRepository(UserSession)
|
||||
269
backend/app/repositories/user.py
Normal file
269
backend/app/repositories/user.py
Normal file
@@ -0,0 +1,269 @@
|
||||
# app/repositories/user.py
|
||||
"""Repository for User model async database operations using SQLAlchemy 2.0 patterns."""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import or_, select, update
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.auth import get_password_hash_async
|
||||
from app.core.repository_exceptions import DuplicateEntryError, InvalidInputError
|
||||
from app.models.user import User
|
||||
from app.repositories.base import BaseRepository
|
||||
from app.schemas.users import UserCreate, UserUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UserRepository(BaseRepository[User, UserCreate, UserUpdate]):
|
||||
"""Repository for User model."""
|
||||
|
||||
async def get_by_email(self, db: AsyncSession, *, email: str) -> User | None:
|
||||
"""Get user by email address."""
|
||||
try:
|
||||
result = await db.execute(select(User).where(User.email == email))
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error("Error getting user by email %s: %s", email, e)
|
||||
raise
|
||||
|
||||
async def create(self, db: AsyncSession, *, obj_in: UserCreate) -> User:
|
||||
"""Create a new user with async password hashing and error handling."""
|
||||
try:
|
||||
password_hash = await get_password_hash_async(obj_in.password)
|
||||
|
||||
db_obj = User(
|
||||
email=obj_in.email,
|
||||
password_hash=password_hash,
|
||||
first_name=obj_in.first_name,
|
||||
last_name=obj_in.last_name,
|
||||
phone_number=obj_in.phone_number
|
||||
if hasattr(obj_in, "phone_number")
|
||||
else None,
|
||||
is_superuser=obj_in.is_superuser
|
||||
if hasattr(obj_in, "is_superuser")
|
||||
else False,
|
||||
preferences={},
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
if "email" in error_msg.lower():
|
||||
logger.warning("Duplicate email attempted: %s", obj_in.email)
|
||||
raise DuplicateEntryError(
|
||||
f"User with email {obj_in.email} already exists"
|
||||
)
|
||||
logger.error("Integrity error creating user: %s", error_msg)
|
||||
raise DuplicateEntryError(f"Database integrity error: {error_msg}")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.exception("Unexpected error creating user: %s", e)
|
||||
raise
|
||||
|
||||
async def create_oauth_user(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
email: str,
|
||||
first_name: str = "User",
|
||||
last_name: str | None = None,
|
||||
) -> User:
|
||||
"""Create a new passwordless user for OAuth sign-in."""
|
||||
try:
|
||||
db_obj = User(
|
||||
email=email,
|
||||
password_hash=None, # OAuth-only user
|
||||
first_name=first_name,
|
||||
last_name=last_name,
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.flush() # Get user.id without committing
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
if "email" in error_msg.lower():
|
||||
logger.warning("Duplicate email attempted: %s", email)
|
||||
raise DuplicateEntryError(f"User with email {email} already exists")
|
||||
logger.error("Integrity error creating OAuth user: %s", error_msg)
|
||||
raise DuplicateEntryError(f"Database integrity error: {error_msg}")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.exception("Unexpected error creating OAuth user: %s", e)
|
||||
raise
|
||||
|
||||
async def update(
|
||||
self, db: AsyncSession, *, db_obj: User, obj_in: UserUpdate | dict[str, Any]
|
||||
) -> User:
|
||||
"""Update user with async password hashing if password is updated."""
|
||||
if isinstance(obj_in, dict):
|
||||
update_data = obj_in
|
||||
else:
|
||||
update_data = obj_in.model_dump(exclude_unset=True)
|
||||
|
||||
if "password" in update_data:
|
||||
update_data["password_hash"] = await get_password_hash_async(
|
||||
update_data["password"]
|
||||
)
|
||||
del update_data["password"]
|
||||
|
||||
return await super().update(db, db_obj=db_obj, obj_in=update_data)
|
||||
|
||||
async def update_password(
|
||||
self, db: AsyncSession, *, user: User, password_hash: str
|
||||
) -> User:
|
||||
"""Set a new password hash on a user and commit."""
|
||||
user.password_hash = password_hash
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return user
|
||||
|
||||
async def get_multi_with_total(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
sort_by: str | None = None,
|
||||
sort_order: str = "asc",
|
||||
filters: dict[str, Any] | None = None,
|
||||
search: str | None = None,
|
||||
) -> tuple[list[User], int]:
|
||||
"""Get multiple users with total count, filtering, sorting, and search."""
|
||||
if skip < 0:
|
||||
raise InvalidInputError("skip must be non-negative")
|
||||
if limit < 0:
|
||||
raise InvalidInputError("limit must be non-negative")
|
||||
if limit > 1000:
|
||||
raise InvalidInputError("Maximum limit is 1000")
|
||||
|
||||
try:
|
||||
query = select(User)
|
||||
query = query.where(User.deleted_at.is_(None))
|
||||
|
||||
if filters:
|
||||
for field, value in filters.items():
|
||||
if hasattr(User, field) and value is not None:
|
||||
query = query.where(getattr(User, field) == value)
|
||||
|
||||
if search:
|
||||
search_filter = or_(
|
||||
User.email.ilike(f"%{search}%"),
|
||||
User.first_name.ilike(f"%{search}%"),
|
||||
User.last_name.ilike(f"%{search}%"),
|
||||
)
|
||||
query = query.where(search_filter)
|
||||
|
||||
from sqlalchemy import func
|
||||
|
||||
count_query = select(func.count()).select_from(query.alias())
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
if sort_by and hasattr(User, sort_by):
|
||||
sort_column = getattr(User, sort_by)
|
||||
if sort_order.lower() == "desc":
|
||||
query = query.order_by(sort_column.desc())
|
||||
else:
|
||||
query = query.order_by(sort_column.asc())
|
||||
|
||||
query = query.offset(skip).limit(limit)
|
||||
result = await db.execute(query)
|
||||
users = list(result.scalars().all())
|
||||
|
||||
return users, total
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error retrieving paginated users: %s", e)
|
||||
raise
|
||||
|
||||
async def bulk_update_status(
|
||||
self, db: AsyncSession, *, user_ids: list[UUID], is_active: bool
|
||||
) -> int:
|
||||
"""Bulk update is_active status for multiple users."""
|
||||
try:
|
||||
if not user_ids:
|
||||
return 0
|
||||
|
||||
stmt = (
|
||||
update(User)
|
||||
.where(User.id.in_(user_ids))
|
||||
.where(User.deleted_at.is_(None))
|
||||
.values(is_active=is_active, updated_at=datetime.now(UTC))
|
||||
)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
await db.commit()
|
||||
|
||||
updated_count = result.rowcount
|
||||
logger.info(
|
||||
"Bulk updated %s users to is_active=%s", updated_count, is_active
|
||||
)
|
||||
return updated_count
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.exception("Error bulk updating user status: %s", e)
|
||||
raise
|
||||
|
||||
async def bulk_soft_delete(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_ids: list[UUID],
|
||||
exclude_user_id: UUID | None = None,
|
||||
) -> int:
|
||||
"""Bulk soft delete multiple users."""
|
||||
try:
|
||||
if not user_ids:
|
||||
return 0
|
||||
|
||||
filtered_ids = [uid for uid in user_ids if uid != exclude_user_id]
|
||||
|
||||
if not filtered_ids:
|
||||
return 0
|
||||
|
||||
stmt = (
|
||||
update(User)
|
||||
.where(User.id.in_(filtered_ids))
|
||||
.where(User.deleted_at.is_(None))
|
||||
.values(
|
||||
deleted_at=datetime.now(UTC),
|
||||
is_active=False,
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
await db.commit()
|
||||
|
||||
deleted_count = result.rowcount
|
||||
logger.info("Bulk soft deleted %s users", deleted_count)
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.exception("Error bulk deleting users: %s", e)
|
||||
raise
|
||||
|
||||
def is_active(self, user: User) -> bool:
|
||||
"""Check if user is active."""
|
||||
return bool(user.is_active)
|
||||
|
||||
def is_superuser(self, user: User) -> bool:
|
||||
"""Check if user is a superuser."""
|
||||
return bool(user.is_superuser)
|
||||
|
||||
|
||||
# Singleton instance
|
||||
user_repo = UserRepository(User)
|
||||
@@ -1,17 +1,20 @@
|
||||
"""
|
||||
Common schemas used across the API for pagination, responses, filtering, and sorting.
|
||||
"""
|
||||
from typing import Generic, TypeVar, List, Optional
|
||||
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel, Field
|
||||
from math import ceil
|
||||
from typing import TypeVar
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
T = TypeVar('T')
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class SortOrder(str, Enum):
|
||||
"""Sort order options."""
|
||||
|
||||
ASC = "asc"
|
||||
DESC = "desc"
|
||||
|
||||
@@ -19,16 +22,9 @@ class SortOrder(str, Enum):
|
||||
class PaginationParams(BaseModel):
|
||||
"""Parameters for pagination."""
|
||||
|
||||
page: int = Field(
|
||||
default=1,
|
||||
ge=1,
|
||||
description="Page number (1-indexed)"
|
||||
)
|
||||
page: int = Field(default=1, ge=1, description="Page number (1-indexed)")
|
||||
limit: int = Field(
|
||||
default=20,
|
||||
ge=1,
|
||||
le=100,
|
||||
description="Number of items per page (max 100)"
|
||||
default=20, ge=1, le=100, description="Number of items per page (max 100)"
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -41,34 +37,20 @@ class PaginationParams(BaseModel):
|
||||
"""Alias for offset (compatibility with existing code)."""
|
||||
return self.offset
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"page": 1,
|
||||
"limit": 20
|
||||
}
|
||||
}
|
||||
}
|
||||
model_config = {"json_schema_extra": {"example": {"page": 1, "limit": 20}}}
|
||||
|
||||
|
||||
class SortParams(BaseModel):
|
||||
"""Parameters for sorting."""
|
||||
|
||||
sort_by: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Field name to sort by"
|
||||
)
|
||||
sort_by: str | None = Field(default=None, description="Field name to sort by")
|
||||
sort_order: SortOrder = Field(
|
||||
default=SortOrder.ASC,
|
||||
description="Sort order (asc or desc)"
|
||||
default=SortOrder.ASC, description="Sort order (asc or desc)"
|
||||
)
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"sort_by": "created_at",
|
||||
"sort_order": "desc"
|
||||
}
|
||||
"example": {"sort_by": "created_at", "sort_order": "desc"}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -91,32 +73,30 @@ class PaginationMeta(BaseModel):
|
||||
"page_size": 20,
|
||||
"total_pages": 8,
|
||||
"has_next": True,
|
||||
"has_prev": False
|
||||
"has_prev": False,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class PaginatedResponse(BaseModel, Generic[T]):
|
||||
class PaginatedResponse[T](BaseModel):
|
||||
"""Generic paginated response wrapper."""
|
||||
|
||||
data: List[T] = Field(..., description="List of items")
|
||||
data: list[T] = Field(..., description="List of items")
|
||||
pagination: PaginationMeta = Field(..., description="Pagination metadata")
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"data": [
|
||||
{"id": "123", "name": "Example Item"}
|
||||
],
|
||||
"data": [{"id": "123", "name": "Example Item"}],
|
||||
"pagination": {
|
||||
"total": 150,
|
||||
"page": 1,
|
||||
"page_size": 20,
|
||||
"total_pages": 8,
|
||||
"has_next": True,
|
||||
"has_prev": False
|
||||
}
|
||||
"has_prev": False,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -128,21 +108,57 @@ class MessageResponse(BaseModel):
|
||||
success: bool = Field(default=True, description="Operation success status")
|
||||
message: str = Field(..., description="Human-readable message")
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {"success": True, "message": "Operation completed successfully"}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class BulkActionRequest(BaseModel):
|
||||
"""Request schema for bulk operations on multiple items."""
|
||||
|
||||
ids: list[UUID] = Field(
|
||||
...,
|
||||
min_length=1,
|
||||
max_length=100,
|
||||
description="List of item IDs to perform action on (max 100)",
|
||||
)
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"ids": [
|
||||
"550e8400-e29b-41d4-a716-446655440000",
|
||||
"6ba7b810-9dad-11d1-80b4-00c04fd430c8",
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class BulkActionResponse(BaseModel):
|
||||
"""Response schema for bulk operations."""
|
||||
|
||||
success: bool = Field(default=True, description="Operation success status")
|
||||
message: str = Field(..., description="Human-readable message")
|
||||
affected_count: int = Field(
|
||||
..., description="Number of items affected by the operation"
|
||||
)
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"success": True,
|
||||
"message": "Operation completed successfully"
|
||||
"message": "Successfully deactivated 5 users",
|
||||
"affected_count": 5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def create_pagination_meta(
|
||||
total: int,
|
||||
page: int,
|
||||
limit: int,
|
||||
items_count: int
|
||||
total: int, page: int, limit: int, items_count: int
|
||||
) -> PaginationMeta:
|
||||
"""
|
||||
Helper function to create pagination metadata.
|
||||
@@ -164,5 +180,5 @@ def create_pagination_meta(
|
||||
page_size=items_count,
|
||||
total_pages=total_pages,
|
||||
has_next=page < total_pages,
|
||||
has_prev=page > 1
|
||||
has_prev=page > 1,
|
||||
)
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
"""
|
||||
Error schemas for standardized API error responses.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@@ -16,6 +17,7 @@ class ErrorCode(str, Enum):
|
||||
INSUFFICIENT_PERMISSIONS = "AUTH_004"
|
||||
USER_INACTIVE = "AUTH_005"
|
||||
AUTHENTICATION_REQUIRED = "AUTH_006"
|
||||
OPERATION_FORBIDDEN = "AUTH_007" # Operation not allowed for this user/role
|
||||
|
||||
# User errors (USER_xxx)
|
||||
USER_NOT_FOUND = "USER_001"
|
||||
@@ -43,6 +45,7 @@ class ErrorCode(str, Enum):
|
||||
NOT_FOUND = "SYS_002"
|
||||
METHOD_NOT_ALLOWED = "SYS_003"
|
||||
RATE_LIMIT_EXCEEDED = "SYS_004"
|
||||
ALREADY_EXISTS = "SYS_005" # Generic resource already exists error
|
||||
|
||||
|
||||
class ErrorDetail(BaseModel):
|
||||
@@ -50,14 +53,14 @@ class ErrorDetail(BaseModel):
|
||||
|
||||
code: ErrorCode = Field(..., description="Machine-readable error code")
|
||||
message: str = Field(..., description="Human-readable error message")
|
||||
field: Optional[str] = Field(None, description="Field name if error is field-specific")
|
||||
field: str | None = Field(None, description="Field name if error is field-specific")
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"code": "VAL_002",
|
||||
"message": "Password must be at least 8 characters long",
|
||||
"field": "password"
|
||||
"field": "password",
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -67,7 +70,7 @@ class ErrorResponse(BaseModel):
|
||||
"""Standardized error response format."""
|
||||
|
||||
success: bool = Field(default=False, description="Always false for error responses")
|
||||
errors: List[ErrorDetail] = Field(..., description="List of errors that occurred")
|
||||
errors: list[ErrorDetail] = Field(..., description="List of errors that occurred")
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
@@ -77,9 +80,9 @@ class ErrorResponse(BaseModel):
|
||||
{
|
||||
"code": "AUTH_001",
|
||||
"message": "Invalid email or password",
|
||||
"field": None
|
||||
"field": None,
|
||||
}
|
||||
]
|
||||
],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
395
backend/app/schemas/oauth.py
Normal file
395
backend/app/schemas/oauth.py
Normal file
@@ -0,0 +1,395 @@
|
||||
"""
|
||||
Pydantic schemas for OAuth authentication.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
# ============================================================================
|
||||
# OAuth Provider Info (for frontend to display available providers)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class OAuthProviderInfo(BaseModel):
|
||||
"""Information about an available OAuth provider."""
|
||||
|
||||
provider: str = Field(..., description="Provider identifier (google, github)")
|
||||
name: str = Field(..., description="Human-readable provider name")
|
||||
icon: str | None = Field(None, description="Icon identifier for frontend")
|
||||
|
||||
|
||||
class OAuthProvidersResponse(BaseModel):
|
||||
"""Response containing list of enabled OAuth providers."""
|
||||
|
||||
enabled: bool = Field(..., description="Whether OAuth is globally enabled")
|
||||
providers: list[OAuthProviderInfo] = Field(
|
||||
default_factory=list, description="List of enabled providers"
|
||||
)
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"enabled": True,
|
||||
"providers": [
|
||||
{"provider": "google", "name": "Google", "icon": "google"},
|
||||
{"provider": "github", "name": "GitHub", "icon": "github"},
|
||||
],
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# OAuth Account (linked provider accounts)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class OAuthAccountBase(BaseModel):
|
||||
"""Base schema for OAuth accounts."""
|
||||
|
||||
provider: str = Field(..., max_length=50, description="OAuth provider name")
|
||||
provider_email: str | None = Field(
|
||||
None, max_length=255, description="Email from OAuth provider"
|
||||
)
|
||||
|
||||
|
||||
class OAuthAccountCreate(OAuthAccountBase):
|
||||
"""Schema for creating an OAuth account link (internal use)."""
|
||||
|
||||
user_id: UUID
|
||||
provider_user_id: str = Field(..., max_length=255)
|
||||
access_token: str | None = None
|
||||
refresh_token: str | None = None
|
||||
token_expires_at: datetime | None = None
|
||||
|
||||
|
||||
class OAuthAccountResponse(OAuthAccountBase):
|
||||
"""Schema for OAuth account response to clients."""
|
||||
|
||||
id: UUID
|
||||
created_at: datetime
|
||||
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"id": "123e4567-e89b-12d3-a456-426614174000",
|
||||
"provider": "google",
|
||||
"provider_email": "user@gmail.com",
|
||||
"created_at": "2025-11-24T12:00:00Z",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class OAuthAccountsListResponse(BaseModel):
|
||||
"""Response containing list of linked OAuth accounts."""
|
||||
|
||||
accounts: list[OAuthAccountResponse]
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"accounts": [
|
||||
{
|
||||
"id": "123e4567-e89b-12d3-a456-426614174000",
|
||||
"provider": "google",
|
||||
"provider_email": "user@gmail.com",
|
||||
"created_at": "2025-11-24T12:00:00Z",
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# OAuth Flow (authorization, callback, etc.)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class OAuthAuthorizeRequest(BaseModel):
|
||||
"""Request parameters for OAuth authorization."""
|
||||
|
||||
provider: str = Field(..., description="OAuth provider (google, github)")
|
||||
redirect_uri: str | None = Field(
|
||||
None, description="Frontend callback URL after OAuth"
|
||||
)
|
||||
mode: str = Field(
|
||||
default="login",
|
||||
description="OAuth mode: login, register, or link",
|
||||
pattern="^(login|register|link)$",
|
||||
)
|
||||
|
||||
|
||||
class OAuthCallbackRequest(BaseModel):
|
||||
"""Request parameters for OAuth callback."""
|
||||
|
||||
code: str = Field(..., description="Authorization code from provider")
|
||||
state: str = Field(..., description="State parameter for CSRF protection")
|
||||
|
||||
|
||||
class OAuthCallbackResponse(BaseModel):
|
||||
"""Response after successful OAuth authentication."""
|
||||
|
||||
access_token: str = Field(..., description="JWT access token")
|
||||
refresh_token: str = Field(..., description="JWT refresh token")
|
||||
token_type: str = Field(default="bearer")
|
||||
expires_in: int = Field(..., description="Token expiration in seconds")
|
||||
is_new_user: bool = Field(
|
||||
default=False, description="Whether a new user was created"
|
||||
)
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"access_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...",
|
||||
"refresh_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...",
|
||||
"token_type": "bearer",
|
||||
"expires_in": 900,
|
||||
"is_new_user": False,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class OAuthUnlinkResponse(BaseModel):
|
||||
"""Response after unlinking an OAuth account."""
|
||||
|
||||
success: bool = Field(..., description="Whether the unlink was successful")
|
||||
message: str = Field(..., description="Status message")
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"example": {"success": True, "message": "Google account unlinked"}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# OAuth State (CSRF protection - internal use)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class OAuthStateCreate(BaseModel):
|
||||
"""Schema for creating OAuth state (internal use)."""
|
||||
|
||||
state: str = Field(..., max_length=255)
|
||||
code_verifier: str | None = Field(None, max_length=128)
|
||||
nonce: str | None = Field(None, max_length=255)
|
||||
provider: str = Field(..., max_length=50)
|
||||
redirect_uri: str | None = Field(None, max_length=500)
|
||||
user_id: UUID | None = None
|
||||
expires_at: datetime
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# OAuth Client (Provider Mode - MCP clients)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class OAuthClientBase(BaseModel):
|
||||
"""Base schema for OAuth clients."""
|
||||
|
||||
client_name: str = Field(..., max_length=255, description="Client application name")
|
||||
client_description: str | None = Field(
|
||||
None, max_length=1000, description="Client description"
|
||||
)
|
||||
redirect_uris: list[str] = Field(
|
||||
default_factory=list, description="Allowed redirect URIs"
|
||||
)
|
||||
allowed_scopes: list[str] = Field(
|
||||
default_factory=list, description="Allowed OAuth scopes"
|
||||
)
|
||||
|
||||
|
||||
class OAuthClientCreate(OAuthClientBase):
|
||||
"""Schema for creating an OAuth client."""
|
||||
|
||||
client_type: str = Field(
|
||||
default="public",
|
||||
description="Client type: public or confidential",
|
||||
pattern="^(public|confidential)$",
|
||||
)
|
||||
|
||||
|
||||
class OAuthClientResponse(OAuthClientBase):
|
||||
"""Schema for OAuth client response."""
|
||||
|
||||
id: UUID
|
||||
client_id: str = Field(..., description="OAuth client ID")
|
||||
client_type: str
|
||||
is_active: bool
|
||||
created_at: datetime
|
||||
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"id": "123e4567-e89b-12d3-a456-426614174000",
|
||||
"client_id": "abc123def456",
|
||||
"client_name": "My MCP App",
|
||||
"client_description": "My application that uses MCP",
|
||||
"client_type": "public",
|
||||
"redirect_uris": ["http://localhost:3000/callback"],
|
||||
"allowed_scopes": ["read:users", "write:users"],
|
||||
"is_active": True,
|
||||
"created_at": "2025-11-24T12:00:00Z",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class OAuthClientWithSecret(OAuthClientResponse):
|
||||
"""Schema for OAuth client response including secret (only shown once)."""
|
||||
|
||||
client_secret: str | None = Field(
|
||||
None, description="Client secret (only shown once for confidential clients)"
|
||||
)
|
||||
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"id": "123e4567-e89b-12d3-a456-426614174000",
|
||||
"client_id": "abc123def456",
|
||||
"client_secret": "secret_xyz789",
|
||||
"client_name": "My MCP App",
|
||||
"client_type": "confidential",
|
||||
"redirect_uris": ["http://localhost:3000/callback"],
|
||||
"allowed_scopes": ["read:users"],
|
||||
"is_active": True,
|
||||
"created_at": "2025-11-24T12:00:00Z",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# OAuth Provider Discovery (RFC 8414 - skeleton)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class OAuthServerMetadata(BaseModel):
|
||||
"""OAuth 2.0 Authorization Server Metadata (RFC 8414)."""
|
||||
|
||||
issuer: str = Field(..., description="Authorization server issuer URL")
|
||||
authorization_endpoint: str = Field(..., description="Authorization endpoint URL")
|
||||
token_endpoint: str = Field(..., description="Token endpoint URL")
|
||||
registration_endpoint: str | None = Field(
|
||||
None, description="Dynamic client registration endpoint"
|
||||
)
|
||||
revocation_endpoint: str | None = Field(
|
||||
None, description="Token revocation endpoint"
|
||||
)
|
||||
introspection_endpoint: str | None = Field(
|
||||
None, description="Token introspection endpoint (RFC 7662)"
|
||||
)
|
||||
scopes_supported: list[str] = Field(
|
||||
default_factory=list, description="Supported scopes"
|
||||
)
|
||||
response_types_supported: list[str] = Field(
|
||||
default_factory=lambda: ["code"], description="Supported response types"
|
||||
)
|
||||
grant_types_supported: list[str] = Field(
|
||||
default_factory=lambda: ["authorization_code", "refresh_token"],
|
||||
description="Supported grant types",
|
||||
)
|
||||
code_challenge_methods_supported: list[str] = Field(
|
||||
default_factory=lambda: ["S256"], description="Supported PKCE methods"
|
||||
)
|
||||
token_endpoint_auth_methods_supported: list[str] = Field(
|
||||
default_factory=lambda: ["client_secret_basic", "client_secret_post", "none"],
|
||||
description="Supported client authentication methods",
|
||||
)
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"issuer": "https://api.example.com",
|
||||
"authorization_endpoint": "https://api.example.com/oauth/authorize",
|
||||
"token_endpoint": "https://api.example.com/oauth/token",
|
||||
"revocation_endpoint": "https://api.example.com/oauth/revoke",
|
||||
"introspection_endpoint": "https://api.example.com/oauth/introspect",
|
||||
"scopes_supported": ["openid", "profile", "email", "read:users"],
|
||||
"response_types_supported": ["code"],
|
||||
"grant_types_supported": ["authorization_code", "refresh_token"],
|
||||
"code_challenge_methods_supported": ["S256"],
|
||||
"token_endpoint_auth_methods_supported": [
|
||||
"client_secret_basic",
|
||||
"client_secret_post",
|
||||
"none",
|
||||
],
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# OAuth Token Responses (RFC 6749)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class OAuthTokenResponse(BaseModel):
|
||||
"""OAuth 2.0 Token Response (RFC 6749 Section 5.1)."""
|
||||
|
||||
access_token: str = Field(..., description="The access token issued by the server")
|
||||
token_type: str = Field(
|
||||
default="Bearer", description="The type of token (typically 'Bearer')"
|
||||
)
|
||||
expires_in: int | None = Field(None, description="Token lifetime in seconds")
|
||||
refresh_token: str | None = Field(
|
||||
None, description="Refresh token for obtaining new access tokens"
|
||||
)
|
||||
scope: str | None = Field(
|
||||
None, description="Space-separated list of granted scopes"
|
||||
)
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"access_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 3600,
|
||||
"refresh_token": "dGhpcyBpcyBhIHJlZnJlc2ggdG9rZW4...",
|
||||
"scope": "openid profile email",
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class OAuthTokenIntrospectionResponse(BaseModel):
|
||||
"""OAuth 2.0 Token Introspection Response (RFC 7662)."""
|
||||
|
||||
active: bool = Field(..., description="Whether the token is currently active")
|
||||
scope: str | None = Field(None, description="Space-separated list of scopes")
|
||||
client_id: str | None = Field(None, description="Client identifier for the token")
|
||||
username: str | None = Field(
|
||||
None, description="Human-readable identifier for the resource owner"
|
||||
)
|
||||
token_type: str | None = Field(
|
||||
None, description="Type of the token (e.g., 'Bearer')"
|
||||
)
|
||||
exp: int | None = Field(None, description="Token expiration time (Unix timestamp)")
|
||||
iat: int | None = Field(None, description="Token issue time (Unix timestamp)")
|
||||
nbf: int | None = Field(None, description="Token not-before time (Unix timestamp)")
|
||||
sub: str | None = Field(None, description="Subject of the token (user ID)")
|
||||
aud: str | None = Field(None, description="Intended audience of the token")
|
||||
iss: str | None = Field(None, description="Issuer of the token")
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"active": True,
|
||||
"scope": "openid profile",
|
||||
"client_id": "client123",
|
||||
"username": "user@example.com",
|
||||
"token_type": "Bearer",
|
||||
"exp": 1735689600,
|
||||
"iat": 1735686000,
|
||||
"sub": "user-uuid-here",
|
||||
}
|
||||
}
|
||||
)
|
||||
@@ -1,10 +1,10 @@
|
||||
# app/schemas/organizations.py
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any, List
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, field_validator, ConfigDict, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from app.models.user_organization import OrganizationRole
|
||||
|
||||
@@ -12,85 +12,94 @@ from app.models.user_organization import OrganizationRole
|
||||
# Organization Schemas
|
||||
class OrganizationBase(BaseModel):
|
||||
"""Base organization schema with common fields."""
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
slug: Optional[str] = Field(None, min_length=1, max_length=255)
|
||||
description: Optional[str] = None
|
||||
is_active: bool = True
|
||||
settings: Optional[Dict[str, Any]] = {}
|
||||
|
||||
@field_validator('slug')
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
slug: str | None = Field(None, min_length=1, max_length=255)
|
||||
description: str | None = None
|
||||
is_active: bool = True
|
||||
settings: dict[str, Any] | None = {}
|
||||
|
||||
@field_validator("slug")
|
||||
@classmethod
|
||||
def validate_slug(cls, v: Optional[str]) -> Optional[str]:
|
||||
def validate_slug(cls, v: str | None) -> str | None:
|
||||
"""Validate slug format: lowercase, alphanumeric, hyphens only."""
|
||||
if v is None:
|
||||
return v
|
||||
if not re.match(r'^[a-z0-9-]+$', v):
|
||||
raise ValueError('Slug must contain only lowercase letters, numbers, and hyphens')
|
||||
if v.startswith('-') or v.endswith('-'):
|
||||
raise ValueError('Slug cannot start or end with a hyphen')
|
||||
if '--' in v:
|
||||
raise ValueError('Slug cannot contain consecutive hyphens')
|
||||
if not re.match(r"^[a-z0-9-]+$", v):
|
||||
raise ValueError(
|
||||
"Slug must contain only lowercase letters, numbers, and hyphens"
|
||||
)
|
||||
if v.startswith("-") or v.endswith("-"):
|
||||
raise ValueError("Slug cannot start or end with a hyphen")
|
||||
if "--" in v:
|
||||
raise ValueError("Slug cannot contain consecutive hyphens")
|
||||
return v
|
||||
|
||||
@field_validator('name')
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
def validate_name(cls, v: str) -> str:
|
||||
"""Validate organization name."""
|
||||
if not v or v.strip() == "":
|
||||
raise ValueError('Organization name cannot be empty')
|
||||
raise ValueError("Organization name cannot be empty")
|
||||
return v.strip()
|
||||
|
||||
|
||||
class OrganizationCreate(OrganizationBase):
|
||||
"""Schema for creating a new organization."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
slug: str = Field(..., min_length=1, max_length=255)
|
||||
slug: str = Field(..., min_length=1, max_length=255) # pyright: ignore[reportIncompatibleVariableOverride]
|
||||
|
||||
|
||||
class OrganizationUpdate(BaseModel):
|
||||
"""Schema for updating an organization."""
|
||||
name: Optional[str] = Field(None, min_length=1, max_length=255)
|
||||
slug: Optional[str] = Field(None, min_length=1, max_length=255)
|
||||
description: Optional[str] = None
|
||||
is_active: Optional[bool] = None
|
||||
settings: Optional[Dict[str, Any]] = None
|
||||
|
||||
@field_validator('slug')
|
||||
name: str | None = Field(None, min_length=1, max_length=255)
|
||||
slug: str | None = Field(None, min_length=1, max_length=255)
|
||||
description: str | None = None
|
||||
is_active: bool | None = None
|
||||
settings: dict[str, Any] | None = None
|
||||
|
||||
@field_validator("slug")
|
||||
@classmethod
|
||||
def validate_slug(cls, v: Optional[str]) -> Optional[str]:
|
||||
def validate_slug(cls, v: str | None) -> str | None:
|
||||
"""Validate slug format."""
|
||||
if v is None:
|
||||
return v
|
||||
if not re.match(r'^[a-z0-9-]+$', v):
|
||||
raise ValueError('Slug must contain only lowercase letters, numbers, and hyphens')
|
||||
if v.startswith('-') or v.endswith('-'):
|
||||
raise ValueError('Slug cannot start or end with a hyphen')
|
||||
if '--' in v:
|
||||
raise ValueError('Slug cannot contain consecutive hyphens')
|
||||
if not re.match(r"^[a-z0-9-]+$", v):
|
||||
raise ValueError(
|
||||
"Slug must contain only lowercase letters, numbers, and hyphens"
|
||||
)
|
||||
if v.startswith("-") or v.endswith("-"):
|
||||
raise ValueError("Slug cannot start or end with a hyphen")
|
||||
if "--" in v:
|
||||
raise ValueError("Slug cannot contain consecutive hyphens")
|
||||
return v
|
||||
|
||||
@field_validator('name')
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
def validate_name(cls, v: Optional[str]) -> Optional[str]:
|
||||
def validate_name(cls, v: str | None) -> str | None:
|
||||
"""Validate organization name."""
|
||||
if v is not None and (not v or v.strip() == ""):
|
||||
raise ValueError('Organization name cannot be empty')
|
||||
raise ValueError("Organization name cannot be empty")
|
||||
return v.strip() if v else v
|
||||
|
||||
|
||||
class OrganizationResponse(OrganizationBase):
|
||||
"""Schema for organization API responses."""
|
||||
|
||||
id: UUID
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime] = None
|
||||
member_count: Optional[int] = 0
|
||||
updated_at: datetime | None = None
|
||||
member_count: int | None = 0
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class OrganizationListResponse(BaseModel):
|
||||
"""Schema for paginated organization list responses."""
|
||||
organizations: List[OrganizationResponse]
|
||||
|
||||
organizations: list[OrganizationResponse]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
@@ -100,44 +109,49 @@ class OrganizationListResponse(BaseModel):
|
||||
# User-Organization Relationship Schemas
|
||||
class UserOrganizationBase(BaseModel):
|
||||
"""Base schema for user-organization relationship."""
|
||||
|
||||
role: OrganizationRole = OrganizationRole.MEMBER
|
||||
is_active: bool = True
|
||||
custom_permissions: Optional[str] = None
|
||||
custom_permissions: str | None = None
|
||||
|
||||
|
||||
class UserOrganizationCreate(BaseModel):
|
||||
"""Schema for adding a user to an organization."""
|
||||
|
||||
user_id: UUID
|
||||
role: OrganizationRole = OrganizationRole.MEMBER
|
||||
custom_permissions: Optional[str] = None
|
||||
custom_permissions: str | None = None
|
||||
|
||||
|
||||
class UserOrganizationUpdate(BaseModel):
|
||||
"""Schema for updating user's role in an organization."""
|
||||
role: Optional[OrganizationRole] = None
|
||||
is_active: Optional[bool] = None
|
||||
custom_permissions: Optional[str] = None
|
||||
|
||||
role: OrganizationRole | None = None
|
||||
is_active: bool | None = None
|
||||
custom_permissions: str | None = None
|
||||
|
||||
|
||||
class UserOrganizationResponse(BaseModel):
|
||||
"""Schema for user-organization relationship responses."""
|
||||
|
||||
user_id: UUID
|
||||
organization_id: UUID
|
||||
role: OrganizationRole
|
||||
is_active: bool
|
||||
custom_permissions: Optional[str] = None
|
||||
custom_permissions: str | None = None
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime] = None
|
||||
updated_at: datetime | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class OrganizationMemberResponse(BaseModel):
|
||||
"""Schema for organization member information."""
|
||||
|
||||
user_id: UUID
|
||||
email: str
|
||||
first_name: str
|
||||
last_name: Optional[str] = None
|
||||
last_name: str | None = None
|
||||
role: OrganizationRole
|
||||
is_active: bool
|
||||
joined_at: datetime
|
||||
@@ -147,7 +161,8 @@ class OrganizationMemberResponse(BaseModel):
|
||||
|
||||
class OrganizationMemberListResponse(BaseModel):
|
||||
"""Schema for paginated organization member list."""
|
||||
members: List[OrganizationMemberResponse]
|
||||
|
||||
members: list[OrganizationMemberResponse]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
|
||||
@@ -1,37 +1,44 @@
|
||||
"""
|
||||
Pydantic schemas for user session management.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class SessionBase(BaseModel):
|
||||
"""Base schema for user sessions."""
|
||||
device_name: Optional[str] = Field(None, max_length=255, description="Friendly device name")
|
||||
device_id: Optional[str] = Field(None, max_length=255, description="Persistent device identifier")
|
||||
|
||||
device_name: str | None = Field(
|
||||
None, max_length=255, description="Friendly device name"
|
||||
)
|
||||
device_id: str | None = Field(
|
||||
None, max_length=255, description="Persistent device identifier"
|
||||
)
|
||||
|
||||
|
||||
class SessionCreate(SessionBase):
|
||||
"""Schema for creating a new session (internal use)."""
|
||||
|
||||
user_id: UUID
|
||||
refresh_token_jti: str = Field(..., max_length=255)
|
||||
ip_address: Optional[str] = Field(None, max_length=45)
|
||||
user_agent: Optional[str] = Field(None, max_length=500)
|
||||
ip_address: str | None = Field(None, max_length=45)
|
||||
user_agent: str | None = Field(None, max_length=500)
|
||||
last_used_at: datetime
|
||||
expires_at: datetime
|
||||
location_city: Optional[str] = Field(None, max_length=100)
|
||||
location_country: Optional[str] = Field(None, max_length=100)
|
||||
location_city: str | None = Field(None, max_length=100)
|
||||
location_country: str | None = Field(None, max_length=100)
|
||||
|
||||
|
||||
class SessionUpdate(BaseModel):
|
||||
"""Schema for updating a session (internal use)."""
|
||||
last_used_at: Optional[datetime] = None
|
||||
is_active: Optional[bool] = None
|
||||
refresh_token_jti: Optional[str] = None
|
||||
expires_at: Optional[datetime] = None
|
||||
|
||||
last_used_at: datetime | None = None
|
||||
is_active: bool | None = None
|
||||
refresh_token_jti: str | None = None
|
||||
expires_at: datetime | None = None
|
||||
|
||||
|
||||
class SessionResponse(SessionBase):
|
||||
@@ -40,14 +47,17 @@ class SessionResponse(SessionBase):
|
||||
|
||||
This is what users see when they list their active sessions.
|
||||
"""
|
||||
|
||||
id: UUID
|
||||
ip_address: Optional[str] = None
|
||||
location_city: Optional[str] = None
|
||||
location_country: Optional[str] = None
|
||||
ip_address: str | None = None
|
||||
location_city: str | None = None
|
||||
location_country: str | None = None
|
||||
last_used_at: datetime
|
||||
created_at: datetime
|
||||
expires_at: datetime
|
||||
is_current: bool = Field(default=False, description="Whether this is the current session")
|
||||
is_current: bool = Field(
|
||||
default=False, description="Whether this is the current session"
|
||||
)
|
||||
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
@@ -62,14 +72,15 @@ class SessionResponse(SessionBase):
|
||||
"last_used_at": "2025-10-31T12:00:00Z",
|
||||
"created_at": "2025-10-30T09:00:00Z",
|
||||
"expires_at": "2025-11-06T09:00:00Z",
|
||||
"is_current": True
|
||||
"is_current": True,
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class SessionListResponse(BaseModel):
|
||||
"""Response containing list of sessions."""
|
||||
|
||||
sessions: list[SessionResponse]
|
||||
total: int = Field(..., description="Total number of active sessions")
|
||||
|
||||
@@ -84,10 +95,10 @@ class SessionListResponse(BaseModel):
|
||||
"last_used_at": "2025-10-31T12:00:00Z",
|
||||
"created_at": "2025-10-30T09:00:00Z",
|
||||
"expires_at": "2025-11-06T09:00:00Z",
|
||||
"is_current": True
|
||||
"is_current": True,
|
||||
}
|
||||
],
|
||||
"total": 1
|
||||
"total": 1,
|
||||
}
|
||||
}
|
||||
)
|
||||
@@ -95,29 +106,68 @@ class SessionListResponse(BaseModel):
|
||||
|
||||
class LogoutRequest(BaseModel):
|
||||
"""Request schema for logout endpoint."""
|
||||
|
||||
refresh_token: str = Field(
|
||||
...,
|
||||
description="Refresh token for the session to logout from",
|
||||
min_length=10
|
||||
..., description="Refresh token for the session to logout from", min_length=10
|
||||
)
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"refresh_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
|
||||
}
|
||||
"example": {"refresh_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class AdminSessionResponse(SessionBase):
|
||||
"""
|
||||
Schema for session responses in admin panel.
|
||||
|
||||
Includes user information for admin to see who owns each session.
|
||||
"""
|
||||
|
||||
id: UUID
|
||||
user_id: UUID
|
||||
user_email: str = Field(..., description="Email of the user who owns this session")
|
||||
user_full_name: str | None = Field(None, description="Full name of the user")
|
||||
ip_address: str | None = None
|
||||
location_city: str | None = None
|
||||
location_country: str | None = None
|
||||
last_used_at: datetime
|
||||
created_at: datetime
|
||||
expires_at: datetime
|
||||
is_active: bool
|
||||
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"id": "123e4567-e89b-12d3-a456-426614174000",
|
||||
"user_id": "456e7890-e89b-12d3-a456-426614174001",
|
||||
"user_email": "user@example.com",
|
||||
"user_full_name": "John Doe",
|
||||
"device_name": "iPhone 14",
|
||||
"device_id": "device-abc-123",
|
||||
"ip_address": "192.168.1.100",
|
||||
"location_city": "San Francisco",
|
||||
"location_country": "United States",
|
||||
"last_used_at": "2025-10-31T12:00:00Z",
|
||||
"created_at": "2025-10-30T09:00:00Z",
|
||||
"expires_at": "2025-11-06T09:00:00Z",
|
||||
"is_active": True,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class DeviceInfo(BaseModel):
|
||||
"""Device information extracted from request."""
|
||||
device_name: Optional[str] = None
|
||||
device_id: Optional[str] = None
|
||||
ip_address: Optional[str] = None
|
||||
user_agent: Optional[str] = None
|
||||
location_city: Optional[str] = None
|
||||
location_country: Optional[str] = None
|
||||
|
||||
device_name: str | None = None
|
||||
device_id: str | None = None
|
||||
ip_address: str | None = None
|
||||
user_agent: str | None = None
|
||||
location_city: str | None = None
|
||||
location_country: str | None = None
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
@@ -127,7 +177,7 @@ class DeviceInfo(BaseModel):
|
||||
"ip_address": "192.168.1.50",
|
||||
"user_agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7)...",
|
||||
"location_city": "San Francisco",
|
||||
"location_country": "United States"
|
||||
"location_country": "United States",
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
191
backend/app/schemas/users.py
Normal file → Executable file
191
backend/app/schemas/users.py
Normal file → Executable file
@@ -1,84 +1,93 @@
|
||||
# app/schemas/users.py
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, EmailStr, field_validator, ConfigDict, Field
|
||||
from pydantic import BaseModel, ConfigDict, EmailStr, Field, field_validator
|
||||
|
||||
from app.schemas.validators import validate_password_strength, validate_phone_number
|
||||
|
||||
|
||||
class UserBase(BaseModel):
|
||||
email: EmailStr
|
||||
first_name: str
|
||||
last_name: Optional[str] = None
|
||||
phone_number: Optional[str] = None
|
||||
last_name: str | None = None
|
||||
phone_number: str | None = None
|
||||
|
||||
@field_validator('phone_number')
|
||||
@field_validator("phone_number")
|
||||
@classmethod
|
||||
def validate_phone_number(cls, v: Optional[str]) -> Optional[str]:
|
||||
if v is None:
|
||||
return v
|
||||
# Simple regex for phone validation
|
||||
if not re.match(r'^\+?[0-9\s\-\(\)]{8,20}$', v):
|
||||
raise ValueError('Invalid phone number format')
|
||||
return v
|
||||
def validate_phone(cls, v: str | None) -> str | None:
|
||||
return validate_phone_number(v)
|
||||
|
||||
|
||||
class UserCreate(UserBase):
|
||||
password: str
|
||||
is_superuser: bool = False
|
||||
is_active: bool = True
|
||||
|
||||
@field_validator('password')
|
||||
@field_validator("password")
|
||||
@classmethod
|
||||
def password_strength(cls, v: str) -> str:
|
||||
"""Basic password strength validation"""
|
||||
if len(v) < 8:
|
||||
raise ValueError('Password must be at least 8 characters')
|
||||
if not any(char.isdigit() for char in v):
|
||||
raise ValueError('Password must contain at least one digit')
|
||||
if not any(char.isupper() for char in v):
|
||||
raise ValueError('Password must contain at least one uppercase letter')
|
||||
return v
|
||||
"""Enterprise-grade password strength validation"""
|
||||
return validate_password_strength(v)
|
||||
|
||||
|
||||
class UserUpdate(BaseModel):
|
||||
first_name: Optional[str] = None
|
||||
last_name: Optional[str] = None
|
||||
phone_number: Optional[str] = None
|
||||
preferences: Optional[Dict[str, Any]] = None
|
||||
is_active: Optional[bool] = True
|
||||
@field_validator('phone_number')
|
||||
def validate_phone_number(cls, v: Optional[str]) -> Optional[str]:
|
||||
first_name: str | None = None
|
||||
last_name: str | None = None
|
||||
phone_number: str | None = None
|
||||
password: str | None = None
|
||||
preferences: dict[str, Any] | None = None
|
||||
locale: str | None = Field(
|
||||
None,
|
||||
max_length=10,
|
||||
pattern=r"^[a-z]{2}(-[A-Z]{2})?$",
|
||||
description="User's preferred locale (BCP 47 format: en, it, en-US, it-IT)",
|
||||
examples=["en", "it", "en-US", "it-IT"],
|
||||
)
|
||||
is_active: bool | None = (
|
||||
None # Changed default from True to None to avoid unintended updates
|
||||
)
|
||||
is_superuser: bool | None = None # Explicitly reject privilege escalation attempts
|
||||
|
||||
@field_validator("phone_number")
|
||||
@classmethod
|
||||
def validate_phone(cls, v: str | None) -> str | None:
|
||||
return validate_phone_number(v)
|
||||
|
||||
@field_validator("password")
|
||||
@classmethod
|
||||
def password_strength(cls, v: str | None) -> str | None:
|
||||
"""Enterprise-grade password strength validation"""
|
||||
if v is None:
|
||||
return v
|
||||
return validate_password_strength(v)
|
||||
|
||||
# Return early for empty strings or whitespace-only strings
|
||||
if not v or v.strip() == "":
|
||||
raise ValueError('Phone number cannot be empty')
|
||||
@field_validator("locale")
|
||||
@classmethod
|
||||
def validate_locale(cls, v: str | None) -> str | None:
|
||||
"""Validate locale against supported locales."""
|
||||
if v is None:
|
||||
return v
|
||||
# Only support English and Italian for template showcase
|
||||
# Note: Locales stored in lowercase for case-insensitive matching
|
||||
supported_locales = {"en", "it", "en-us", "en-gb", "it-it"}
|
||||
# Normalize to lowercase for comparison and storage
|
||||
v_lower = v.lower()
|
||||
if v_lower not in supported_locales:
|
||||
raise ValueError(
|
||||
f"Unsupported locale '{v}'. Supported locales: {sorted(supported_locales)}"
|
||||
)
|
||||
# Return normalized lowercase version for consistency
|
||||
return v_lower
|
||||
|
||||
# Remove all spaces and formatting characters
|
||||
cleaned = re.sub(r'[\s\-\(\)]', '', v)
|
||||
|
||||
# Basic pattern:
|
||||
# Must start with + or 0
|
||||
# After + must have at least 8 digits
|
||||
# After 0 must have at least 8 digits
|
||||
# Maximum total length of 15 digits (international standard)
|
||||
# Only allowed characters are + at start and digits
|
||||
pattern = r'^(?:\+[0-9]{8,14}|0[0-9]{8,14})$'
|
||||
|
||||
if not re.match(pattern, cleaned):
|
||||
raise ValueError('Phone number must start with + or 0 followed by 8-14 digits')
|
||||
|
||||
# Additional validation to catch specific invalid cases
|
||||
if cleaned.count('+') > 1:
|
||||
raise ValueError('Phone number can only contain one + symbol at the start')
|
||||
|
||||
# Check for any non-digit characters (except the leading +)
|
||||
if not all(c.isdigit() for c in cleaned[1:]):
|
||||
raise ValueError('Phone number can only contain digits after the prefix')
|
||||
|
||||
return cleaned
|
||||
@field_validator("is_superuser")
|
||||
@classmethod
|
||||
def prevent_superuser_modification(cls, v: bool | None) -> bool | None:
|
||||
"""Prevent users from modifying their superuser status via this schema."""
|
||||
if v is not None:
|
||||
raise ValueError("Cannot modify superuser status through user update")
|
||||
return v
|
||||
|
||||
|
||||
class UserInDB(UserBase):
|
||||
@@ -86,7 +95,8 @@ class UserInDB(UserBase):
|
||||
is_active: bool
|
||||
is_superuser: bool
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime] = None
|
||||
updated_at: datetime | None = None
|
||||
locale: str | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@@ -96,26 +106,29 @@ class UserResponse(UserBase):
|
||||
is_active: bool
|
||||
is_superuser: bool
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime] = None
|
||||
updated_at: datetime | None = None
|
||||
locale: str | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class Token(BaseModel):
|
||||
access_token: str
|
||||
refresh_token: Optional[str] = None
|
||||
refresh_token: str | None = None
|
||||
token_type: str = "bearer"
|
||||
user: "UserResponse" # Forward reference since UserResponse is defined above
|
||||
expires_in: int | None = None # Token expiration in seconds
|
||||
|
||||
|
||||
class TokenPayload(BaseModel):
|
||||
sub: str # User ID
|
||||
exp: int # Expiration time
|
||||
iat: Optional[int] = None # Issued at
|
||||
jti: Optional[str] = None # JWT ID
|
||||
is_superuser: Optional[bool] = False
|
||||
first_name: Optional[str] = None
|
||||
email: Optional[str] = None
|
||||
type: Optional[str] = None # Token type (access/refresh)
|
||||
iat: int | None = None # Issued at
|
||||
jti: str | None = None # JWT ID
|
||||
is_superuser: bool | None = False
|
||||
first_name: str | None = None
|
||||
email: str | None = None
|
||||
type: str | None = None # Token type (access/refresh)
|
||||
|
||||
|
||||
class TokenData(BaseModel):
|
||||
@@ -125,38 +138,28 @@ class TokenData(BaseModel):
|
||||
|
||||
class PasswordChange(BaseModel):
|
||||
"""Schema for changing password (requires current password)."""
|
||||
|
||||
current_password: str
|
||||
new_password: str
|
||||
|
||||
@field_validator('new_password')
|
||||
@field_validator("new_password")
|
||||
@classmethod
|
||||
def password_strength(cls, v: str) -> str:
|
||||
"""Basic password strength validation"""
|
||||
if len(v) < 8:
|
||||
raise ValueError('Password must be at least 8 characters')
|
||||
if not any(char.isdigit() for char in v):
|
||||
raise ValueError('Password must contain at least one digit')
|
||||
if not any(char.isupper() for char in v):
|
||||
raise ValueError('Password must contain at least one uppercase letter')
|
||||
return v
|
||||
"""Enterprise-grade password strength validation"""
|
||||
return validate_password_strength(v)
|
||||
|
||||
|
||||
class PasswordReset(BaseModel):
|
||||
"""Schema for resetting password (via email token)."""
|
||||
|
||||
token: str
|
||||
new_password: str
|
||||
|
||||
@field_validator('new_password')
|
||||
@field_validator("new_password")
|
||||
@classmethod
|
||||
def password_strength(cls, v: str) -> str:
|
||||
"""Basic password strength validation"""
|
||||
if len(v) < 8:
|
||||
raise ValueError('Password must be at least 8 characters')
|
||||
if not any(char.isdigit() for char in v):
|
||||
raise ValueError('Password must contain at least one digit')
|
||||
if not any(char.isupper() for char in v):
|
||||
raise ValueError('Password must contain at least one uppercase letter')
|
||||
return v
|
||||
"""Enterprise-grade password strength validation"""
|
||||
return validate_password_strength(v)
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
@@ -170,39 +173,29 @@ class RefreshTokenRequest(BaseModel):
|
||||
|
||||
class PasswordResetRequest(BaseModel):
|
||||
"""Schema for requesting a password reset."""
|
||||
|
||||
email: EmailStr = Field(..., description="Email address of the account")
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"email": "user@example.com"
|
||||
}
|
||||
}
|
||||
}
|
||||
model_config = {"json_schema_extra": {"example": {"email": "user@example.com"}}}
|
||||
|
||||
|
||||
class PasswordResetConfirm(BaseModel):
|
||||
"""Schema for confirming a password reset with token."""
|
||||
|
||||
token: str = Field(..., description="Password reset token from email")
|
||||
new_password: str = Field(..., min_length=8, description="New password")
|
||||
|
||||
@field_validator('new_password')
|
||||
@field_validator("new_password")
|
||||
@classmethod
|
||||
def password_strength(cls, v: str) -> str:
|
||||
"""Basic password strength validation"""
|
||||
if len(v) < 8:
|
||||
raise ValueError('Password must be at least 8 characters')
|
||||
if not any(char.isdigit() for char in v):
|
||||
raise ValueError('Password must contain at least one digit')
|
||||
if not any(char.isupper() for char in v):
|
||||
raise ValueError('Password must contain at least one uppercase letter')
|
||||
return v
|
||||
"""Enterprise-grade password strength validation"""
|
||||
return validate_password_strength(v)
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"token": "eyJwYXlsb2FkIjp7ImVtYWlsIjoidXNlckBleGFtcGxlLmNvbSIsImV4cCI6MTcxMjM0NTY3OH19",
|
||||
"new_password": "NewSecurePassword123"
|
||||
"new_password": "NewSecurePassword123",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
211
backend/app/schemas/validators.py
Normal file
211
backend/app/schemas/validators.py
Normal file
@@ -0,0 +1,211 @@
|
||||
"""
|
||||
Shared validators for Pydantic schemas.
|
||||
|
||||
This module provides reusable validation functions to ensure consistency
|
||||
across all schemas and avoid code duplication.
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
# Common weak passwords that should be rejected
|
||||
COMMON_PASSWORDS: set[str] = {
|
||||
"password",
|
||||
"password1",
|
||||
"password123",
|
||||
"password1234",
|
||||
"admin",
|
||||
"admin123",
|
||||
"admin1234",
|
||||
"welcome",
|
||||
"welcome1",
|
||||
"welcome123",
|
||||
"qwerty",
|
||||
"qwerty123",
|
||||
"12345678",
|
||||
"123456789",
|
||||
"1234567890",
|
||||
"letmein",
|
||||
"letmein1",
|
||||
"letmein123",
|
||||
"monkey123",
|
||||
"dragon123",
|
||||
"passw0rd",
|
||||
"p@ssw0rd",
|
||||
"p@ssword",
|
||||
}
|
||||
|
||||
|
||||
def validate_password_strength(password: str) -> str:
|
||||
"""
|
||||
Validate password strength with enterprise-grade requirements.
|
||||
|
||||
Requirements:
|
||||
- Minimum 12 characters (increased from 8 for better security)
|
||||
- At least one lowercase letter
|
||||
- At least one uppercase letter
|
||||
- At least one digit
|
||||
- At least one special character
|
||||
- Not in common password list
|
||||
|
||||
Args:
|
||||
password: The password to validate
|
||||
|
||||
Returns:
|
||||
The validated password
|
||||
|
||||
Raises:
|
||||
ValueError: If password doesn't meet requirements
|
||||
|
||||
Examples:
|
||||
>>> validate_password_strength("MySecureP@ss123") # Valid
|
||||
>>> validate_password_strength("password1") # Invalid - too weak
|
||||
"""
|
||||
# Check if we are in demo mode
|
||||
from app.core.config import settings
|
||||
|
||||
if settings.DEMO_MODE:
|
||||
# In demo mode, allow specific weak passwords for demo accounts
|
||||
demo_passwords = {"Demo123!", "Admin123!"}
|
||||
if password in demo_passwords:
|
||||
return password
|
||||
|
||||
# Check minimum length
|
||||
if len(password) < 12:
|
||||
raise ValueError("Password must be at least 12 characters long")
|
||||
|
||||
# Check against common passwords (case-insensitive)
|
||||
if password.lower() in COMMON_PASSWORDS:
|
||||
raise ValueError("Password is too common. Please choose a stronger password")
|
||||
|
||||
# Check for required character types
|
||||
checks = [
|
||||
(any(c.islower() for c in password), "at least one lowercase letter"),
|
||||
(any(c.isupper() for c in password), "at least one uppercase letter"),
|
||||
(any(c.isdigit() for c in password), "at least one digit"),
|
||||
(
|
||||
any(c in "!@#$%^&*()_+-=[]{}|;:,.<>?~`" for c in password),
|
||||
"at least one special character (!@#$%^&*()_+-=[]{}|;:,.<>?~`)",
|
||||
),
|
||||
]
|
||||
|
||||
failed = [msg for check, msg in checks if not check]
|
||||
if failed:
|
||||
raise ValueError(f"Password must contain {', '.join(failed)}")
|
||||
|
||||
return password
|
||||
|
||||
|
||||
def validate_phone_number(phone: str | None) -> str | None:
|
||||
"""
|
||||
Validate phone number format.
|
||||
|
||||
Accepts international format with + prefix or local format with 0 prefix.
|
||||
Removes formatting characters (spaces, hyphens, parentheses).
|
||||
|
||||
Args:
|
||||
phone: Phone number to validate (can be None)
|
||||
|
||||
Returns:
|
||||
Cleaned phone number or None
|
||||
|
||||
Raises:
|
||||
ValueError: If phone number format is invalid
|
||||
|
||||
Examples:
|
||||
>>> validate_phone_number("+1 (555) 123-4567") # Valid
|
||||
>>> validate_phone_number("0412 345 678") # Valid
|
||||
>>> validate_phone_number("invalid") # Invalid
|
||||
"""
|
||||
if phone is None:
|
||||
return None
|
||||
|
||||
# Check for empty strings
|
||||
if not phone or phone.strip() == "":
|
||||
raise ValueError("Phone number cannot be empty")
|
||||
|
||||
# Remove all spaces and formatting characters
|
||||
cleaned = re.sub(r"[\s\-\(\)]", "", phone)
|
||||
|
||||
# Basic pattern:
|
||||
# Must start with + or 0
|
||||
# After + must have at least 8 digits
|
||||
# After 0 must have at least 8 digits
|
||||
# Maximum total length of 15 digits (international standard)
|
||||
# Only allowed characters are + at start and digits
|
||||
pattern = r"^(?:\+[0-9]{8,14}|0[0-9]{8,14})$"
|
||||
|
||||
if not re.match(pattern, cleaned):
|
||||
raise ValueError("Phone number must start with + or 0 followed by 8-14 digits")
|
||||
|
||||
# Additional validation to catch specific invalid cases
|
||||
# NOTE: These checks are defensive code - the regex pattern above already catches these cases
|
||||
if cleaned.count("+") > 1: # pragma: no cover
|
||||
raise ValueError("Phone number can only contain one + symbol at the start")
|
||||
|
||||
# Check for any non-digit characters (except the leading +)
|
||||
if not all(c.isdigit() for c in cleaned[1:]): # pragma: no cover
|
||||
raise ValueError("Phone number can only contain digits after the prefix")
|
||||
|
||||
return cleaned
|
||||
|
||||
|
||||
def validate_email_format(email: str) -> str:
|
||||
"""
|
||||
Additional email validation beyond Pydantic's EmailStr.
|
||||
|
||||
This can be extended for custom email validation rules.
|
||||
|
||||
Args:
|
||||
email: Email address to validate
|
||||
|
||||
Returns:
|
||||
Validated email address
|
||||
|
||||
Raises:
|
||||
ValueError: If email format is invalid
|
||||
"""
|
||||
# Pydantic's EmailStr already does comprehensive validation
|
||||
# This function is here for custom rules if needed
|
||||
|
||||
# Example: Reject disposable email domains (optional)
|
||||
# disposable_domains = {'tempmail.com', '10minutemail.com', 'guerrillamail.com'}
|
||||
# domain = email.split('@')[1].lower()
|
||||
# if domain in disposable_domains:
|
||||
# raise ValueError('Disposable email addresses are not allowed')
|
||||
|
||||
return email.lower() # Normalize to lowercase
|
||||
|
||||
|
||||
def validate_slug(slug: str) -> str:
|
||||
"""
|
||||
Validate URL slug format.
|
||||
|
||||
Slugs must:
|
||||
- Be 2-50 characters long
|
||||
- Contain only lowercase letters, numbers, and hyphens
|
||||
- Not start or end with a hyphen
|
||||
- Not contain consecutive hyphens
|
||||
|
||||
Args:
|
||||
slug: URL slug to validate
|
||||
|
||||
Returns:
|
||||
Validated slug
|
||||
|
||||
Raises:
|
||||
ValueError: If slug format is invalid
|
||||
"""
|
||||
if not slug or len(slug) < 2:
|
||||
raise ValueError("Slug must be at least 2 characters long")
|
||||
|
||||
if len(slug) > 50:
|
||||
raise ValueError("Slug must be at most 50 characters long")
|
||||
|
||||
# Check format
|
||||
if not re.match(r"^[a-z0-9]+(?:-[a-z0-9]+)*$", slug):
|
||||
raise ValueError(
|
||||
"Slug can only contain lowercase letters, numbers, and hyphens. "
|
||||
"It cannot start or end with a hyphen, and cannot contain consecutive hyphens"
|
||||
)
|
||||
|
||||
return slug
|
||||
@@ -0,0 +1,19 @@
|
||||
# app/services/__init__.py
|
||||
from . import oauth_provider_service
|
||||
from .auth_service import AuthService
|
||||
from .oauth_service import OAuthService
|
||||
from .organization_service import OrganizationService, organization_service
|
||||
from .session_service import SessionService, session_service
|
||||
from .user_service import UserService, user_service
|
||||
|
||||
__all__ = [
|
||||
"AuthService",
|
||||
"OAuthService",
|
||||
"OrganizationService",
|
||||
"SessionService",
|
||||
"UserService",
|
||||
"oauth_provider_service",
|
||||
"organization_service",
|
||||
"session_service",
|
||||
"user_service",
|
||||
]
|
||||
|
||||
174
backend/app/services/auth_service.py
Normal file → Executable file
174
backend/app/services/auth_service.py
Normal file → Executable file
@@ -1,36 +1,40 @@
|
||||
# app/services/auth_service.py
|
||||
import logging
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.auth import (
|
||||
verify_password,
|
||||
get_password_hash,
|
||||
TokenExpiredError,
|
||||
TokenInvalidError,
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
TokenExpiredError,
|
||||
TokenInvalidError
|
||||
get_password_hash_async,
|
||||
verify_password_async,
|
||||
)
|
||||
from app.core.config import settings
|
||||
from app.core.exceptions import AuthenticationError, DuplicateError
|
||||
from app.core.repository_exceptions import DuplicateEntryError
|
||||
from app.models.user import User
|
||||
from app.schemas.users import Token, UserCreate
|
||||
from app.repositories.user import user_repo
|
||||
from app.schemas.users import Token, UserCreate, UserResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AuthenticationError(Exception):
|
||||
"""Exception raised for authentication errors"""
|
||||
pass
|
||||
# Pre-computed bcrypt hash used for constant-time comparison when user is not found,
|
||||
# preventing timing attacks that could enumerate valid email addresses.
|
||||
_DUMMY_HASH = "$2b$12$EixZaYVK1fsbw1ZfbX3OXePaWxn96p36zLFbnJHfxPSEFBzXKiHia"
|
||||
|
||||
|
||||
class AuthService:
|
||||
"""Service for handling authentication operations"""
|
||||
|
||||
@staticmethod
|
||||
def authenticate_user(db: Session, email: str, password: str) -> Optional[User]:
|
||||
async def authenticate_user(
|
||||
db: AsyncSession, email: str, password: str
|
||||
) -> User | None:
|
||||
"""
|
||||
Authenticate a user with email and password.
|
||||
Authenticate a user with email and password using async password verification.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
@@ -40,12 +44,16 @@ class AuthService:
|
||||
Returns:
|
||||
User if authenticated, None otherwise
|
||||
"""
|
||||
user = db.query(User).filter(User.email == email).first()
|
||||
user = await user_repo.get_by_email(db, email=email)
|
||||
|
||||
if not user:
|
||||
# Perform a dummy verification to match timing of a real bcrypt check,
|
||||
# preventing email enumeration via response-time differences.
|
||||
await verify_password_async(password, _DUMMY_HASH)
|
||||
return None
|
||||
|
||||
if not verify_password(password, user.password_hash):
|
||||
# Verify password asynchronously to avoid blocking event loop
|
||||
if not await verify_password_async(password, user.password_hash):
|
||||
return None
|
||||
|
||||
if not user.is_active:
|
||||
@@ -54,7 +62,7 @@ class AuthService:
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
def create_user(db: Session, user_data: UserCreate) -> User:
|
||||
async def create_user(db: AsyncSession, user_data: UserCreate) -> User:
|
||||
"""
|
||||
Create a new user.
|
||||
|
||||
@@ -64,31 +72,30 @@ class AuthService:
|
||||
|
||||
Returns:
|
||||
Created user
|
||||
|
||||
Raises:
|
||||
AuthenticationError: If user already exists or creation fails
|
||||
"""
|
||||
# Check if user already exists
|
||||
existing_user = db.query(User).filter(User.email == user_data.email).first()
|
||||
if existing_user:
|
||||
raise AuthenticationError("User with this email already exists")
|
||||
try:
|
||||
# Check if user already exists
|
||||
existing_user = await user_repo.get_by_email(db, email=user_data.email)
|
||||
if existing_user:
|
||||
raise DuplicateError("User with this email already exists")
|
||||
|
||||
# Create new user
|
||||
hashed_password = get_password_hash(user_data.password)
|
||||
# Delegate creation (hashing + commit) to the repository
|
||||
user = await user_repo.create(db, obj_in=user_data)
|
||||
|
||||
# Create user object from model
|
||||
user = User(
|
||||
email=user_data.email,
|
||||
password_hash=hashed_password,
|
||||
first_name=user_data.first_name,
|
||||
last_name=user_data.last_name,
|
||||
phone_number=user_data.phone_number,
|
||||
is_active=True,
|
||||
is_superuser=False
|
||||
)
|
||||
logger.info("User created successfully: %s", user.email)
|
||||
return user
|
||||
|
||||
db.add(user)
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
|
||||
return user
|
||||
except (AuthenticationError, DuplicateError):
|
||||
# Re-raise API exceptions without rollback
|
||||
raise
|
||||
except DuplicateEntryError as e:
|
||||
raise DuplicateError(str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Error creating user: %s", e)
|
||||
raise AuthenticationError(f"Failed to create user: {e!s}")
|
||||
|
||||
@staticmethod
|
||||
def create_tokens(user: User) -> Token:
|
||||
@@ -99,32 +106,33 @@ class AuthService:
|
||||
user: User to create tokens for
|
||||
|
||||
Returns:
|
||||
Token object with access and refresh tokens
|
||||
Token object with access and refresh tokens and user info
|
||||
"""
|
||||
# Generate claims
|
||||
claims = {
|
||||
"is_superuser": user.is_superuser,
|
||||
"email": user.email,
|
||||
"first_name": user.first_name
|
||||
"first_name": user.first_name,
|
||||
}
|
||||
|
||||
# Create tokens
|
||||
access_token = create_access_token(
|
||||
subject=str(user.id),
|
||||
claims=claims
|
||||
)
|
||||
access_token = create_access_token(subject=str(user.id), claims=claims)
|
||||
|
||||
refresh_token = create_refresh_token(
|
||||
subject=str(user.id)
|
||||
)
|
||||
refresh_token = create_refresh_token(subject=str(user.id))
|
||||
|
||||
# Convert User model to UserResponse schema
|
||||
user_response = UserResponse.model_validate(user)
|
||||
|
||||
return Token(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token
|
||||
refresh_token=refresh_token,
|
||||
user=user_response,
|
||||
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
* 60, # Convert minutes to seconds
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def refresh_tokens(db: Session, refresh_token: str) -> Token:
|
||||
async def refresh_tokens(db: AsyncSession, refresh_token: str) -> Token:
|
||||
"""
|
||||
Generate new tokens using a refresh token.
|
||||
|
||||
@@ -150,7 +158,7 @@ class AuthService:
|
||||
user_id = token_data.user_id
|
||||
|
||||
# Get user from database
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
user = await user_repo.get(db, id=str(user_id))
|
||||
if not user or not user.is_active:
|
||||
raise TokenInvalidError("Invalid user or inactive account")
|
||||
|
||||
@@ -158,11 +166,13 @@ class AuthService:
|
||||
return AuthService.create_tokens(user)
|
||||
|
||||
except (TokenExpiredError, TokenInvalidError) as e:
|
||||
logger.warning(f"Token refresh failed: {str(e)}")
|
||||
logger.warning("Token refresh failed: %s", e)
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def change_password(db: Session, user_id: UUID, current_password: str, new_password: str) -> bool:
|
||||
async def change_password(
|
||||
db: AsyncSession, user_id: UUID, current_password: str, new_password: str
|
||||
) -> bool:
|
||||
"""
|
||||
Change a user's password.
|
||||
|
||||
@@ -176,18 +186,58 @@ class AuthService:
|
||||
True if password was changed successfully
|
||||
|
||||
Raises:
|
||||
AuthenticationError: If current password is incorrect
|
||||
AuthenticationError: If current password is incorrect or update fails
|
||||
"""
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
try:
|
||||
user = await user_repo.get(db, id=str(user_id))
|
||||
if not user:
|
||||
raise AuthenticationError("User not found")
|
||||
|
||||
# Verify current password asynchronously
|
||||
if not await verify_password_async(current_password, user.password_hash):
|
||||
raise AuthenticationError("Current password is incorrect")
|
||||
|
||||
# Hash new password asynchronously to avoid blocking event loop
|
||||
new_hash = await get_password_hash_async(new_password)
|
||||
await user_repo.update_password(db, user=user, password_hash=new_hash)
|
||||
|
||||
logger.info("Password changed successfully for user %s", user_id)
|
||||
return True
|
||||
|
||||
except AuthenticationError:
|
||||
# Re-raise authentication errors without rollback
|
||||
raise
|
||||
except Exception as e:
|
||||
# Rollback on any database errors
|
||||
await db.rollback()
|
||||
logger.exception("Error changing password for user %s: %s", user_id, e)
|
||||
raise AuthenticationError(f"Failed to change password: {e!s}")
|
||||
|
||||
@staticmethod
|
||||
async def reset_password(
|
||||
db: AsyncSession, *, email: str, new_password: str
|
||||
) -> User:
|
||||
"""
|
||||
Reset a user's password without requiring the current password.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
email: User email address
|
||||
new_password: New password to set
|
||||
|
||||
Returns:
|
||||
Updated user
|
||||
|
||||
Raises:
|
||||
AuthenticationError: If user not found or inactive
|
||||
"""
|
||||
user = await user_repo.get_by_email(db, email=email)
|
||||
if not user:
|
||||
raise AuthenticationError("User not found")
|
||||
if not user.is_active:
|
||||
raise AuthenticationError("User account is inactive")
|
||||
|
||||
# Verify current password
|
||||
if not verify_password(current_password, user.password_hash):
|
||||
raise AuthenticationError("Current password is incorrect")
|
||||
|
||||
# Update password
|
||||
user.password_hash = get_password_hash(new_password)
|
||||
db.commit()
|
||||
|
||||
return True
|
||||
new_hash = await get_password_hash_async(new_password)
|
||||
user = await user_repo.update_password(db, user=user, password_hash=new_hash)
|
||||
logger.info("Password reset successfully for %s", email)
|
||||
return user
|
||||
|
||||
@@ -5,8 +5,8 @@ Email service with placeholder implementation.
|
||||
This service provides email sending functionality with a simple console/log-based
|
||||
placeholder that can be easily replaced with a real email provider (SendGrid, SES, etc.)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from app.core.config import settings
|
||||
@@ -20,13 +20,12 @@ class EmailBackend(ABC):
|
||||
@abstractmethod
|
||||
async def send_email(
|
||||
self,
|
||||
to: List[str],
|
||||
to: list[str],
|
||||
subject: str,
|
||||
html_content: str,
|
||||
text_content: Optional[str] = None
|
||||
text_content: str | None = None,
|
||||
) -> bool:
|
||||
"""Send an email."""
|
||||
pass
|
||||
|
||||
|
||||
class ConsoleEmailBackend(EmailBackend):
|
||||
@@ -39,10 +38,10 @@ class ConsoleEmailBackend(EmailBackend):
|
||||
|
||||
async def send_email(
|
||||
self,
|
||||
to: List[str],
|
||||
to: list[str],
|
||||
subject: str,
|
||||
html_content: str,
|
||||
text_content: Optional[str] = None
|
||||
text_content: str | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Log email content to console/logs.
|
||||
@@ -59,8 +58,8 @@ class ConsoleEmailBackend(EmailBackend):
|
||||
logger.info("=" * 80)
|
||||
logger.info("EMAIL SENT (Console Backend)")
|
||||
logger.info("=" * 80)
|
||||
logger.info(f"To: {', '.join(to)}")
|
||||
logger.info(f"Subject: {subject}")
|
||||
logger.info("To: %s", ", ".join(to))
|
||||
logger.info("Subject: %s", subject)
|
||||
logger.info("-" * 80)
|
||||
if text_content:
|
||||
logger.info("Plain Text Content:")
|
||||
@@ -88,10 +87,10 @@ class SMTPEmailBackend(EmailBackend):
|
||||
|
||||
async def send_email(
|
||||
self,
|
||||
to: List[str],
|
||||
to: list[str],
|
||||
subject: str,
|
||||
html_content: str,
|
||||
text_content: Optional[str] = None
|
||||
text_content: str | None = None,
|
||||
) -> bool:
|
||||
"""Send email via SMTP."""
|
||||
# TODO: Implement SMTP sending
|
||||
@@ -108,7 +107,7 @@ class EmailService:
|
||||
and can be configured to use different backends (console, SMTP, SendGrid, etc.)
|
||||
"""
|
||||
|
||||
def __init__(self, backend: Optional[EmailBackend] = None):
|
||||
def __init__(self, backend: EmailBackend | None = None):
|
||||
"""
|
||||
Initialize email service with a backend.
|
||||
|
||||
@@ -118,10 +117,7 @@ class EmailService:
|
||||
self.backend = backend or ConsoleEmailBackend()
|
||||
|
||||
async def send_password_reset_email(
|
||||
self,
|
||||
to_email: str,
|
||||
reset_token: str,
|
||||
user_name: Optional[str] = None
|
||||
self, to_email: str, reset_token: str, user_name: str | None = None
|
||||
) -> bool:
|
||||
"""
|
||||
Send password reset email.
|
||||
@@ -142,7 +138,7 @@ class EmailService:
|
||||
|
||||
# Plain text version
|
||||
text_content = f"""
|
||||
Hello{' ' + user_name if user_name else ''},
|
||||
Hello{" " + user_name if user_name else ""},
|
||||
|
||||
You requested a password reset for your account. Click the link below to reset your password:
|
||||
|
||||
@@ -177,7 +173,7 @@ The {settings.PROJECT_NAME} Team
|
||||
<h1>Password Reset</h1>
|
||||
</div>
|
||||
<div class="content">
|
||||
<p>Hello{' ' + user_name if user_name else ''},</p>
|
||||
<p>Hello{" " + user_name if user_name else ""},</p>
|
||||
<p>You requested a password reset for your account. Click the button below to reset your password:</p>
|
||||
<p style="text-align: center;">
|
||||
<a href="{reset_url}" class="button">Reset Password</a>
|
||||
@@ -200,17 +196,14 @@ The {settings.PROJECT_NAME} Team
|
||||
to=[to_email],
|
||||
subject=subject,
|
||||
html_content=html_content,
|
||||
text_content=text_content
|
||||
text_content=text_content,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send password reset email to {to_email}: {str(e)}")
|
||||
logger.error("Failed to send password reset email to %s: %s", to_email, e)
|
||||
return False
|
||||
|
||||
async def send_email_verification(
|
||||
self,
|
||||
to_email: str,
|
||||
verification_token: str,
|
||||
user_name: Optional[str] = None
|
||||
self, to_email: str, verification_token: str, user_name: str | None = None
|
||||
) -> bool:
|
||||
"""
|
||||
Send email verification email.
|
||||
@@ -224,14 +217,16 @@ The {settings.PROJECT_NAME} Team
|
||||
True if email sent successfully
|
||||
"""
|
||||
# Generate verification URL
|
||||
verification_url = f"{settings.FRONTEND_URL}/verify-email?token={verification_token}"
|
||||
verification_url = (
|
||||
f"{settings.FRONTEND_URL}/verify-email?token={verification_token}"
|
||||
)
|
||||
|
||||
# Prepare email content
|
||||
subject = "Verify Your Email Address"
|
||||
|
||||
# Plain text version
|
||||
text_content = f"""
|
||||
Hello{' ' + user_name if user_name else ''},
|
||||
Hello{" " + user_name if user_name else ""},
|
||||
|
||||
Thank you for signing up! Please verify your email address by clicking the link below:
|
||||
|
||||
@@ -266,7 +261,7 @@ The {settings.PROJECT_NAME} Team
|
||||
<h1>Verify Your Email</h1>
|
||||
</div>
|
||||
<div class="content">
|
||||
<p>Hello{' ' + user_name if user_name else ''},</p>
|
||||
<p>Hello{" " + user_name if user_name else ""},</p>
|
||||
<p>Thank you for signing up! Please verify your email address by clicking the button below:</p>
|
||||
<p style="text-align: center;">
|
||||
<a href="{verification_url}" class="button">Verify Email</a>
|
||||
@@ -289,10 +284,10 @@ The {settings.PROJECT_NAME} Team
|
||||
to=[to_email],
|
||||
subject=subject,
|
||||
html_content=html_content,
|
||||
text_content=text_content
|
||||
text_content=text_content,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send verification email to {to_email}: {str(e)}")
|
||||
logger.error("Failed to send verification email to %s: %s", to_email, e)
|
||||
return False
|
||||
|
||||
|
||||
|
||||
970
backend/app/services/oauth_provider_service.py
Executable file
970
backend/app/services/oauth_provider_service.py
Executable file
@@ -0,0 +1,970 @@
|
||||
"""
|
||||
OAuth Provider Service for MCP integration.
|
||||
|
||||
Implements OAuth 2.0 Authorization Server functionality:
|
||||
- Authorization code flow with PKCE
|
||||
- Token issuance (JWT access tokens, opaque refresh tokens)
|
||||
- Token refresh
|
||||
- Token revocation
|
||||
- Consent management
|
||||
|
||||
Security features:
|
||||
- PKCE required for public clients (S256)
|
||||
- Short-lived authorization codes (10 minutes)
|
||||
- JWT access tokens (self-contained, no DB lookup)
|
||||
- Secure refresh token storage (hashed)
|
||||
- Token rotation on refresh
|
||||
- Comprehensive validation
|
||||
"""
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import logging
|
||||
import secrets
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
import jwt
|
||||
from jwt.exceptions import ExpiredSignatureError, InvalidTokenError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.config import settings
|
||||
from app.models.oauth_client import OAuthClient
|
||||
from app.models.user import User
|
||||
from app.repositories.oauth_authorization_code import oauth_authorization_code_repo
|
||||
from app.repositories.oauth_client import oauth_client_repo
|
||||
from app.repositories.oauth_consent import oauth_consent_repo
|
||||
from app.repositories.oauth_provider_token import oauth_provider_token_repo
|
||||
from app.repositories.user import user_repo
|
||||
from app.schemas.oauth import OAuthClientCreate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Constants
|
||||
AUTHORIZATION_CODE_EXPIRY_MINUTES = 10
|
||||
ACCESS_TOKEN_EXPIRY_MINUTES = 60 # 1 hour for MCP clients
|
||||
REFRESH_TOKEN_EXPIRY_DAYS = 30
|
||||
|
||||
|
||||
class OAuthProviderError(Exception):
|
||||
"""Base exception for OAuth provider errors."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
error: str,
|
||||
error_description: str | None = None,
|
||||
error_uri: str | None = None,
|
||||
):
|
||||
self.error = error
|
||||
self.error_description = error_description
|
||||
self.error_uri = error_uri
|
||||
super().__init__(error_description or error)
|
||||
|
||||
|
||||
class InvalidClientError(OAuthProviderError):
|
||||
"""Client authentication failed."""
|
||||
|
||||
def __init__(self, description: str = "Invalid client credentials"):
|
||||
super().__init__("invalid_client", description)
|
||||
|
||||
|
||||
class InvalidGrantError(OAuthProviderError):
|
||||
"""Invalid authorization grant."""
|
||||
|
||||
def __init__(self, description: str = "Invalid grant"):
|
||||
super().__init__("invalid_grant", description)
|
||||
|
||||
|
||||
class InvalidRequestError(OAuthProviderError):
|
||||
"""Invalid request parameters."""
|
||||
|
||||
def __init__(self, description: str = "Invalid request"):
|
||||
super().__init__("invalid_request", description)
|
||||
|
||||
|
||||
class InvalidScopeError(OAuthProviderError):
|
||||
"""Invalid scope requested."""
|
||||
|
||||
def __init__(self, description: str = "Invalid scope"):
|
||||
super().__init__("invalid_scope", description)
|
||||
|
||||
|
||||
class UnauthorizedClientError(OAuthProviderError):
|
||||
"""Client not authorized for this grant type."""
|
||||
|
||||
def __init__(self, description: str = "Unauthorized client"):
|
||||
super().__init__("unauthorized_client", description)
|
||||
|
||||
|
||||
class AccessDeniedError(OAuthProviderError):
|
||||
"""User denied authorization."""
|
||||
|
||||
def __init__(self, description: str = "Access denied"):
|
||||
super().__init__("access_denied", description)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Helper Functions
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def generate_code() -> str:
|
||||
"""Generate a cryptographically secure authorization code."""
|
||||
return secrets.token_urlsafe(64)
|
||||
|
||||
|
||||
def generate_token() -> str:
|
||||
"""Generate a cryptographically secure token."""
|
||||
return secrets.token_urlsafe(48)
|
||||
|
||||
|
||||
def generate_jti() -> str:
|
||||
"""Generate a unique JWT ID."""
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
|
||||
def hash_token(token: str) -> str:
|
||||
"""Hash a token using SHA-256."""
|
||||
return hashlib.sha256(token.encode()).hexdigest()
|
||||
|
||||
|
||||
def verify_pkce(code_verifier: str, code_challenge: str, method: str) -> bool:
|
||||
"""
|
||||
Verify PKCE code_verifier against stored code_challenge.
|
||||
|
||||
SECURITY: Only S256 method is supported. The 'plain' method provides
|
||||
no security benefit and is explicitly rejected per RFC 7636 Section 4.3.
|
||||
"""
|
||||
if method != "S256":
|
||||
# SECURITY: Reject any method other than S256
|
||||
# 'plain' method provides no security against code interception attacks
|
||||
logger.warning("PKCE verification rejected for unsupported method: %s", method)
|
||||
return False
|
||||
|
||||
# SHA-256 hash, then base64url encode (RFC 7636 Section 4.2)
|
||||
digest = hashlib.sha256(code_verifier.encode()).digest()
|
||||
computed = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
|
||||
return secrets.compare_digest(computed, code_challenge)
|
||||
|
||||
|
||||
def parse_scope(scope: str) -> list[str]:
|
||||
"""Parse space-separated scope string into list."""
|
||||
return [s.strip() for s in scope.split() if s.strip()]
|
||||
|
||||
|
||||
def join_scope(scopes: list[str]) -> str:
|
||||
"""Join scope list into space-separated string."""
|
||||
return " ".join(sorted(set(scopes)))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Client Validation
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def get_client(db: AsyncSession, client_id: str) -> OAuthClient | None:
|
||||
"""Get OAuth client by client_id."""
|
||||
return await oauth_client_repo.get_by_client_id(db, client_id=client_id)
|
||||
|
||||
|
||||
async def validate_client(
|
||||
db: AsyncSession,
|
||||
client_id: str,
|
||||
client_secret: str | None = None,
|
||||
require_secret: bool = False,
|
||||
) -> OAuthClient:
|
||||
"""
|
||||
Validate OAuth client credentials.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
client_id: Client identifier
|
||||
client_secret: Client secret (required for confidential clients)
|
||||
require_secret: Whether to require secret validation
|
||||
|
||||
Returns:
|
||||
Validated OAuthClient
|
||||
|
||||
Raises:
|
||||
InvalidClientError: If client validation fails
|
||||
"""
|
||||
client = await get_client(db, client_id)
|
||||
if not client:
|
||||
raise InvalidClientError("Unknown client_id")
|
||||
|
||||
# Confidential clients must provide valid secret
|
||||
if client.client_type == "confidential" or require_secret:
|
||||
if not client_secret:
|
||||
raise InvalidClientError("Client secret required")
|
||||
if not client.client_secret_hash:
|
||||
raise InvalidClientError("Client not configured with secret")
|
||||
|
||||
# SECURITY: Verify secret using bcrypt
|
||||
from app.core.auth import verify_password
|
||||
|
||||
stored_hash = str(client.client_secret_hash)
|
||||
|
||||
if not stored_hash.startswith("$2"):
|
||||
raise InvalidClientError(
|
||||
"Client secret uses deprecated hash format. "
|
||||
"Please regenerate your client credentials."
|
||||
)
|
||||
|
||||
if not verify_password(client_secret, stored_hash):
|
||||
raise InvalidClientError("Invalid client secret")
|
||||
|
||||
return client
|
||||
|
||||
|
||||
def validate_redirect_uri(client: OAuthClient, redirect_uri: str) -> None:
|
||||
"""
|
||||
Validate redirect_uri against client's registered URIs.
|
||||
|
||||
Raises:
|
||||
InvalidRequestError: If redirect_uri is not registered
|
||||
"""
|
||||
if not client.redirect_uris:
|
||||
raise InvalidRequestError("Client has no registered redirect URIs")
|
||||
|
||||
if redirect_uri not in client.redirect_uris:
|
||||
raise InvalidRequestError("Invalid redirect_uri")
|
||||
|
||||
|
||||
def validate_scopes(client: OAuthClient, requested_scopes: list[str]) -> list[str]:
|
||||
"""
|
||||
Validate requested scopes against client's allowed scopes.
|
||||
|
||||
Returns:
|
||||
List of valid scopes (intersection of requested and allowed)
|
||||
|
||||
Raises:
|
||||
InvalidScopeError: If no valid scopes
|
||||
"""
|
||||
allowed = set(client.allowed_scopes or [])
|
||||
requested = set(requested_scopes)
|
||||
|
||||
# If no scopes requested, use all allowed scopes
|
||||
if not requested:
|
||||
return list(allowed)
|
||||
|
||||
valid = requested & allowed
|
||||
if not valid:
|
||||
raise InvalidScopeError(
|
||||
"None of the requested scopes are allowed for this client"
|
||||
)
|
||||
|
||||
# Warn if some scopes were filtered out
|
||||
invalid = requested - allowed
|
||||
if invalid:
|
||||
logger.warning(
|
||||
"Client %s requested invalid scopes: %s", client.client_id, invalid
|
||||
)
|
||||
|
||||
return list(valid)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Authorization Code Flow
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def create_authorization_code(
|
||||
db: AsyncSession,
|
||||
client: OAuthClient,
|
||||
user: User,
|
||||
redirect_uri: str,
|
||||
scope: str,
|
||||
code_challenge: str | None = None,
|
||||
code_challenge_method: str | None = None,
|
||||
state: str | None = None,
|
||||
nonce: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Create an authorization code for the authorization code flow.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
client: Validated OAuth client
|
||||
user: Authenticated user
|
||||
redirect_uri: Validated redirect URI
|
||||
scope: Granted scopes (space-separated)
|
||||
code_challenge: PKCE code challenge
|
||||
code_challenge_method: PKCE method (S256)
|
||||
state: CSRF state parameter
|
||||
nonce: OpenID Connect nonce
|
||||
|
||||
Returns:
|
||||
Authorization code string
|
||||
"""
|
||||
# Public clients MUST use PKCE
|
||||
if client.client_type == "public":
|
||||
if not code_challenge or code_challenge_method != "S256":
|
||||
raise InvalidRequestError("PKCE with S256 is required for public clients")
|
||||
|
||||
code = generate_code()
|
||||
expires_at = datetime.now(UTC) + timedelta(
|
||||
minutes=AUTHORIZATION_CODE_EXPIRY_MINUTES
|
||||
)
|
||||
|
||||
await oauth_authorization_code_repo.create_code(
|
||||
db,
|
||||
code=code,
|
||||
client_id=client.client_id,
|
||||
user_id=user.id,
|
||||
redirect_uri=redirect_uri,
|
||||
scope=scope,
|
||||
expires_at=expires_at,
|
||||
code_challenge=code_challenge,
|
||||
code_challenge_method=code_challenge_method,
|
||||
state=state,
|
||||
nonce=nonce,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Created authorization code for user %s and client %s",
|
||||
user.id,
|
||||
client.client_id,
|
||||
)
|
||||
return code
|
||||
|
||||
|
||||
async def exchange_authorization_code(
|
||||
db: AsyncSession,
|
||||
code: str,
|
||||
client_id: str,
|
||||
redirect_uri: str,
|
||||
code_verifier: str | None = None,
|
||||
client_secret: str | None = None,
|
||||
device_info: str | None = None,
|
||||
ip_address: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Exchange authorization code for tokens.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
code: Authorization code
|
||||
client_id: Client identifier
|
||||
redirect_uri: Must match the original redirect_uri
|
||||
code_verifier: PKCE code verifier
|
||||
client_secret: Client secret (for confidential clients)
|
||||
device_info: Optional device information
|
||||
ip_address: Optional IP address
|
||||
|
||||
Returns:
|
||||
Token response dict with access_token, refresh_token, etc.
|
||||
|
||||
Raises:
|
||||
InvalidGrantError: If code is invalid, expired, or already used
|
||||
InvalidClientError: If client validation fails
|
||||
"""
|
||||
# Atomically mark code as used and fetch it (prevents race condition)
|
||||
# RFC 6749 Section 4.1.2: Authorization codes MUST be single-use
|
||||
updated_id = await oauth_authorization_code_repo.consume_code_atomically(
|
||||
db, code=code
|
||||
)
|
||||
|
||||
if not updated_id:
|
||||
# Either code doesn't exist or was already used
|
||||
# Check if it exists to provide appropriate error
|
||||
existing_code = await oauth_authorization_code_repo.get_by_code(db, code=code)
|
||||
|
||||
if existing_code and existing_code.used:
|
||||
# Code reuse is a security incident - revoke all tokens for this grant
|
||||
logger.warning(
|
||||
"Authorization code reuse detected for client %s",
|
||||
existing_code.client_id,
|
||||
)
|
||||
await revoke_tokens_for_user_client(
|
||||
db, UUID(str(existing_code.user_id)), str(existing_code.client_id)
|
||||
)
|
||||
raise InvalidGrantError("Authorization code has already been used")
|
||||
else:
|
||||
raise InvalidGrantError("Invalid authorization code")
|
||||
|
||||
# Now fetch the full auth code record
|
||||
auth_code = await oauth_authorization_code_repo.get_by_id(db, code_id=updated_id)
|
||||
if auth_code is None:
|
||||
raise InvalidGrantError("Authorization code not found after consumption")
|
||||
|
||||
if auth_code.is_expired:
|
||||
raise InvalidGrantError("Authorization code has expired")
|
||||
|
||||
if auth_code.client_id != client_id:
|
||||
raise InvalidGrantError("Authorization code was not issued to this client")
|
||||
|
||||
if auth_code.redirect_uri != redirect_uri:
|
||||
raise InvalidGrantError("redirect_uri mismatch")
|
||||
|
||||
# Validate client - ALWAYS require secret for confidential clients
|
||||
client = await get_client(db, client_id)
|
||||
if not client:
|
||||
raise InvalidClientError("Unknown client_id")
|
||||
|
||||
# Confidential clients MUST authenticate (RFC 6749 Section 3.2.1)
|
||||
if client.client_type == "confidential":
|
||||
if not client_secret:
|
||||
raise InvalidClientError("Client secret required for confidential clients")
|
||||
client = await validate_client(
|
||||
db, client_id, client_secret, require_secret=True
|
||||
)
|
||||
elif client_secret:
|
||||
# Public client provided secret - validate it if given
|
||||
client = await validate_client(
|
||||
db, client_id, client_secret, require_secret=True
|
||||
)
|
||||
|
||||
# Verify PKCE
|
||||
if auth_code.code_challenge:
|
||||
if not code_verifier:
|
||||
raise InvalidGrantError("code_verifier required")
|
||||
if not verify_pkce(
|
||||
code_verifier,
|
||||
str(auth_code.code_challenge),
|
||||
str(auth_code.code_challenge_method or "S256"),
|
||||
):
|
||||
raise InvalidGrantError("Invalid code_verifier")
|
||||
elif client.client_type == "public":
|
||||
# Public clients without PKCE - this shouldn't happen if we validated on authorize
|
||||
raise InvalidGrantError("PKCE required for public clients")
|
||||
|
||||
# Get user
|
||||
user = await user_repo.get(db, id=str(auth_code.user_id))
|
||||
if not user or not user.is_active:
|
||||
raise InvalidGrantError("User not found or inactive")
|
||||
|
||||
# Generate tokens
|
||||
return await create_tokens(
|
||||
db=db,
|
||||
client=client,
|
||||
user=user,
|
||||
scope=str(auth_code.scope),
|
||||
nonce=str(auth_code.nonce) if auth_code.nonce else None,
|
||||
device_info=device_info,
|
||||
ip_address=ip_address,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Token Generation
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def create_tokens(
|
||||
db: AsyncSession,
|
||||
client: OAuthClient,
|
||||
user: User,
|
||||
scope: str,
|
||||
nonce: str | None = None,
|
||||
device_info: str | None = None,
|
||||
ip_address: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Create access and refresh tokens.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
client: OAuth client
|
||||
user: User
|
||||
scope: Granted scopes
|
||||
nonce: OpenID Connect nonce (included in ID token)
|
||||
device_info: Optional device information
|
||||
ip_address: Optional IP address
|
||||
|
||||
Returns:
|
||||
Token response dict
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
jti = generate_jti()
|
||||
|
||||
# Access token expiry
|
||||
access_token_lifetime = int(client.access_token_lifetime or "3600")
|
||||
access_expires = now + timedelta(seconds=access_token_lifetime)
|
||||
|
||||
# Refresh token expiry
|
||||
refresh_token_lifetime = int(
|
||||
client.refresh_token_lifetime or str(REFRESH_TOKEN_EXPIRY_DAYS * 86400)
|
||||
)
|
||||
refresh_expires = now + timedelta(seconds=refresh_token_lifetime)
|
||||
|
||||
# Create JWT access token
|
||||
# SECURITY: Include all standard JWT claims per RFC 7519
|
||||
access_token_payload = {
|
||||
"iss": settings.OAUTH_ISSUER,
|
||||
"sub": str(user.id),
|
||||
"aud": client.client_id,
|
||||
"exp": int(access_expires.timestamp()),
|
||||
"iat": int(now.timestamp()),
|
||||
"nbf": int(now.timestamp()), # Not Before - token is valid immediately
|
||||
"jti": jti,
|
||||
"scope": scope,
|
||||
"client_id": client.client_id,
|
||||
# User info (basic claims)
|
||||
"email": user.email,
|
||||
"name": f"{user.first_name or ''} {user.last_name or ''}".strip() or user.email,
|
||||
}
|
||||
|
||||
# Add nonce for OpenID Connect
|
||||
if nonce:
|
||||
access_token_payload["nonce"] = nonce
|
||||
|
||||
access_token = jwt.encode(
|
||||
access_token_payload,
|
||||
settings.SECRET_KEY,
|
||||
algorithm=settings.ALGORITHM,
|
||||
)
|
||||
|
||||
# Create opaque refresh token
|
||||
refresh_token = generate_token()
|
||||
refresh_token_hash = hash_token(refresh_token)
|
||||
|
||||
# Store refresh token in database
|
||||
await oauth_provider_token_repo.create_token(
|
||||
db,
|
||||
token_hash=refresh_token_hash,
|
||||
jti=jti,
|
||||
client_id=client.client_id,
|
||||
user_id=user.id,
|
||||
scope=scope,
|
||||
expires_at=refresh_expires,
|
||||
device_info=device_info,
|
||||
ip_address=ip_address,
|
||||
)
|
||||
|
||||
logger.info("Issued tokens for user %s to client %s", user.id, client.client_id)
|
||||
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": access_token_lifetime,
|
||||
"refresh_token": refresh_token,
|
||||
"scope": scope,
|
||||
}
|
||||
|
||||
|
||||
async def refresh_tokens(
|
||||
db: AsyncSession,
|
||||
refresh_token: str,
|
||||
client_id: str,
|
||||
client_secret: str | None = None,
|
||||
scope: str | None = None,
|
||||
device_info: str | None = None,
|
||||
ip_address: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Refresh access token using refresh token.
|
||||
|
||||
Implements token rotation - old refresh token is invalidated,
|
||||
new refresh token is issued.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
refresh_token: Refresh token
|
||||
client_id: Client identifier
|
||||
client_secret: Client secret (for confidential clients)
|
||||
scope: Optional reduced scope
|
||||
device_info: Optional device information
|
||||
ip_address: Optional IP address
|
||||
|
||||
Returns:
|
||||
New token response dict
|
||||
|
||||
Raises:
|
||||
InvalidGrantError: If refresh token is invalid
|
||||
"""
|
||||
# Find refresh token
|
||||
token_hash = hash_token(refresh_token)
|
||||
token_record = await oauth_provider_token_repo.get_by_token_hash(
|
||||
db, token_hash=token_hash
|
||||
)
|
||||
|
||||
if not token_record:
|
||||
raise InvalidGrantError("Invalid refresh token")
|
||||
|
||||
if token_record.revoked:
|
||||
# Token reuse after revocation - security incident
|
||||
logger.warning(
|
||||
"Revoked refresh token reuse detected for client %s", token_record.client_id
|
||||
)
|
||||
raise InvalidGrantError("Refresh token has been revoked")
|
||||
|
||||
if token_record.is_expired:
|
||||
raise InvalidGrantError("Refresh token has expired")
|
||||
|
||||
if token_record.client_id != client_id:
|
||||
raise InvalidGrantError("Refresh token was not issued to this client")
|
||||
|
||||
# Validate client
|
||||
client = await validate_client(
|
||||
db,
|
||||
client_id,
|
||||
client_secret,
|
||||
require_secret=(client_secret is not None),
|
||||
)
|
||||
|
||||
# Get user
|
||||
user = await user_repo.get(db, id=str(token_record.user_id))
|
||||
if not user or not user.is_active:
|
||||
raise InvalidGrantError("User not found or inactive")
|
||||
|
||||
# Validate scope (can only reduce, not expand)
|
||||
token_scope = str(token_record.scope) if token_record.scope else ""
|
||||
original_scopes = set(parse_scope(token_scope))
|
||||
if scope:
|
||||
requested_scopes = set(parse_scope(scope))
|
||||
if not requested_scopes.issubset(original_scopes):
|
||||
raise InvalidScopeError("Cannot expand scope on refresh")
|
||||
final_scope = join_scope(list(requested_scopes))
|
||||
else:
|
||||
final_scope = token_scope
|
||||
|
||||
# Revoke old refresh token (token rotation)
|
||||
await oauth_provider_token_repo.revoke(db, token=token_record)
|
||||
|
||||
# Issue new tokens
|
||||
device = str(token_record.device_info) if token_record.device_info else None
|
||||
ip_addr = str(token_record.ip_address) if token_record.ip_address else None
|
||||
return await create_tokens(
|
||||
db=db,
|
||||
client=client,
|
||||
user=user,
|
||||
scope=final_scope,
|
||||
device_info=device_info or device,
|
||||
ip_address=ip_address or ip_addr,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Token Revocation
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def revoke_token(
|
||||
db: AsyncSession,
|
||||
token: str,
|
||||
token_type_hint: str | None = None,
|
||||
client_id: str | None = None,
|
||||
client_secret: str | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Revoke a token (access or refresh).
|
||||
|
||||
For refresh tokens: marks as revoked in database
|
||||
For access tokens: we can't truly revoke JWTs, but we can revoke
|
||||
the associated refresh token to prevent further refreshes
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
token: Token to revoke
|
||||
token_type_hint: "access_token" or "refresh_token"
|
||||
client_id: Client identifier (for validation)
|
||||
client_secret: Client secret (for confidential clients)
|
||||
|
||||
Returns:
|
||||
True if token was revoked, False if not found
|
||||
"""
|
||||
# Try as refresh token first (more likely)
|
||||
if token_type_hint != "access_token":
|
||||
token_hash = hash_token(token)
|
||||
refresh_record = await oauth_provider_token_repo.get_by_token_hash(
|
||||
db, token_hash=token_hash
|
||||
)
|
||||
|
||||
if refresh_record:
|
||||
# Validate client if provided
|
||||
if client_id and refresh_record.client_id != client_id:
|
||||
raise InvalidClientError("Token was not issued to this client")
|
||||
|
||||
await oauth_provider_token_repo.revoke(db, token=refresh_record)
|
||||
logger.info("Revoked refresh token %s...", refresh_record.jti[:8])
|
||||
return True
|
||||
|
||||
# Try as access token (JWT)
|
||||
if token_type_hint != "refresh_token":
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.SECRET_KEY,
|
||||
algorithms=[settings.ALGORITHM],
|
||||
options={
|
||||
"verify_exp": False,
|
||||
"verify_aud": False,
|
||||
}, # Allow expired tokens
|
||||
)
|
||||
jti = payload.get("jti")
|
||||
if jti:
|
||||
# Find and revoke the associated refresh token
|
||||
refresh_record = await oauth_provider_token_repo.get_by_jti(db, jti=jti)
|
||||
if refresh_record:
|
||||
if client_id and refresh_record.client_id != client_id:
|
||||
raise InvalidClientError("Token was not issued to this client")
|
||||
await oauth_provider_token_repo.revoke(db, token=refresh_record)
|
||||
logger.info(
|
||||
"Revoked refresh token via access token JTI %s...", jti[:8]
|
||||
)
|
||||
return True
|
||||
except InvalidTokenError:
|
||||
pass
|
||||
except Exception: # noqa: S110 - Intentional: invalid JWT not an error
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
|
||||
async def revoke_tokens_for_user_client(
|
||||
db: AsyncSession,
|
||||
user_id: UUID,
|
||||
client_id: str,
|
||||
) -> int:
|
||||
"""
|
||||
Revoke all tokens for a specific user-client pair.
|
||||
|
||||
Used when security incidents are detected (e.g., code reuse).
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_id: User identifier
|
||||
client_id: Client identifier
|
||||
|
||||
Returns:
|
||||
Number of tokens revoked
|
||||
"""
|
||||
count = await oauth_provider_token_repo.revoke_all_for_user_client(
|
||||
db, user_id=user_id, client_id=client_id
|
||||
)
|
||||
|
||||
if count > 0:
|
||||
logger.warning(
|
||||
"Revoked %s tokens for user %s and client %s", count, user_id, client_id
|
||||
)
|
||||
|
||||
return count
|
||||
|
||||
|
||||
async def revoke_all_user_tokens(db: AsyncSession, user_id: UUID) -> int:
|
||||
"""
|
||||
Revoke all OAuth provider tokens for a user.
|
||||
|
||||
Used when user changes password or explicitly logs out everywhere.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_id: User identifier
|
||||
|
||||
Returns:
|
||||
Number of tokens revoked
|
||||
"""
|
||||
count = await oauth_provider_token_repo.revoke_all_for_user(db, user_id=user_id)
|
||||
|
||||
if count > 0:
|
||||
logger.info("Revoked %s OAuth provider tokens for user %s", count, user_id)
|
||||
|
||||
return count
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Token Introspection (RFC 7662)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def introspect_token(
|
||||
db: AsyncSession,
|
||||
token: str,
|
||||
token_type_hint: str | None = None,
|
||||
client_id: str | None = None,
|
||||
client_secret: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Introspect a token to determine its validity and metadata.
|
||||
|
||||
Implements RFC 7662 Token Introspection.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
token: Token to introspect
|
||||
token_type_hint: "access_token" or "refresh_token"
|
||||
client_id: Client requesting introspection
|
||||
client_secret: Client secret
|
||||
|
||||
Returns:
|
||||
Introspection response dict
|
||||
"""
|
||||
# Validate client if credentials provided
|
||||
if client_id:
|
||||
await validate_client(db, client_id, client_secret)
|
||||
|
||||
# Try as access token (JWT) first
|
||||
if token_type_hint != "refresh_token":
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.SECRET_KEY,
|
||||
algorithms=[settings.ALGORITHM],
|
||||
options={
|
||||
"verify_aud": False
|
||||
}, # Don't require audience match for introspection
|
||||
)
|
||||
|
||||
# Check if associated refresh token is revoked
|
||||
jti = payload.get("jti")
|
||||
if jti:
|
||||
refresh_record = await oauth_provider_token_repo.get_by_jti(db, jti=jti)
|
||||
if refresh_record and refresh_record.revoked:
|
||||
return {"active": False}
|
||||
|
||||
return {
|
||||
"active": True,
|
||||
"scope": payload.get("scope", ""),
|
||||
"client_id": payload.get("client_id"),
|
||||
"username": payload.get("email"),
|
||||
"token_type": "Bearer",
|
||||
"exp": payload.get("exp"),
|
||||
"iat": payload.get("iat"),
|
||||
"sub": payload.get("sub"),
|
||||
"aud": payload.get("aud"),
|
||||
"iss": payload.get("iss"),
|
||||
}
|
||||
except ExpiredSignatureError:
|
||||
return {"active": False}
|
||||
except InvalidTokenError:
|
||||
pass
|
||||
except Exception: # noqa: S110 - Intentional: invalid JWT falls through to refresh token check
|
||||
pass
|
||||
|
||||
# Try as refresh token
|
||||
if token_type_hint != "access_token":
|
||||
token_hash = hash_token(token)
|
||||
refresh_record = await oauth_provider_token_repo.get_by_token_hash(
|
||||
db, token_hash=token_hash
|
||||
)
|
||||
|
||||
if refresh_record and refresh_record.is_valid:
|
||||
return {
|
||||
"active": True,
|
||||
"scope": refresh_record.scope,
|
||||
"client_id": refresh_record.client_id,
|
||||
"token_type": "refresh_token",
|
||||
"exp": int(refresh_record.expires_at.timestamp()),
|
||||
"iat": int(refresh_record.created_at.timestamp()),
|
||||
"sub": str(refresh_record.user_id),
|
||||
}
|
||||
|
||||
return {"active": False}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Consent Management
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def get_consent(
|
||||
db: AsyncSession,
|
||||
user_id: UUID,
|
||||
client_id: str,
|
||||
):
|
||||
"""Get existing consent record for user-client pair."""
|
||||
return await oauth_consent_repo.get_consent(
|
||||
db, user_id=user_id, client_id=client_id
|
||||
)
|
||||
|
||||
|
||||
async def check_consent(
|
||||
db: AsyncSession,
|
||||
user_id: UUID,
|
||||
client_id: str,
|
||||
requested_scopes: list[str],
|
||||
) -> bool:
|
||||
"""
|
||||
Check if user has already consented to the requested scopes.
|
||||
|
||||
Returns True if all requested scopes are already granted.
|
||||
"""
|
||||
consent = await get_consent(db, user_id, client_id)
|
||||
if not consent:
|
||||
return False
|
||||
return consent.has_scopes(requested_scopes)
|
||||
|
||||
|
||||
async def grant_consent(
|
||||
db: AsyncSession,
|
||||
user_id: UUID,
|
||||
client_id: str,
|
||||
scopes: list[str],
|
||||
):
|
||||
"""
|
||||
Grant or update consent for a user-client pair.
|
||||
|
||||
If consent already exists, updates the granted scopes.
|
||||
"""
|
||||
return await oauth_consent_repo.grant_consent(
|
||||
db, user_id=user_id, client_id=client_id, scopes=scopes
|
||||
)
|
||||
|
||||
|
||||
async def revoke_consent(
|
||||
db: AsyncSession,
|
||||
user_id: UUID,
|
||||
client_id: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Revoke consent and all tokens for a user-client pair.
|
||||
|
||||
Returns True if consent was found and revoked.
|
||||
"""
|
||||
# Revoke all tokens first
|
||||
await revoke_tokens_for_user_client(db, user_id, client_id)
|
||||
|
||||
# Delete consent record
|
||||
return await oauth_consent_repo.revoke_consent(
|
||||
db, user_id=user_id, client_id=client_id
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Cleanup
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def register_client(db: AsyncSession, client_data: OAuthClientCreate) -> tuple:
|
||||
"""Create a new OAuth client. Returns (client, secret)."""
|
||||
return await oauth_client_repo.create_client(db, obj_in=client_data)
|
||||
|
||||
|
||||
async def list_clients(db: AsyncSession) -> list:
|
||||
"""List all registered OAuth clients."""
|
||||
return await oauth_client_repo.get_all_clients(db)
|
||||
|
||||
|
||||
async def delete_client_by_id(db: AsyncSession, client_id: str) -> None:
|
||||
"""Delete an OAuth client by client_id."""
|
||||
await oauth_client_repo.delete_client(db, client_id=client_id)
|
||||
|
||||
|
||||
async def list_user_consents(db: AsyncSession, user_id: UUID) -> list[dict]:
|
||||
"""Get all OAuth consents for a user with client details."""
|
||||
return await oauth_consent_repo.get_user_consents_with_clients(db, user_id=user_id)
|
||||
|
||||
|
||||
async def cleanup_expired_codes(db: AsyncSession) -> int:
|
||||
"""
|
||||
Delete expired authorization codes.
|
||||
|
||||
Should be called periodically (e.g., every hour).
|
||||
|
||||
Returns:
|
||||
Number of codes deleted
|
||||
"""
|
||||
return await oauth_authorization_code_repo.cleanup_expired(db)
|
||||
|
||||
|
||||
async def cleanup_expired_tokens(db: AsyncSession) -> int:
|
||||
"""
|
||||
Delete expired and revoked refresh tokens.
|
||||
|
||||
Should be called periodically (e.g., daily).
|
||||
|
||||
Returns:
|
||||
Number of tokens deleted
|
||||
"""
|
||||
return await oauth_provider_token_repo.cleanup_expired(db, cutoff_days=7)
|
||||
744
backend/app/services/oauth_service.py
Normal file
744
backend/app/services/oauth_service.py
Normal file
@@ -0,0 +1,744 @@
|
||||
"""
|
||||
OAuth Service for handling social authentication flows.
|
||||
|
||||
Supports:
|
||||
- Google OAuth (OpenID Connect)
|
||||
- GitHub OAuth
|
||||
|
||||
Features:
|
||||
- PKCE support for public clients
|
||||
- State parameter for CSRF protection
|
||||
- Auto-linking by email (configurable)
|
||||
- Account linking for existing users
|
||||
"""
|
||||
|
||||
import logging
|
||||
import secrets
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import TypedDict, cast
|
||||
from uuid import UUID
|
||||
|
||||
from authlib.integrations.httpx_client import AsyncOAuth2Client
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.auth import create_access_token, create_refresh_token
|
||||
from app.core.config import settings
|
||||
from app.core.exceptions import AuthenticationError
|
||||
from app.models.user import User
|
||||
from app.repositories.oauth_account import oauth_account_repo as oauth_account
|
||||
from app.repositories.oauth_state import oauth_state_repo as oauth_state
|
||||
from app.repositories.user import user_repo
|
||||
from app.schemas.oauth import (
|
||||
OAuthAccountCreate,
|
||||
OAuthCallbackResponse,
|
||||
OAuthProviderInfo,
|
||||
OAuthProvidersResponse,
|
||||
OAuthStateCreate,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _OAuthProviderConfigRequired(TypedDict):
|
||||
name: str
|
||||
icon: str
|
||||
authorize_url: str
|
||||
token_url: str
|
||||
userinfo_url: str
|
||||
scopes: list[str]
|
||||
supports_pkce: bool
|
||||
|
||||
|
||||
class OAuthProviderConfig(_OAuthProviderConfigRequired, total=False):
|
||||
"""Type definition for OAuth provider configuration."""
|
||||
|
||||
email_url: str # Optional, GitHub-only
|
||||
|
||||
|
||||
# Provider configurations
|
||||
OAUTH_PROVIDERS: dict[str, OAuthProviderConfig] = {
|
||||
"google": {
|
||||
"name": "Google",
|
||||
"icon": "google",
|
||||
"authorize_url": "https://accounts.google.com/o/oauth2/v2/auth",
|
||||
"token_url": "https://oauth2.googleapis.com/token",
|
||||
"userinfo_url": "https://www.googleapis.com/oauth2/v3/userinfo",
|
||||
"scopes": ["openid", "email", "profile"],
|
||||
"supports_pkce": True,
|
||||
},
|
||||
"github": {
|
||||
"name": "GitHub",
|
||||
"icon": "github",
|
||||
"authorize_url": "https://github.com/login/oauth/authorize",
|
||||
"token_url": "https://github.com/login/oauth/access_token",
|
||||
"userinfo_url": "https://api.github.com/user",
|
||||
"email_url": "https://api.github.com/user/emails",
|
||||
"scopes": ["read:user", "user:email"],
|
||||
"supports_pkce": False, # GitHub doesn't support PKCE
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class OAuthService:
|
||||
"""Service for handling OAuth authentication flows."""
|
||||
|
||||
@staticmethod
|
||||
def get_enabled_providers() -> OAuthProvidersResponse:
|
||||
"""
|
||||
Get list of enabled OAuth providers.
|
||||
|
||||
Returns:
|
||||
OAuthProvidersResponse with enabled providers
|
||||
"""
|
||||
providers = []
|
||||
|
||||
for provider_id in settings.enabled_oauth_providers:
|
||||
if provider_id in OAUTH_PROVIDERS:
|
||||
config = OAUTH_PROVIDERS[provider_id]
|
||||
providers.append(
|
||||
OAuthProviderInfo(
|
||||
provider=provider_id,
|
||||
name=config["name"],
|
||||
icon=config["icon"],
|
||||
)
|
||||
)
|
||||
|
||||
return OAuthProvidersResponse(
|
||||
enabled=settings.OAUTH_ENABLED and len(providers) > 0,
|
||||
providers=providers,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_provider_credentials(provider: str) -> tuple[str, str]:
|
||||
"""Get client ID and secret for a provider."""
|
||||
if provider == "google":
|
||||
client_id = settings.OAUTH_GOOGLE_CLIENT_ID
|
||||
client_secret = settings.OAUTH_GOOGLE_CLIENT_SECRET
|
||||
elif provider == "github":
|
||||
client_id = settings.OAUTH_GITHUB_CLIENT_ID
|
||||
client_secret = settings.OAUTH_GITHUB_CLIENT_SECRET
|
||||
else:
|
||||
raise AuthenticationError(f"Unknown OAuth provider: {provider}")
|
||||
|
||||
if not client_id or not client_secret:
|
||||
raise AuthenticationError(f"OAuth provider {provider} is not configured")
|
||||
|
||||
return client_id, client_secret
|
||||
|
||||
@staticmethod
|
||||
async def create_authorization_url(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
provider: str,
|
||||
redirect_uri: str,
|
||||
user_id: str | None = None,
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Create OAuth authorization URL with state and optional PKCE.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
provider: OAuth provider (google, github)
|
||||
redirect_uri: Callback URL after OAuth
|
||||
user_id: User ID if linking account (user is logged in)
|
||||
|
||||
Returns:
|
||||
Tuple of (authorization_url, state)
|
||||
|
||||
Raises:
|
||||
AuthenticationError: If provider is not configured
|
||||
"""
|
||||
if not settings.OAUTH_ENABLED:
|
||||
raise AuthenticationError("OAuth is not enabled")
|
||||
|
||||
if provider not in OAUTH_PROVIDERS:
|
||||
raise AuthenticationError(f"Unknown OAuth provider: {provider}")
|
||||
|
||||
if provider not in settings.enabled_oauth_providers:
|
||||
raise AuthenticationError(f"OAuth provider {provider} is not enabled")
|
||||
|
||||
config = OAUTH_PROVIDERS[provider]
|
||||
client_id, client_secret = OAuthService._get_provider_credentials(provider)
|
||||
|
||||
# Generate state for CSRF protection
|
||||
state = secrets.token_urlsafe(32)
|
||||
|
||||
# Generate PKCE code verifier and challenge if supported
|
||||
code_verifier = None
|
||||
code_challenge = None
|
||||
if config.get("supports_pkce"):
|
||||
code_verifier = secrets.token_urlsafe(64)
|
||||
# Create code_challenge using S256 method
|
||||
import base64
|
||||
import hashlib
|
||||
|
||||
code_challenge_bytes = hashlib.sha256(code_verifier.encode()).digest()
|
||||
code_challenge = (
|
||||
base64.urlsafe_b64encode(code_challenge_bytes).decode().rstrip("=")
|
||||
)
|
||||
|
||||
# Generate nonce for OIDC (Google)
|
||||
nonce = secrets.token_urlsafe(32) if provider == "google" else None
|
||||
|
||||
# Store state in database
|
||||
from uuid import UUID
|
||||
|
||||
state_data = OAuthStateCreate(
|
||||
state=state,
|
||||
code_verifier=code_verifier,
|
||||
nonce=nonce,
|
||||
provider=provider,
|
||||
redirect_uri=redirect_uri,
|
||||
user_id=UUID(user_id) if user_id else None,
|
||||
expires_at=datetime.now(UTC)
|
||||
+ timedelta(minutes=settings.OAUTH_STATE_EXPIRE_MINUTES),
|
||||
)
|
||||
await oauth_state.create_state(db, obj_in=state_data)
|
||||
|
||||
# Build authorization URL
|
||||
async with AsyncOAuth2Client(
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
redirect_uri=redirect_uri,
|
||||
) as client:
|
||||
# Prepare authorization params
|
||||
auth_params = {
|
||||
"state": state,
|
||||
"scope": " ".join(config["scopes"]),
|
||||
}
|
||||
|
||||
if code_challenge:
|
||||
auth_params["code_challenge"] = code_challenge
|
||||
auth_params["code_challenge_method"] = "S256"
|
||||
|
||||
if nonce:
|
||||
auth_params["nonce"] = nonce
|
||||
|
||||
url, _ = client.create_authorization_url(
|
||||
config["authorize_url"],
|
||||
**auth_params,
|
||||
)
|
||||
|
||||
logger.info("OAuth authorization URL created for %s", provider)
|
||||
return url, state
|
||||
|
||||
@staticmethod
|
||||
async def handle_callback(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
code: str,
|
||||
state: str,
|
||||
redirect_uri: str,
|
||||
) -> OAuthCallbackResponse:
|
||||
"""
|
||||
Handle OAuth callback and authenticate/create user.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
code: Authorization code from provider
|
||||
state: State parameter for CSRF verification
|
||||
redirect_uri: Callback URL (must match authorization request)
|
||||
|
||||
Returns:
|
||||
OAuthCallbackResponse with tokens
|
||||
|
||||
Raises:
|
||||
AuthenticationError: If authentication fails
|
||||
"""
|
||||
# Validate and consume state
|
||||
state_record = await oauth_state.get_and_consume_state(db, state=state)
|
||||
if not state_record:
|
||||
raise AuthenticationError("Invalid or expired OAuth state")
|
||||
|
||||
# SECURITY: Validate redirect_uri matches the one from authorization request
|
||||
# This prevents authorization code injection attacks (RFC 6749 Section 10.6)
|
||||
if state_record.redirect_uri != redirect_uri:
|
||||
logger.warning(
|
||||
"OAuth redirect_uri mismatch: expected %s, got %s",
|
||||
state_record.redirect_uri,
|
||||
redirect_uri,
|
||||
)
|
||||
raise AuthenticationError("Redirect URI mismatch")
|
||||
|
||||
# Extract provider from state record (str for type safety)
|
||||
provider: str = str(state_record.provider)
|
||||
|
||||
if provider not in OAUTH_PROVIDERS:
|
||||
raise AuthenticationError(f"Unknown OAuth provider: {provider}")
|
||||
|
||||
config = OAUTH_PROVIDERS[provider]
|
||||
client_id, client_secret = OAuthService._get_provider_credentials(provider)
|
||||
|
||||
# Exchange code for tokens
|
||||
async with AsyncOAuth2Client(
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
redirect_uri=redirect_uri,
|
||||
) as client:
|
||||
try:
|
||||
# Prepare token request params
|
||||
token_params: dict[str, str] = {"code": code}
|
||||
|
||||
if state_record.code_verifier:
|
||||
token_params["code_verifier"] = str(state_record.code_verifier)
|
||||
|
||||
token = await client.fetch_token(
|
||||
config["token_url"],
|
||||
**token_params,
|
||||
)
|
||||
|
||||
# SECURITY: Validate ID token signature and nonce for OpenID Connect
|
||||
# This prevents token forgery and replay attacks (OIDC Core 3.1.3.7)
|
||||
if provider == "google" and state_record.nonce:
|
||||
id_token = token.get("id_token")
|
||||
if id_token:
|
||||
await OAuthService._verify_google_id_token(
|
||||
id_token=str(id_token),
|
||||
expected_nonce=str(state_record.nonce),
|
||||
client_id=client_id,
|
||||
)
|
||||
except AuthenticationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("OAuth token exchange failed: %s", e)
|
||||
raise AuthenticationError("Failed to exchange authorization code")
|
||||
|
||||
# Get user info from provider
|
||||
try:
|
||||
access_token = token.get("access_token")
|
||||
if not access_token:
|
||||
raise AuthenticationError("No access token received")
|
||||
|
||||
user_info = await OAuthService._get_user_info(
|
||||
client, provider, config, access_token
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get user info: %s", e)
|
||||
raise AuthenticationError(
|
||||
"Failed to get user information from provider"
|
||||
)
|
||||
|
||||
# Process user info and create/link account
|
||||
provider_user_id = str(user_info.get("id") or user_info.get("sub"))
|
||||
# Email can be None if user didn't grant email permission
|
||||
# SECURITY: Normalize email (lowercase, strip) to prevent case-based account duplication
|
||||
email_raw = user_info.get("email")
|
||||
provider_email: str | None = (
|
||||
str(email_raw).lower().strip() if email_raw else None
|
||||
)
|
||||
|
||||
if not provider_user_id:
|
||||
raise AuthenticationError("Provider did not return user ID")
|
||||
|
||||
# Check if this OAuth account already exists
|
||||
existing_oauth = await oauth_account.get_by_provider_id(
|
||||
db, provider=provider, provider_user_id=provider_user_id
|
||||
)
|
||||
|
||||
is_new_user = False
|
||||
|
||||
if existing_oauth:
|
||||
# Existing OAuth account - login
|
||||
user = existing_oauth.user
|
||||
if not user.is_active:
|
||||
raise AuthenticationError("User account is inactive")
|
||||
|
||||
# Update tokens if stored
|
||||
if token.get("access_token"):
|
||||
await oauth_account.update_tokens(
|
||||
db,
|
||||
account=existing_oauth,
|
||||
access_token=token.get("access_token"),
|
||||
refresh_token=token.get("refresh_token"),
|
||||
token_expires_at=datetime.now(UTC)
|
||||
+ timedelta(seconds=token.get("expires_in", 3600)),
|
||||
)
|
||||
|
||||
logger.info("OAuth login successful for %s via %s", user.email, provider)
|
||||
|
||||
elif state_record.user_id:
|
||||
# Account linking flow (user is already logged in)
|
||||
user = await user_repo.get(db, id=str(state_record.user_id))
|
||||
|
||||
if not user:
|
||||
raise AuthenticationError("User not found for account linking")
|
||||
|
||||
# Check if user already has this provider linked
|
||||
user_id = cast(UUID, user.id)
|
||||
existing_provider = await oauth_account.get_user_account_by_provider(
|
||||
db, user_id=user_id, provider=provider
|
||||
)
|
||||
if existing_provider:
|
||||
raise AuthenticationError(
|
||||
f"You already have a {provider} account linked"
|
||||
)
|
||||
|
||||
# Create OAuth account link
|
||||
oauth_create = OAuthAccountCreate(
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
provider_user_id=provider_user_id,
|
||||
provider_email=provider_email,
|
||||
access_token=token.get("access_token"),
|
||||
refresh_token=token.get("refresh_token"),
|
||||
token_expires_at=datetime.now(UTC)
|
||||
+ timedelta(seconds=token.get("expires_in", 3600))
|
||||
if token.get("expires_in")
|
||||
else None,
|
||||
)
|
||||
await oauth_account.create_account(db, obj_in=oauth_create)
|
||||
|
||||
logger.info("OAuth account linked: %s -> %s", provider, user.email)
|
||||
|
||||
else:
|
||||
# New OAuth login - check for existing user by email
|
||||
user = None
|
||||
|
||||
if provider_email and settings.OAUTH_AUTO_LINK_BY_EMAIL:
|
||||
user = await user_repo.get_by_email(db, email=provider_email)
|
||||
|
||||
if user:
|
||||
# Auto-link to existing user
|
||||
if not user.is_active:
|
||||
raise AuthenticationError("User account is inactive")
|
||||
|
||||
# Check if user already has this provider linked
|
||||
user_id = cast(UUID, user.id)
|
||||
existing_provider = await oauth_account.get_user_account_by_provider(
|
||||
db, user_id=user_id, provider=provider
|
||||
)
|
||||
if existing_provider:
|
||||
# This shouldn't happen if we got here, but safety check
|
||||
logger.warning(
|
||||
"OAuth account already linked (race condition?): %s -> %s",
|
||||
provider,
|
||||
user.email,
|
||||
)
|
||||
else:
|
||||
# Create OAuth account link
|
||||
oauth_create = OAuthAccountCreate(
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
provider_user_id=provider_user_id,
|
||||
provider_email=provider_email,
|
||||
access_token=token.get("access_token"),
|
||||
refresh_token=token.get("refresh_token"),
|
||||
token_expires_at=datetime.now(UTC)
|
||||
+ timedelta(seconds=token.get("expires_in", 3600))
|
||||
if token.get("expires_in")
|
||||
else None,
|
||||
)
|
||||
await oauth_account.create_account(db, obj_in=oauth_create)
|
||||
|
||||
logger.info(
|
||||
"OAuth auto-linked by email: %s -> %s", provider, user.email
|
||||
)
|
||||
|
||||
else:
|
||||
# Create new user
|
||||
if not provider_email:
|
||||
raise AuthenticationError(
|
||||
f"Email is required for registration. "
|
||||
f"Please grant email permission to {provider}."
|
||||
)
|
||||
|
||||
user = await OAuthService._create_oauth_user(
|
||||
db,
|
||||
email=provider_email,
|
||||
provider=provider,
|
||||
provider_user_id=provider_user_id,
|
||||
user_info=user_info,
|
||||
token=token,
|
||||
)
|
||||
is_new_user = True
|
||||
|
||||
logger.info("New user created via OAuth: %s (%s)", user.email, provider)
|
||||
|
||||
# Generate JWT tokens
|
||||
claims = {
|
||||
"is_superuser": user.is_superuser,
|
||||
"email": user.email,
|
||||
"first_name": user.first_name,
|
||||
}
|
||||
|
||||
access_token_jwt = create_access_token(subject=str(user.id), claims=claims)
|
||||
refresh_token_jwt = create_refresh_token(subject=str(user.id))
|
||||
|
||||
return OAuthCallbackResponse(
|
||||
access_token=access_token_jwt,
|
||||
refresh_token=refresh_token_jwt,
|
||||
token_type="bearer",
|
||||
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||||
is_new_user=is_new_user,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _get_user_info(
|
||||
client: AsyncOAuth2Client,
|
||||
provider: str,
|
||||
config: OAuthProviderConfig,
|
||||
access_token: str,
|
||||
) -> dict[str, object]:
|
||||
"""Get user info from OAuth provider."""
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
|
||||
if provider == "github":
|
||||
# GitHub returns JSON with Accept header
|
||||
headers["Accept"] = "application/vnd.github+json"
|
||||
|
||||
resp = await client.get(config["userinfo_url"], headers=headers)
|
||||
resp.raise_for_status()
|
||||
user_info = resp.json()
|
||||
|
||||
# GitHub requires separate request for email
|
||||
if provider == "github" and not user_info.get("email"):
|
||||
email_resp = await client.get(
|
||||
config["email_url"], # pyright: ignore[reportTypedDictNotRequiredAccess]
|
||||
headers=headers,
|
||||
)
|
||||
email_resp.raise_for_status()
|
||||
emails = email_resp.json()
|
||||
|
||||
# Find primary verified email
|
||||
for email_data in emails:
|
||||
if email_data.get("primary") and email_data.get("verified"):
|
||||
user_info["email"] = email_data["email"]
|
||||
break
|
||||
|
||||
return user_info
|
||||
|
||||
# Google's OIDC configuration endpoints
|
||||
GOOGLE_JWKS_URL = "https://www.googleapis.com/oauth2/v3/certs"
|
||||
GOOGLE_ISSUERS = ("https://accounts.google.com", "accounts.google.com")
|
||||
|
||||
@staticmethod
|
||||
async def _verify_google_id_token(
|
||||
id_token: str,
|
||||
expected_nonce: str,
|
||||
client_id: str,
|
||||
) -> dict[str, object]:
|
||||
"""
|
||||
Verify Google ID token signature and claims.
|
||||
|
||||
SECURITY: This properly verifies the ID token by:
|
||||
1. Fetching Google's public keys (JWKS)
|
||||
2. Verifying the JWT signature against the public key
|
||||
3. Validating issuer, audience, expiry, and nonce claims
|
||||
|
||||
Args:
|
||||
id_token: The ID token JWT string
|
||||
expected_nonce: The nonce we sent in the authorization request
|
||||
client_id: Our OAuth client ID (expected audience)
|
||||
|
||||
Returns:
|
||||
Decoded ID token payload
|
||||
|
||||
Raises:
|
||||
AuthenticationError: If verification fails
|
||||
"""
|
||||
import httpx
|
||||
import jwt as pyjwt
|
||||
from jwt.algorithms import RSAAlgorithm
|
||||
from jwt.exceptions import InvalidTokenError
|
||||
|
||||
try:
|
||||
# Fetch Google's public keys (JWKS)
|
||||
# In production, consider caching this with TTL matching Cache-Control header
|
||||
async with httpx.AsyncClient() as client:
|
||||
jwks_response = await client.get(
|
||||
OAuthService.GOOGLE_JWKS_URL,
|
||||
timeout=10.0,
|
||||
)
|
||||
jwks_response.raise_for_status()
|
||||
jwks = jwks_response.json()
|
||||
|
||||
# Get the key ID from the token header
|
||||
unverified_header = pyjwt.get_unverified_header(id_token)
|
||||
kid = unverified_header.get("kid")
|
||||
if not kid:
|
||||
raise AuthenticationError("ID token missing key ID (kid)")
|
||||
|
||||
# Find the matching public key
|
||||
jwk_data = None
|
||||
for key in jwks.get("keys", []):
|
||||
if key.get("kid") == kid:
|
||||
jwk_data = key
|
||||
break
|
||||
|
||||
if not jwk_data:
|
||||
raise AuthenticationError("ID token signed with unknown key")
|
||||
|
||||
# Convert JWK to a public key object for PyJWT
|
||||
public_key = RSAAlgorithm.from_jwk(jwk_data)
|
||||
|
||||
# Verify the token signature and decode claims
|
||||
# PyJWT will verify signature against the RSA public key
|
||||
payload = pyjwt.decode(
|
||||
id_token,
|
||||
public_key,
|
||||
algorithms=["RS256"], # Google uses RS256
|
||||
audience=client_id,
|
||||
issuer=OAuthService.GOOGLE_ISSUERS,
|
||||
options={
|
||||
"verify_signature": True,
|
||||
"verify_aud": True,
|
||||
"verify_iss": True,
|
||||
"verify_exp": True,
|
||||
"verify_iat": True,
|
||||
},
|
||||
)
|
||||
|
||||
# Verify nonce (OIDC replay attack protection)
|
||||
token_nonce = payload.get("nonce")
|
||||
if token_nonce != expected_nonce:
|
||||
logger.warning(
|
||||
"OAuth ID token nonce mismatch: expected %s, got %s",
|
||||
expected_nonce,
|
||||
token_nonce,
|
||||
)
|
||||
raise AuthenticationError("Invalid ID token nonce")
|
||||
|
||||
logger.debug("Google ID token verified successfully")
|
||||
return payload
|
||||
|
||||
except InvalidTokenError as e:
|
||||
logger.warning("Google ID token verification failed: %s", e)
|
||||
raise AuthenticationError("Invalid ID token signature")
|
||||
except httpx.HTTPError as e:
|
||||
logger.error("Failed to fetch Google JWKS: %s", e)
|
||||
# If we can't verify the ID token, fail closed for security
|
||||
raise AuthenticationError("Failed to verify ID token")
|
||||
except Exception as e:
|
||||
logger.error("Unexpected error verifying Google ID token: %s", e)
|
||||
raise AuthenticationError("ID token verification error")
|
||||
|
||||
@staticmethod
|
||||
async def _create_oauth_user(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
email: str,
|
||||
provider: str,
|
||||
provider_user_id: str,
|
||||
user_info: dict,
|
||||
token: dict,
|
||||
) -> User:
|
||||
"""Create a new user from OAuth provider data."""
|
||||
# Extract name from user_info
|
||||
first_name = "User"
|
||||
last_name = None
|
||||
|
||||
if provider == "google":
|
||||
first_name = user_info.get("given_name") or user_info.get("name", "User")
|
||||
last_name = user_info.get("family_name")
|
||||
elif provider == "github":
|
||||
# GitHub has full name, try to split
|
||||
name = user_info.get("name") or user_info.get("login", "User")
|
||||
parts = name.split(" ", 1)
|
||||
first_name = parts[0]
|
||||
last_name = parts[1] if len(parts) > 1 else None
|
||||
|
||||
# Create user (no password for OAuth-only users)
|
||||
user = User(
|
||||
email=email,
|
||||
password_hash=None, # OAuth-only user
|
||||
first_name=first_name,
|
||||
last_name=last_name,
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
)
|
||||
db.add(user)
|
||||
await db.flush() # Get user.id
|
||||
|
||||
# Create OAuth account link
|
||||
user_id = cast(UUID, user.id)
|
||||
oauth_create = OAuthAccountCreate(
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
provider_user_id=provider_user_id,
|
||||
provider_email=email,
|
||||
access_token=token.get("access_token"),
|
||||
refresh_token=token.get("refresh_token"),
|
||||
token_expires_at=datetime.now(UTC)
|
||||
+ timedelta(seconds=token.get("expires_in", 3600))
|
||||
if token.get("expires_in")
|
||||
else None,
|
||||
)
|
||||
await oauth_account.create_account(db, obj_in=oauth_create)
|
||||
|
||||
await db.refresh(user)
|
||||
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
async def unlink_provider(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user: User,
|
||||
provider: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Unlink an OAuth provider from a user account.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user: User to unlink from
|
||||
provider: Provider to unlink
|
||||
|
||||
Returns:
|
||||
True if unlinked successfully
|
||||
|
||||
Raises:
|
||||
AuthenticationError: If unlinking would leave user without login method
|
||||
"""
|
||||
# Check if user can safely remove this OAuth account
|
||||
# Note: We query directly instead of using user.can_remove_oauth property
|
||||
# because the property uses lazy loading which doesn't work in async context
|
||||
user_id = cast(UUID, user.id)
|
||||
has_password = user.password_hash is not None
|
||||
oauth_accounts = await oauth_account.get_user_accounts(db, user_id=user_id)
|
||||
can_remove = has_password or len(oauth_accounts) > 1
|
||||
|
||||
if not can_remove:
|
||||
raise AuthenticationError(
|
||||
"Cannot unlink OAuth account. You must have either a password set "
|
||||
"or at least one other OAuth provider linked."
|
||||
)
|
||||
|
||||
deleted = await oauth_account.delete_account(
|
||||
db, user_id=user_id, provider=provider
|
||||
)
|
||||
|
||||
if not deleted:
|
||||
raise AuthenticationError(f"No {provider} account found to unlink")
|
||||
|
||||
logger.info("OAuth provider unlinked: %s from %s", provider, user.email)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def get_user_accounts(db: AsyncSession, *, user_id: UUID) -> list:
|
||||
"""Get all OAuth accounts linked to a user."""
|
||||
return await oauth_account.get_user_accounts(db, user_id=user_id)
|
||||
|
||||
@staticmethod
|
||||
async def get_user_account_by_provider(
|
||||
db: AsyncSession, *, user_id: UUID, provider: str
|
||||
):
|
||||
"""Get a specific OAuth account for a user and provider."""
|
||||
return await oauth_account.get_user_account_by_provider(
|
||||
db, user_id=user_id, provider=provider
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def cleanup_expired_states(db: AsyncSession) -> int:
|
||||
"""
|
||||
Clean up expired OAuth states.
|
||||
|
||||
Should be called periodically (e.g., by a background task).
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Number of states cleaned up
|
||||
"""
|
||||
return await oauth_state.cleanup_expired(db)
|
||||
155
backend/app/services/organization_service.py
Normal file
155
backend/app/services/organization_service.py
Normal file
@@ -0,0 +1,155 @@
|
||||
# app/services/organization_service.py
|
||||
"""Service layer for organization operations — delegates to OrganizationRepository."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.exceptions import NotFoundError
|
||||
from app.models.organization import Organization
|
||||
from app.models.user_organization import OrganizationRole, UserOrganization
|
||||
from app.repositories.organization import OrganizationRepository, organization_repo
|
||||
from app.schemas.organizations import OrganizationCreate, OrganizationUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OrganizationService:
|
||||
"""Service for organization management operations."""
|
||||
|
||||
def __init__(
|
||||
self, organization_repository: OrganizationRepository | None = None
|
||||
) -> None:
|
||||
self._repo = organization_repository or organization_repo
|
||||
|
||||
async def get_organization(self, db: AsyncSession, org_id: str) -> Organization:
|
||||
"""Get organization by ID, raising NotFoundError if not found."""
|
||||
org = await self._repo.get(db, id=org_id)
|
||||
if not org:
|
||||
raise NotFoundError(f"Organization {org_id} not found")
|
||||
return org
|
||||
|
||||
async def create_organization(
|
||||
self, db: AsyncSession, *, obj_in: OrganizationCreate
|
||||
) -> Organization:
|
||||
"""Create a new organization."""
|
||||
return await self._repo.create(db, obj_in=obj_in)
|
||||
|
||||
async def update_organization(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
org: Organization,
|
||||
obj_in: OrganizationUpdate | dict[str, Any],
|
||||
) -> Organization:
|
||||
"""Update an existing organization."""
|
||||
return await self._repo.update(db, db_obj=org, obj_in=obj_in)
|
||||
|
||||
async def remove_organization(self, db: AsyncSession, org_id: str) -> None:
|
||||
"""Permanently delete an organization by ID."""
|
||||
await self._repo.remove(db, id=org_id)
|
||||
|
||||
async def get_member_count(self, db: AsyncSession, *, organization_id: UUID) -> int:
|
||||
"""Get number of active members in an organization."""
|
||||
return await self._repo.get_member_count(db, organization_id=organization_id)
|
||||
|
||||
async def get_multi_with_member_counts(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: bool | None = None,
|
||||
search: str | None = None,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""List organizations with member counts and pagination."""
|
||||
return await self._repo.get_multi_with_member_counts(
|
||||
db, skip=skip, limit=limit, is_active=is_active, search=search
|
||||
)
|
||||
|
||||
async def get_user_organizations_with_details(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: UUID,
|
||||
is_active: bool | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get all organizations a user belongs to, with membership details."""
|
||||
return await self._repo.get_user_organizations_with_details(
|
||||
db, user_id=user_id, is_active=is_active
|
||||
)
|
||||
|
||||
async def get_organization_members(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
organization_id: UUID,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: bool | None = True,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""Get members of an organization with pagination."""
|
||||
return await self._repo.get_organization_members(
|
||||
db,
|
||||
organization_id=organization_id,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
is_active=is_active,
|
||||
)
|
||||
|
||||
async def add_member(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
organization_id: UUID,
|
||||
user_id: UUID,
|
||||
role: OrganizationRole = OrganizationRole.MEMBER,
|
||||
) -> UserOrganization:
|
||||
"""Add a user to an organization."""
|
||||
return await self._repo.add_user(
|
||||
db, organization_id=organization_id, user_id=user_id, role=role
|
||||
)
|
||||
|
||||
async def remove_member(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
organization_id: UUID,
|
||||
user_id: UUID,
|
||||
) -> bool:
|
||||
"""Remove a user from an organization. Returns True if found and removed."""
|
||||
return await self._repo.remove_user(
|
||||
db, organization_id=organization_id, user_id=user_id
|
||||
)
|
||||
|
||||
async def get_user_role_in_org(
|
||||
self, db: AsyncSession, *, user_id: UUID, organization_id: UUID
|
||||
) -> OrganizationRole | None:
|
||||
"""Get the role of a user in an organization."""
|
||||
return await self._repo.get_user_role_in_org(
|
||||
db, user_id=user_id, organization_id=organization_id
|
||||
)
|
||||
|
||||
async def get_org_distribution(
|
||||
self, db: AsyncSession, *, limit: int = 6
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Return top organizations by member count for admin dashboard."""
|
||||
from sqlalchemy import func, select
|
||||
|
||||
result = await db.execute(
|
||||
select(
|
||||
Organization.name,
|
||||
func.count(UserOrganization.user_id).label("count"),
|
||||
)
|
||||
.join(UserOrganization, Organization.id == UserOrganization.organization_id)
|
||||
.group_by(Organization.name)
|
||||
.order_by(func.count(UserOrganization.user_id).desc())
|
||||
.limit(limit)
|
||||
)
|
||||
return [{"name": row.name, "value": row.count} for row in result.all()]
|
||||
|
||||
|
||||
# Default singleton
|
||||
organization_service = OrganizationService()
|
||||
80
backend/app/services/session_cleanup.py
Normal file → Executable file
80
backend/app/services/session_cleanup.py
Normal file → Executable file
@@ -3,16 +3,17 @@ Background job for cleaning up expired sessions.
|
||||
|
||||
This service runs periodically to remove old session records from the database.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from app.core.database import SessionLocal
|
||||
from app.crud.session import session as session_crud
|
||||
from app.repositories.session import session_repo as session_repo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def cleanup_expired_sessions(keep_days: int = 30) -> int:
|
||||
async def cleanup_expired_sessions(keep_days: int = 30) -> int:
|
||||
"""
|
||||
Clean up expired and inactive sessions.
|
||||
|
||||
@@ -29,52 +30,59 @@ def cleanup_expired_sessions(keep_days: int = 30) -> int:
|
||||
"""
|
||||
logger.info("Starting session cleanup job...")
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# Use CRUD method to cleanup
|
||||
count = session_crud.cleanup_expired(db, keep_days=keep_days)
|
||||
async with SessionLocal() as db:
|
||||
try:
|
||||
# Use repository method to cleanup
|
||||
count = await session_repo.cleanup_expired(db, keep_days=keep_days)
|
||||
|
||||
logger.info(f"Session cleanup complete: {count} sessions deleted")
|
||||
logger.info("Session cleanup complete: %s sessions deleted", count)
|
||||
|
||||
return count
|
||||
return count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during session cleanup: {str(e)}", exc_info=True)
|
||||
return 0
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.exception("Error during session cleanup: %s", e)
|
||||
return 0
|
||||
|
||||
|
||||
def get_session_statistics() -> dict:
|
||||
async def get_session_statistics() -> dict:
|
||||
"""
|
||||
Get statistics about current sessions.
|
||||
|
||||
Returns:
|
||||
Dictionary with session stats
|
||||
"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
from app.models.user_session import UserSession
|
||||
async with SessionLocal() as db:
|
||||
try:
|
||||
from sqlalchemy import func, select
|
||||
|
||||
total_sessions = db.query(UserSession).count()
|
||||
active_sessions = db.query(UserSession).filter(UserSession.is_active == True).count()
|
||||
expired_sessions = db.query(UserSession).filter(
|
||||
UserSession.expires_at < datetime.now(timezone.utc)
|
||||
).count()
|
||||
from app.models.user_session import UserSession
|
||||
|
||||
stats = {
|
||||
"total": total_sessions,
|
||||
"active": active_sessions,
|
||||
"inactive": total_sessions - active_sessions,
|
||||
"expired": expired_sessions,
|
||||
}
|
||||
total_result = await db.execute(select(func.count(UserSession.id)))
|
||||
total_sessions = total_result.scalar_one()
|
||||
|
||||
logger.info(f"Session statistics: {stats}")
|
||||
active_result = await db.execute(
|
||||
select(func.count(UserSession.id)).where(UserSession.is_active)
|
||||
)
|
||||
active_sessions = active_result.scalar_one()
|
||||
|
||||
return stats
|
||||
expired_result = await db.execute(
|
||||
select(func.count(UserSession.id)).where(
|
||||
UserSession.expires_at < datetime.now(UTC)
|
||||
)
|
||||
)
|
||||
expired_sessions = expired_result.scalar_one()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting session statistics: {str(e)}", exc_info=True)
|
||||
return {}
|
||||
finally:
|
||||
db.close()
|
||||
stats = {
|
||||
"total": total_sessions,
|
||||
"active": active_sessions,
|
||||
"inactive": total_sessions - active_sessions,
|
||||
"expired": expired_sessions,
|
||||
}
|
||||
|
||||
logger.info("Session statistics: %s", stats)
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error getting session statistics: %s", e)
|
||||
return {}
|
||||
|
||||
97
backend/app/services/session_service.py
Normal file
97
backend/app/services/session_service.py
Normal file
@@ -0,0 +1,97 @@
|
||||
# app/services/session_service.py
|
||||
"""Service layer for session operations — delegates to SessionRepository."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.user_session import UserSession
|
||||
from app.repositories.session import SessionRepository, session_repo
|
||||
from app.schemas.sessions import SessionCreate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SessionService:
|
||||
"""Service for user session management operations."""
|
||||
|
||||
def __init__(self, session_repository: SessionRepository | None = None) -> None:
|
||||
self._repo = session_repository or session_repo
|
||||
|
||||
async def create_session(
|
||||
self, db: AsyncSession, *, obj_in: SessionCreate
|
||||
) -> UserSession:
|
||||
"""Create a new session record."""
|
||||
return await self._repo.create_session(db, obj_in=obj_in)
|
||||
|
||||
async def get_session(
|
||||
self, db: AsyncSession, session_id: str
|
||||
) -> UserSession | None:
|
||||
"""Get session by ID."""
|
||||
return await self._repo.get(db, id=session_id)
|
||||
|
||||
async def get_user_sessions(
|
||||
self, db: AsyncSession, *, user_id: str, active_only: bool = True
|
||||
) -> list[UserSession]:
|
||||
"""Get all sessions for a user."""
|
||||
return await self._repo.get_user_sessions(
|
||||
db, user_id=user_id, active_only=active_only
|
||||
)
|
||||
|
||||
async def get_active_by_jti(
|
||||
self, db: AsyncSession, *, jti: str
|
||||
) -> UserSession | None:
|
||||
"""Get active session by refresh token JTI."""
|
||||
return await self._repo.get_active_by_jti(db, jti=jti)
|
||||
|
||||
async def get_by_jti(self, db: AsyncSession, *, jti: str) -> UserSession | None:
|
||||
"""Get session by refresh token JTI (active or inactive)."""
|
||||
return await self._repo.get_by_jti(db, jti=jti)
|
||||
|
||||
async def deactivate(
|
||||
self, db: AsyncSession, *, session_id: str
|
||||
) -> UserSession | None:
|
||||
"""Deactivate a session (logout from device)."""
|
||||
return await self._repo.deactivate(db, session_id=session_id)
|
||||
|
||||
async def deactivate_all_user_sessions(
|
||||
self, db: AsyncSession, *, user_id: str
|
||||
) -> int:
|
||||
"""Deactivate all sessions for a user. Returns count deactivated."""
|
||||
return await self._repo.deactivate_all_user_sessions(db, user_id=user_id)
|
||||
|
||||
async def update_refresh_token(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
session: UserSession,
|
||||
new_jti: str,
|
||||
new_expires_at: datetime,
|
||||
) -> UserSession:
|
||||
"""Update session with a rotated refresh token."""
|
||||
return await self._repo.update_refresh_token(
|
||||
db, session=session, new_jti=new_jti, new_expires_at=new_expires_at
|
||||
)
|
||||
|
||||
async def cleanup_expired_for_user(self, db: AsyncSession, *, user_id: str) -> int:
|
||||
"""Remove expired sessions for a user. Returns count removed."""
|
||||
return await self._repo.cleanup_expired_for_user(db, user_id=user_id)
|
||||
|
||||
async def get_all_sessions(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
active_only: bool = True,
|
||||
with_user: bool = True,
|
||||
) -> tuple[list[UserSession], int]:
|
||||
"""Get all sessions with pagination (admin only)."""
|
||||
return await self._repo.get_all_sessions(
|
||||
db, skip=skip, limit=limit, active_only=active_only, with_user=with_user
|
||||
)
|
||||
|
||||
|
||||
# Default singleton
|
||||
session_service = SessionService()
|
||||
120
backend/app/services/user_service.py
Normal file
120
backend/app/services/user_service.py
Normal file
@@ -0,0 +1,120 @@
|
||||
# app/services/user_service.py
|
||||
"""Service layer for user operations — delegates to UserRepository."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.exceptions import NotFoundError
|
||||
from app.models.user import User
|
||||
from app.repositories.user import UserRepository, user_repo
|
||||
from app.schemas.users import UserCreate, UserUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UserService:
|
||||
"""Service for user management operations."""
|
||||
|
||||
def __init__(self, user_repository: UserRepository | None = None) -> None:
|
||||
self._repo = user_repository or user_repo
|
||||
|
||||
async def get_user(self, db: AsyncSession, user_id: str) -> User:
|
||||
"""Get user by ID, raising NotFoundError if not found."""
|
||||
user = await self._repo.get(db, id=user_id)
|
||||
if not user:
|
||||
raise NotFoundError(f"User {user_id} not found")
|
||||
return user
|
||||
|
||||
async def get_by_email(self, db: AsyncSession, email: str) -> User | None:
|
||||
"""Get user by email address."""
|
||||
return await self._repo.get_by_email(db, email=email)
|
||||
|
||||
async def create_user(self, db: AsyncSession, user_data: UserCreate) -> User:
|
||||
"""Create a new user."""
|
||||
return await self._repo.create(db, obj_in=user_data)
|
||||
|
||||
async def update_user(
|
||||
self, db: AsyncSession, *, user: User, obj_in: UserUpdate | dict[str, Any]
|
||||
) -> User:
|
||||
"""Update an existing user."""
|
||||
return await self._repo.update(db, db_obj=user, obj_in=obj_in)
|
||||
|
||||
async def soft_delete_user(self, db: AsyncSession, user_id: str) -> None:
|
||||
"""Soft-delete a user by ID."""
|
||||
await self._repo.soft_delete(db, id=user_id)
|
||||
|
||||
async def list_users(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
sort_by: str | None = None,
|
||||
sort_order: str = "asc",
|
||||
filters: dict[str, Any] | None = None,
|
||||
search: str | None = None,
|
||||
) -> tuple[list[User], int]:
|
||||
"""List users with pagination, sorting, filtering, and search."""
|
||||
return await self._repo.get_multi_with_total(
|
||||
db,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
sort_by=sort_by,
|
||||
sort_order=sort_order,
|
||||
filters=filters,
|
||||
search=search,
|
||||
)
|
||||
|
||||
async def bulk_update_status(
|
||||
self, db: AsyncSession, *, user_ids: list[UUID], is_active: bool
|
||||
) -> int:
|
||||
"""Bulk update active status for multiple users. Returns count updated."""
|
||||
return await self._repo.bulk_update_status(
|
||||
db, user_ids=user_ids, is_active=is_active
|
||||
)
|
||||
|
||||
async def bulk_soft_delete(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_ids: list[UUID],
|
||||
exclude_user_id: UUID | None = None,
|
||||
) -> int:
|
||||
"""Bulk soft-delete multiple users. Returns count deleted."""
|
||||
return await self._repo.bulk_soft_delete(
|
||||
db, user_ids=user_ids, exclude_user_id=exclude_user_id
|
||||
)
|
||||
|
||||
async def get_stats(self, db: AsyncSession) -> dict[str, Any]:
|
||||
"""Return user stats needed for the admin dashboard."""
|
||||
from sqlalchemy import func, select
|
||||
|
||||
total_users = (
|
||||
await db.execute(select(func.count()).select_from(User))
|
||||
).scalar() or 0
|
||||
active_count = (
|
||||
await db.execute(
|
||||
select(func.count()).select_from(User).where(User.is_active)
|
||||
)
|
||||
).scalar() or 0
|
||||
inactive_count = (
|
||||
await db.execute(
|
||||
select(func.count()).select_from(User).where(User.is_active.is_(False))
|
||||
)
|
||||
).scalar() or 0
|
||||
all_users = list(
|
||||
(await db.execute(select(User).order_by(User.created_at))).scalars().all()
|
||||
)
|
||||
return {
|
||||
"total_users": total_users,
|
||||
"active_count": active_count,
|
||||
"inactive_count": inactive_count,
|
||||
"all_users": all_users,
|
||||
}
|
||||
|
||||
|
||||
# Default singleton
|
||||
user_service = UserService()
|
||||
@@ -2,7 +2,8 @@
|
||||
Authentication utilities for testing.
|
||||
This module provides tools to bypass FastAPI's authentication in tests.
|
||||
"""
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
@@ -13,9 +14,9 @@ from app.models.user import User
|
||||
|
||||
|
||||
def create_test_auth_client(
|
||||
app: FastAPI,
|
||||
test_user: User,
|
||||
extra_overrides: Optional[Dict[Callable, Callable]] = None
|
||||
app: FastAPI,
|
||||
test_user: User,
|
||||
extra_overrides: dict[Callable, Callable] | None = None,
|
||||
) -> TestClient:
|
||||
"""
|
||||
Create a test client with authentication pre-configured.
|
||||
@@ -47,10 +48,7 @@ def create_test_auth_client(
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def create_test_optional_auth_client(
|
||||
app: FastAPI,
|
||||
test_user: User
|
||||
) -> TestClient:
|
||||
def create_test_optional_auth_client(app: FastAPI, test_user: User) -> TestClient:
|
||||
"""
|
||||
Create a test client with optional authentication pre-configured.
|
||||
|
||||
@@ -70,10 +68,7 @@ def create_test_optional_auth_client(
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def create_test_superuser_client(
|
||||
app: FastAPI,
|
||||
test_user: User
|
||||
) -> TestClient:
|
||||
def create_test_superuser_client(app: FastAPI, test_user: User) -> TestClient:
|
||||
"""
|
||||
Create a test client with superuser authentication pre-configured.
|
||||
|
||||
@@ -120,7 +115,7 @@ def cleanup_test_client_auth(app: FastAPI) -> None:
|
||||
auth_deps = [
|
||||
get_current_user,
|
||||
get_optional_current_user,
|
||||
OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||
OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login"),
|
||||
]
|
||||
|
||||
# Remove overrides
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user