forked from cardosofelipe/pragma-stack
Compare commits
23 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a94e29d99c | ||
|
|
81e48c73ca | ||
|
|
a3f78dc801 | ||
|
|
07309013d7 | ||
|
|
846fc31190 | ||
|
|
ff7a67cb58 | ||
|
|
0760a8284d | ||
|
|
ce4d0c7b0d | ||
|
|
4ceb8ad98c | ||
|
|
f8aafb250d | ||
|
|
4385d20ca6 | ||
|
|
1a36907f10 | ||
|
|
0553a1fc53 | ||
|
|
57e969ed67 | ||
|
|
68275b1dd3 | ||
|
|
80d2dc0cb2 | ||
|
|
a8aa416ecb | ||
|
|
4c6bf55bcc | ||
|
|
98b455fdc3 | ||
|
|
0646c96b19 | ||
|
|
62afb328fe | ||
|
|
b9a746bc16 | ||
|
|
de8e18e97d |
2
.github/workflows/README.md
vendored
2
.github/workflows/README.md
vendored
@@ -41,7 +41,7 @@ To enable CI/CD workflows:
|
||||
- Runs on: Push to main/develop, PRs affecting frontend code
|
||||
- Tests: Frontend unit tests (Jest)
|
||||
- Coverage: Uploads to Codecov
|
||||
- Fast: Uses npm cache
|
||||
- Fast: Uses bun cache
|
||||
|
||||
### `e2e-tests.yml`
|
||||
- Runs on: All pushes and PRs
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -187,7 +187,7 @@ coverage.xml
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
backend/.benchmarks
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
40
AGENTS.md
40
AGENTS.md
@@ -13,10 +13,10 @@ uv run uvicorn app.main:app --reload # Start dev server
|
||||
|
||||
# Frontend (Node.js)
|
||||
cd frontend
|
||||
npm install # Install dependencies
|
||||
npm run dev # Start dev server
|
||||
npm run generate:api # Generate API client from OpenAPI
|
||||
npm run test:e2e # Run E2E tests
|
||||
bun install # Install dependencies
|
||||
bun run dev # Start dev server
|
||||
bun run generate:api # Generate API client from OpenAPI
|
||||
bun run test:e2e # Run E2E tests
|
||||
```
|
||||
|
||||
**Access points:**
|
||||
@@ -37,7 +37,7 @@ Default superuser (change in production):
|
||||
│ ├── app/
|
||||
│ │ ├── api/ # API routes (auth, users, organizations, admin)
|
||||
│ │ ├── core/ # Core functionality (auth, config, database)
|
||||
│ │ ├── crud/ # Database CRUD operations
|
||||
│ │ ├── repositories/ # Repository pattern (database operations)
|
||||
│ │ ├── models/ # SQLAlchemy ORM models
|
||||
│ │ ├── schemas/ # Pydantic request/response schemas
|
||||
│ │ ├── services/ # Business logic layer
|
||||
@@ -113,7 +113,7 @@ OAUTH_ISSUER=https://api.yourdomain.com # JWT issuer URL (must be HTTPS in
|
||||
### Database Pattern
|
||||
- **Async SQLAlchemy 2.0** with PostgreSQL
|
||||
- **Connection pooling**: 20 base connections, 50 max overflow
|
||||
- **CRUD base class**: `crud/base.py` with common operations
|
||||
- **Repository base class**: `repositories/base.py` with common operations
|
||||
- **Migrations**: Alembic with helper script `migrate.py`
|
||||
- `python migrate.py auto "message"` - Generate and apply
|
||||
- `python migrate.py list` - View history
|
||||
@@ -121,7 +121,7 @@ OAUTH_ISSUER=https://api.yourdomain.com # JWT issuer URL (must be HTTPS in
|
||||
### Frontend State Management
|
||||
- **Zustand stores**: Lightweight state management
|
||||
- **TanStack Query**: API data fetching/caching
|
||||
- **Auto-generated client**: From OpenAPI spec via `npm run generate:api`
|
||||
- **Auto-generated client**: From OpenAPI spec via `bun run generate:api`
|
||||
- **Dependency Injection**: ALWAYS use `useAuth()` from `AuthContext`, NEVER import `useAuthStore` directly
|
||||
|
||||
### Internationalization (i18n)
|
||||
@@ -165,21 +165,25 @@ Permission dependencies in `api/dependencies/permissions.py`:
|
||||
**Frontend Unit Tests (Jest):**
|
||||
- 97% coverage
|
||||
- Component, hook, and utility testing
|
||||
- Run: `npm test`
|
||||
- Coverage: `npm run test:coverage`
|
||||
- Run: `bun run test`
|
||||
- Coverage: `bun run test:coverage`
|
||||
|
||||
**Frontend E2E Tests (Playwright):**
|
||||
- 56 passing, 1 skipped (zero flaky tests)
|
||||
- Complete user flows (auth, navigation, settings)
|
||||
- Run: `npm run test:e2e`
|
||||
- UI mode: `npm run test:e2e:ui`
|
||||
- Run: `bun run test:e2e`
|
||||
- UI mode: `bun run test:e2e:ui`
|
||||
|
||||
### Development Tooling
|
||||
|
||||
**Backend:**
|
||||
- **uv**: Modern Python package manager (10-100x faster than pip)
|
||||
- **Ruff**: All-in-one linting/formatting (replaces Black, Flake8, isort)
|
||||
- **mypy**: Type checking with Pydantic plugin
|
||||
- **Pyright**: Static type checking (strict mode)
|
||||
- **pip-audit**: Dependency vulnerability scanning (OSV database)
|
||||
- **detect-secrets**: Hardcoded secrets detection
|
||||
- **pip-licenses**: License compliance checking
|
||||
- **pre-commit**: Git hook framework (Ruff, detect-secrets, standard checks)
|
||||
- **Makefile**: `make help` for all commands
|
||||
|
||||
**Frontend:**
|
||||
@@ -218,11 +222,11 @@ NEXT_PUBLIC_API_URL=http://localhost:8000/api/v1
|
||||
### Adding a New API Endpoint
|
||||
|
||||
1. **Define schema** in `backend/app/schemas/`
|
||||
2. **Create CRUD operations** in `backend/app/crud/`
|
||||
2. **Create repository** in `backend/app/repositories/`
|
||||
3. **Implement route** in `backend/app/api/routes/`
|
||||
4. **Register router** in `backend/app/api/main.py`
|
||||
5. **Write tests** in `backend/tests/api/`
|
||||
6. **Generate frontend client**: `npm run generate:api`
|
||||
6. **Generate frontend client**: `bun run generate:api`
|
||||
|
||||
### Database Migrations
|
||||
|
||||
@@ -239,7 +243,7 @@ python migrate.py auto "description" # Generate + apply
|
||||
2. **Follow design system** (see `frontend/docs/design-system/`)
|
||||
3. **Use dependency injection** for auth (`useAuth()` not `useAuthStore`)
|
||||
4. **Write tests** in `frontend/tests/` or `__tests__/`
|
||||
5. **Run type check**: `npm run type-check`
|
||||
5. **Run type check**: `bun run type-check`
|
||||
|
||||
## Security Features
|
||||
|
||||
@@ -249,6 +253,10 @@ python migrate.py auto "description" # Generate + apply
|
||||
- **CSRF protection**: Built into FastAPI
|
||||
- **Session revocation**: Database-backed session tracking
|
||||
- **Comprehensive security tests**: JWT algorithm attacks, session hijacking, privilege escalation
|
||||
- **Dependency vulnerability scanning**: `make dep-audit` (pip-audit against OSV database)
|
||||
- **License compliance**: `make license-check` (blocks GPL-3.0/AGPL)
|
||||
- **Secrets detection**: Pre-commit hook blocks hardcoded secrets
|
||||
- **Unified security pipeline**: `make audit` (all security checks), `make check` (quality + security + tests)
|
||||
|
||||
## Docker Deployment
|
||||
|
||||
@@ -281,7 +289,7 @@ docker-compose exec backend python -c "from app.init_db import init_db; import a
|
||||
- Authentication system (JWT with refresh tokens, OAuth/social login)
|
||||
- **OAuth Provider Mode (MCP-ready)**: Full OAuth 2.0 Authorization Server
|
||||
- Session management (device tracking, revocation)
|
||||
- User management (CRUD, password change)
|
||||
- User management (full lifecycle, password change)
|
||||
- Organization system (multi-tenant with RBAC)
|
||||
- Admin panel (user/org management, bulk operations)
|
||||
- **Internationalization (i18n)** with English and Italian
|
||||
|
||||
30
CLAUDE.md
30
CLAUDE.md
@@ -43,7 +43,7 @@ EOF
|
||||
- Check current state: `python migrate.py current`
|
||||
|
||||
**Frontend API Client Generation:**
|
||||
- Run `npm run generate:api` after backend schema changes
|
||||
- Run `bun run generate:api` after backend schema changes
|
||||
- Client is auto-generated from OpenAPI spec
|
||||
- Located in `frontend/src/lib/api/generated/`
|
||||
- NEVER manually edit generated files
|
||||
@@ -51,10 +51,16 @@ EOF
|
||||
**Testing Commands:**
|
||||
- Backend unit/integration: `IS_TEST=True uv run pytest` (always prefix with `IS_TEST=True`)
|
||||
- Backend E2E (requires Docker): `make test-e2e`
|
||||
- Frontend unit: `npm test`
|
||||
- Frontend E2E: `npm run test:e2e`
|
||||
- Frontend unit: `bun run test`
|
||||
- Frontend E2E: `bun run test:e2e`
|
||||
- Use `make test` or `make test-cov` in backend for convenience
|
||||
|
||||
**Security & Quality Commands (Backend):**
|
||||
- `make validate` — lint + format + type checks
|
||||
- `make audit` — dependency vulnerabilities + license compliance
|
||||
- `make validate-all` — quality + security checks
|
||||
- `make check` — **full pipeline**: quality + security + tests
|
||||
|
||||
**Backend E2E Testing (requires Docker):**
|
||||
- Install deps: `make install-e2e`
|
||||
- Run all E2E tests: `make test-e2e`
|
||||
@@ -142,7 +148,7 @@ async def mock_commit():
|
||||
with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
|
||||
with pytest.raises(OperationalError):
|
||||
await crud_method(session, obj_in=data)
|
||||
await repo_method(session, obj_in=data)
|
||||
mock_rollback.assert_called_once()
|
||||
```
|
||||
|
||||
@@ -157,14 +163,18 @@ with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
- Never skip security headers in production
|
||||
- Rate limiting is configured in route decorators: `@limiter.limit("10/minute")`
|
||||
- Session revocation is database-backed, not just JWT expiry
|
||||
- Run `make audit` to check for dependency vulnerabilities and license compliance
|
||||
- Run `make check` for the full pipeline: quality + security + tests
|
||||
- Pre-commit hooks enforce Ruff lint/format and detect-secrets on every commit
|
||||
- Setup hooks: `cd backend && uv run pre-commit install`
|
||||
|
||||
### Common Workflows Guidance
|
||||
|
||||
**When Adding a New Feature:**
|
||||
1. Start with backend schema and CRUD
|
||||
1. Start with backend schema and repository
|
||||
2. Implement API route with proper authorization
|
||||
3. Write backend tests (aim for >90% coverage)
|
||||
4. Generate frontend API client: `npm run generate:api`
|
||||
4. Generate frontend API client: `bun run generate:api`
|
||||
5. Implement frontend components
|
||||
6. Write frontend unit tests
|
||||
7. Add E2E tests for critical flows
|
||||
@@ -177,8 +187,8 @@ with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
|
||||
**When Debugging:**
|
||||
- Backend: Check `IS_TEST=True` environment variable is set
|
||||
- Frontend: Run `npm run type-check` first
|
||||
- E2E: Use `npm run test:e2e:debug` for step-by-step debugging
|
||||
- Frontend: Run `bun run type-check` first
|
||||
- E2E: Use `bun run test:e2e:debug` for step-by-step debugging
|
||||
- Check logs: Backend has detailed error logging
|
||||
|
||||
**Demo Mode (Frontend-Only Showcase):**
|
||||
@@ -186,7 +196,7 @@ with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
- Uses MSW (Mock Service Worker) to intercept API calls in browser
|
||||
- Zero backend required - perfect for Vercel deployments
|
||||
- **Fully Automated**: MSW handlers auto-generated from OpenAPI spec
|
||||
- Run `npm run generate:api` → updates both API client AND MSW handlers
|
||||
- Run `bun run generate:api` → updates both API client AND MSW handlers
|
||||
- No manual synchronization needed!
|
||||
- Demo credentials (any password ≥8 chars works):
|
||||
- User: `demo@example.com` / `DemoPass123`
|
||||
@@ -214,7 +224,7 @@ with patch.object(session, 'commit', side_effect=mock_commit):
|
||||
No Claude Code Skills installed yet. To create one, invoke the built-in "skill-creator" skill.
|
||||
|
||||
**Potential skill ideas for this project:**
|
||||
- API endpoint generator workflow (schema → CRUD → route → tests → frontend client)
|
||||
- API endpoint generator workflow (schema → repository → route → tests → frontend client)
|
||||
- Component generator with design system compliance
|
||||
- Database migration troubleshooting helper
|
||||
- Test coverage analyzer and improvement suggester
|
||||
|
||||
@@ -91,7 +91,10 @@ Ready to write some code? Awesome!
|
||||
cd backend
|
||||
|
||||
# Install dependencies (uv manages virtual environment automatically)
|
||||
uv sync
|
||||
make install-dev
|
||||
|
||||
# Setup pre-commit hooks
|
||||
uv run pre-commit install
|
||||
|
||||
# Setup environment
|
||||
cp .env.example .env
|
||||
@@ -100,8 +103,14 @@ cp .env.example .env
|
||||
# Run migrations
|
||||
python migrate.py apply
|
||||
|
||||
# Run quality + security checks
|
||||
make validate-all
|
||||
|
||||
# Run tests
|
||||
IS_TEST=True uv run pytest
|
||||
make test
|
||||
|
||||
# Run full pipeline (quality + security + tests)
|
||||
make check
|
||||
|
||||
# Start dev server
|
||||
uvicorn app.main:app --reload
|
||||
@@ -113,20 +122,20 @@ uvicorn app.main:app --reload
|
||||
cd frontend
|
||||
|
||||
# Install dependencies
|
||||
npm install
|
||||
bun install
|
||||
|
||||
# Setup environment
|
||||
cp .env.local.example .env.local
|
||||
|
||||
# Generate API client
|
||||
npm run generate:api
|
||||
bun run generate:api
|
||||
|
||||
# Run tests
|
||||
npm test
|
||||
npm run test:e2e:ui
|
||||
bun run test
|
||||
bun run test:e2e:ui
|
||||
|
||||
# Start dev server
|
||||
npm run dev
|
||||
bun run dev
|
||||
```
|
||||
|
||||
---
|
||||
@@ -195,7 +204,7 @@ export function UserProfile({ userId }: UserProfileProps) {
|
||||
|
||||
### Key Patterns
|
||||
|
||||
- **Backend**: Use CRUD pattern, keep routes thin, business logic in services
|
||||
- **Backend**: Use repository pattern, keep routes thin, business logic in services
|
||||
- **Frontend**: Use React Query for server state, Zustand for client state
|
||||
- **Both**: Handle errors gracefully, log appropriately, write tests
|
||||
|
||||
@@ -316,7 +325,7 @@ Fixed stuff
|
||||
### Before Submitting
|
||||
|
||||
- [ ] Code follows project style guidelines
|
||||
- [ ] All tests pass locally
|
||||
- [ ] `make check` passes (quality + security + tests) in backend
|
||||
- [ ] New tests added for new features
|
||||
- [ ] Documentation updated if needed
|
||||
- [ ] No merge conflicts with `main`
|
||||
|
||||
25
Makefile
25
Makefile
@@ -1,4 +1,4 @@
|
||||
.PHONY: help dev dev-full prod down logs logs-dev clean clean-slate drop-db reset-db push-images deploy
|
||||
.PHONY: help dev dev-full prod down logs logs-dev clean clean-slate drop-db reset-db push-images deploy scan-images
|
||||
|
||||
VERSION ?= latest
|
||||
REGISTRY ?= ghcr.io/cardosofelipe/pragma-stack
|
||||
@@ -21,6 +21,7 @@ help:
|
||||
@echo " make prod - Start production stack"
|
||||
@echo " make deploy - Pull and deploy latest images"
|
||||
@echo " make push-images - Build and push images to registry"
|
||||
@echo " make scan-images - Scan production images for CVEs (requires trivy)"
|
||||
@echo " make logs - Follow production container logs"
|
||||
@echo ""
|
||||
@echo "Cleanup:"
|
||||
@@ -89,6 +90,28 @@ push-images:
|
||||
docker push $(REGISTRY)/backend:$(VERSION)
|
||||
docker push $(REGISTRY)/frontend:$(VERSION)
|
||||
|
||||
scan-images:
|
||||
@docker info > /dev/null 2>&1 || (echo "❌ Docker is not running!"; exit 1)
|
||||
@echo "🐳 Building and scanning production images for CVEs..."
|
||||
docker build -t $(REGISTRY)/backend:scan --target production ./backend
|
||||
docker build -t $(REGISTRY)/frontend:scan --target runner ./frontend
|
||||
@echo ""
|
||||
@echo "=== Backend Image Scan ==="
|
||||
@if command -v trivy > /dev/null 2>&1; then \
|
||||
trivy image --severity HIGH,CRITICAL --exit-code 1 $(REGISTRY)/backend:scan; \
|
||||
else \
|
||||
echo "ℹ️ Trivy not found locally, using Docker to run Trivy..."; \
|
||||
docker run --rm -v /var/run/docker.sock:/var/run/docker.sock aquasec/trivy image --severity HIGH,CRITICAL --exit-code 1 $(REGISTRY)/backend:scan; \
|
||||
fi
|
||||
@echo ""
|
||||
@echo "=== Frontend Image Scan ==="
|
||||
@if command -v trivy > /dev/null 2>&1; then \
|
||||
trivy image --severity HIGH,CRITICAL --exit-code 1 $(REGISTRY)/frontend:scan; \
|
||||
else \
|
||||
docker run --rm -v /var/run/docker.sock:/var/run/docker.sock aquasec/trivy image --severity HIGH,CRITICAL --exit-code 1 $(REGISTRY)/frontend:scan; \
|
||||
fi
|
||||
@echo "✅ No HIGH/CRITICAL CVEs found in production images!"
|
||||
|
||||
# ============================================================================
|
||||
# Cleanup
|
||||
# ============================================================================
|
||||
|
||||
26
README.md
26
README.md
@@ -58,7 +58,7 @@ Full OAuth 2.0 Authorization Server for Model Context Protocol (MCP) and third-p
|
||||
- User can belong to multiple organizations
|
||||
|
||||
### 🛠️ **Admin Panel**
|
||||
- Complete user management (CRUD, activate/deactivate, bulk operations)
|
||||
- Complete user management (full lifecycle, activate/deactivate, bulk operations)
|
||||
- Organization management (create, edit, delete, member management)
|
||||
- Session monitoring across all users
|
||||
- Real-time statistics dashboard
|
||||
@@ -166,7 +166,7 @@ Full OAuth 2.0 Authorization Server for Model Context Protocol (MCP) and third-p
|
||||
```bash
|
||||
cd frontend
|
||||
echo "NEXT_PUBLIC_DEMO_MODE=true" > .env.local
|
||||
npm run dev
|
||||
bun run dev
|
||||
```
|
||||
|
||||
**Demo Credentials:**
|
||||
@@ -298,17 +298,17 @@ uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
|
||||
cd frontend
|
||||
|
||||
# Install dependencies
|
||||
npm install
|
||||
bun install
|
||||
|
||||
# Setup environment
|
||||
cp .env.local.example .env.local
|
||||
# Edit .env.local with your backend URL
|
||||
|
||||
# Generate API client
|
||||
npm run generate:api
|
||||
bun run generate:api
|
||||
|
||||
# Start development server
|
||||
npm run dev
|
||||
bun run dev
|
||||
```
|
||||
|
||||
Visit http://localhost:3000 to see your app!
|
||||
@@ -322,7 +322,7 @@ Visit http://localhost:3000 to see your app!
|
||||
│ ├── app/
|
||||
│ │ ├── api/ # API routes and dependencies
|
||||
│ │ ├── core/ # Core functionality (auth, config, database)
|
||||
│ │ ├── crud/ # Database operations
|
||||
│ │ ├── repositories/ # Repository pattern (database operations)
|
||||
│ │ ├── models/ # SQLAlchemy models
|
||||
│ │ ├── schemas/ # Pydantic schemas
|
||||
│ │ ├── services/ # Business logic
|
||||
@@ -377,7 +377,7 @@ open htmlcov/index.html
|
||||
```
|
||||
|
||||
**Test types:**
|
||||
- **Unit tests**: CRUD operations, utilities, business logic
|
||||
- **Unit tests**: Repository operations, utilities, business logic
|
||||
- **Integration tests**: API endpoints with database
|
||||
- **Security tests**: JWT algorithm attacks, session hijacking, privilege escalation
|
||||
- **Error handling tests**: Database failures, validation errors
|
||||
@@ -390,13 +390,13 @@ open htmlcov/index.html
|
||||
cd frontend
|
||||
|
||||
# Run unit tests
|
||||
npm test
|
||||
bun run test
|
||||
|
||||
# Run with coverage
|
||||
npm run test:coverage
|
||||
bun run test:coverage
|
||||
|
||||
# Watch mode
|
||||
npm run test:watch
|
||||
bun run test:watch
|
||||
```
|
||||
|
||||
**Test types:**
|
||||
@@ -414,10 +414,10 @@ npm run test:watch
|
||||
cd frontend
|
||||
|
||||
# Run E2E tests
|
||||
npm run test:e2e
|
||||
bun run test:e2e
|
||||
|
||||
# Run E2E tests in UI mode (recommended for development)
|
||||
npm run test:e2e:ui
|
||||
bun run test:e2e:ui
|
||||
|
||||
# Run specific test file
|
||||
npx playwright test auth-login.spec.ts
|
||||
@@ -542,7 +542,7 @@ docker-compose down
|
||||
|
||||
### ✅ Completed
|
||||
- [x] Authentication system (JWT, refresh tokens, session management, OAuth)
|
||||
- [x] User management (CRUD, profile, password change)
|
||||
- [x] User management (full lifecycle, profile, password change)
|
||||
- [x] Organization system with RBAC (Owner, Admin, Member)
|
||||
- [x] Admin panel (users, organizations, sessions, statistics)
|
||||
- [x] **Internationalization (i18n)** with next-intl (English + Italian)
|
||||
|
||||
@@ -11,7 +11,7 @@ omit =
|
||||
app/utils/auth_test_utils.py
|
||||
|
||||
# Async implementations not yet in use
|
||||
app/crud/base_async.py
|
||||
app/repositories/base_async.py
|
||||
app/core/database_async.py
|
||||
|
||||
# CLI scripts - run manually, not tested
|
||||
@@ -23,7 +23,7 @@ omit =
|
||||
app/api/routes/__init__.py
|
||||
app/api/dependencies/__init__.py
|
||||
app/core/__init__.py
|
||||
app/crud/__init__.py
|
||||
app/repositories/__init__.py
|
||||
app/models/__init__.py
|
||||
app/schemas/__init__.py
|
||||
app/services/__init__.py
|
||||
|
||||
44
backend/.pre-commit-config.yaml
Normal file
44
backend/.pre-commit-config.yaml
Normal file
@@ -0,0 +1,44 @@
|
||||
# Pre-commit hooks for backend quality and security checks.
|
||||
#
|
||||
# Install:
|
||||
# cd backend && uv run pre-commit install
|
||||
#
|
||||
# Run manually on all files:
|
||||
# cd backend && uv run pre-commit run --all-files
|
||||
#
|
||||
# Skip hooks temporarily:
|
||||
# git commit --no-verify
|
||||
#
|
||||
repos:
|
||||
# ── Code Quality ──────────────────────────────────────────────────────────
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.14.4
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix, --exit-non-zero-on-fix]
|
||||
- id: ruff-format
|
||||
|
||||
# ── General File Hygiene ──────────────────────────────────────────────────
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v5.0.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
- id: end-of-file-fixer
|
||||
- id: check-yaml
|
||||
- id: check-toml
|
||||
- id: check-merge-conflict
|
||||
- id: check-added-large-files
|
||||
args: [--maxkb=500]
|
||||
- id: debug-statements
|
||||
|
||||
# ── Security ──────────────────────────────────────────────────────────────
|
||||
- repo: https://github.com/Yelp/detect-secrets
|
||||
rev: v1.5.0
|
||||
hooks:
|
||||
- id: detect-secrets
|
||||
args: ['--baseline', '.secrets.baseline']
|
||||
exclude: |
|
||||
(?x)^(
|
||||
.*\.lock$|
|
||||
.*\.svg$
|
||||
)$
|
||||
1073
backend/.secrets.baseline
Normal file
1073
backend/.secrets.baseline
Normal file
File diff suppressed because it is too large
Load Diff
@@ -33,11 +33,11 @@ RUN chmod +x /usr/local/bin/entrypoint.sh
|
||||
|
||||
ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
|
||||
|
||||
# Production stage
|
||||
FROM python:3.12-slim AS production
|
||||
# Production stage — Alpine eliminates glibc CVEs (e.g. CVE-2026-0861)
|
||||
FROM python:3.12-alpine AS production
|
||||
|
||||
# Create non-root user
|
||||
RUN groupadd -r appuser && useradd -r -g appuser appuser
|
||||
RUN addgroup -S appuser && adduser -S -G appuser appuser
|
||||
|
||||
WORKDIR /app
|
||||
ENV PYTHONDONTWRITEBYTECODE=1 \
|
||||
@@ -48,18 +48,18 @@ ENV PYTHONDONTWRITEBYTECODE=1 \
|
||||
UV_NO_CACHE=1
|
||||
|
||||
# Install system dependencies and uv
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends postgresql-client curl ca-certificates && \
|
||||
RUN apk add --no-cache postgresql-client curl ca-certificates && \
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh && \
|
||||
mv /root/.local/bin/uv* /usr/local/bin/ && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
mv /root/.local/bin/uv* /usr/local/bin/
|
||||
|
||||
# Copy dependency files
|
||||
COPY pyproject.toml uv.lock ./
|
||||
|
||||
# Install only production dependencies using uv (no dev dependencies)
|
||||
RUN uv sync --frozen --no-dev
|
||||
# Install build dependencies, compile Python packages, then remove build deps
|
||||
RUN apk add --no-cache --virtual .build-deps \
|
||||
gcc g++ musl-dev python3-dev linux-headers libffi-dev openssl-dev && \
|
||||
uv sync --frozen --no-dev && \
|
||||
apk del .build-deps
|
||||
|
||||
# Copy application code
|
||||
COPY . .
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
.PHONY: help lint lint-fix format format-check type-check test test-cov validate clean install-dev sync check-docker install-e2e test-e2e test-e2e-schema test-all
|
||||
.PHONY: help lint lint-fix format format-check type-check test test-cov validate clean install-dev sync check-docker install-e2e test-e2e test-e2e-schema test-all dep-audit license-check audit validate-all check benchmark benchmark-check benchmark-save scan-image test-api-security
|
||||
|
||||
# Prevent a stale VIRTUAL_ENV in the caller's shell from confusing uv
|
||||
unexport VIRTUAL_ENV
|
||||
|
||||
# Default target
|
||||
help:
|
||||
@@ -14,8 +17,21 @@ help:
|
||||
@echo " make lint-fix - Run Ruff linter with auto-fix"
|
||||
@echo " make format - Format code with Ruff"
|
||||
@echo " make format-check - Check if code is formatted"
|
||||
@echo " make type-check - Run mypy type checking"
|
||||
@echo " make validate - Run all checks (lint + format + types)"
|
||||
@echo " make type-check - Run pyright type checking"
|
||||
@echo " make validate - Run all checks (lint + format + types + schema fuzz)"
|
||||
@echo ""
|
||||
@echo "Performance:"
|
||||
@echo " make benchmark - Run performance benchmarks"
|
||||
@echo " make benchmark-save - Run benchmarks and save as baseline"
|
||||
@echo " make benchmark-check - Run benchmarks and compare against baseline"
|
||||
@echo ""
|
||||
@echo "Security & Audit:"
|
||||
@echo " make dep-audit - Scan dependencies for known vulnerabilities"
|
||||
@echo " make license-check - Check dependency license compliance"
|
||||
@echo " make audit - Run all security audits (deps + licenses)"
|
||||
@echo " make scan-image - Scan Docker image for CVEs (requires trivy)"
|
||||
@echo " make validate-all - Run all quality + security checks"
|
||||
@echo " make check - Full pipeline: quality + security + tests"
|
||||
@echo ""
|
||||
@echo "Testing:"
|
||||
@echo " make test - Run pytest (unit/integration, SQLite)"
|
||||
@@ -24,6 +40,7 @@ help:
|
||||
@echo " make test-e2e-schema - Run Schemathesis API schema tests"
|
||||
@echo " make test-all - Run all tests (unit + E2E)"
|
||||
@echo " make check-docker - Check if Docker is available"
|
||||
@echo " make check - Full pipeline: quality + security + tests"
|
||||
@echo ""
|
||||
@echo "Cleanup:"
|
||||
@echo " make clean - Remove cache and build artifacts"
|
||||
@@ -63,12 +80,52 @@ format-check:
|
||||
@uv run ruff format --check app/ tests/
|
||||
|
||||
type-check:
|
||||
@echo "🔎 Running mypy type checking..."
|
||||
@uv run mypy app/
|
||||
@echo "🔎 Running pyright type checking..."
|
||||
@uv run pyright app/
|
||||
|
||||
validate: lint format-check type-check
|
||||
validate: lint format-check type-check test-api-security
|
||||
@echo "✅ All quality checks passed!"
|
||||
|
||||
# API Security Testing (Schemathesis property-based fuzzing)
|
||||
test-api-security: check-docker
|
||||
@echo "🔐 Running Schemathesis API security fuzzing..."
|
||||
@IS_TEST=True PYTHONPATH=. uv run pytest tests/e2e/ -v -m "schemathesis" --tb=short -n 0
|
||||
@echo "✅ API schema security tests passed!"
|
||||
|
||||
# ============================================================================
|
||||
# Security & Audit
|
||||
# ============================================================================
|
||||
|
||||
dep-audit:
|
||||
@echo "🔒 Scanning dependencies for known vulnerabilities..."
|
||||
@uv run pip-audit --desc --skip-editable
|
||||
@echo "✅ No known vulnerabilities found!"
|
||||
|
||||
license-check:
|
||||
@echo "📜 Checking dependency license compliance..."
|
||||
@uv run pip-licenses --fail-on="GPL-3.0-or-later;AGPL-3.0-or-later" --format=plain > /dev/null
|
||||
@echo "✅ All dependency licenses are compliant!"
|
||||
|
||||
audit: dep-audit license-check
|
||||
@echo "✅ All security audits passed!"
|
||||
|
||||
scan-image: check-docker
|
||||
@echo "🐳 Scanning Docker image for OS-level CVEs with Trivy..."
|
||||
@docker build -t pragma-backend:scan -q --target production .
|
||||
@if command -v trivy > /dev/null 2>&1; then \
|
||||
trivy image --severity HIGH,CRITICAL --exit-code 1 pragma-backend:scan; \
|
||||
else \
|
||||
echo "ℹ️ Trivy not found locally, using Docker to run Trivy..."; \
|
||||
docker run --rm -v /var/run/docker.sock:/var/run/docker.sock aquasec/trivy image --severity HIGH,CRITICAL --exit-code 1 pragma-backend:scan; \
|
||||
fi
|
||||
@echo "✅ No HIGH/CRITICAL CVEs found in Docker image!"
|
||||
|
||||
validate-all: validate audit
|
||||
@echo "✅ All quality + security checks passed!"
|
||||
|
||||
check: validate-all test
|
||||
@echo "✅ Full validation pipeline complete!"
|
||||
|
||||
# ============================================================================
|
||||
# Testing
|
||||
# ============================================================================
|
||||
@@ -114,6 +171,31 @@ test-e2e-schema: check-docker
|
||||
@echo "🧪 Running Schemathesis API schema tests..."
|
||||
@IS_TEST=True PYTHONPATH=. uv run pytest tests/e2e/ -v -m "schemathesis" --tb=short -n 0
|
||||
|
||||
# ============================================================================
|
||||
# Performance Benchmarks
|
||||
# ============================================================================
|
||||
|
||||
benchmark:
|
||||
@echo "⏱️ Running performance benchmarks..."
|
||||
@IS_TEST=True PYTHONPATH=. uv run pytest tests/benchmarks/ -v --benchmark-only --benchmark-sort=mean -p no:xdist --override-ini='addopts='
|
||||
|
||||
benchmark-save:
|
||||
@echo "⏱️ Running benchmarks and saving baseline..."
|
||||
@IS_TEST=True PYTHONPATH=. uv run pytest tests/benchmarks/ -v --benchmark-only --benchmark-save=baseline --benchmark-sort=mean -p no:xdist --override-ini='addopts='
|
||||
@echo "✅ Benchmark baseline saved to .benchmarks/"
|
||||
|
||||
benchmark-check:
|
||||
@echo "⏱️ Running benchmarks and comparing against baseline..."
|
||||
@if find .benchmarks -name '*_baseline*' -print -quit 2>/dev/null | grep -q .; then \
|
||||
IS_TEST=True PYTHONPATH=. uv run pytest tests/benchmarks/ -v --benchmark-only --benchmark-compare=0001_baseline --benchmark-sort=mean --benchmark-compare-fail=mean:200% -p no:xdist --override-ini='addopts='; \
|
||||
echo "✅ No performance regressions detected!"; \
|
||||
else \
|
||||
echo "⚠️ No benchmark baseline found. Run 'make benchmark-save' first to create one."; \
|
||||
echo " Running benchmarks without comparison..."; \
|
||||
IS_TEST=True PYTHONPATH=. uv run pytest tests/benchmarks/ -v --benchmark-only --benchmark-save=baseline --benchmark-sort=mean -p no:xdist --override-ini='addopts='; \
|
||||
echo "✅ Benchmark baseline created. Future runs of 'make benchmark-check' will compare against it."; \
|
||||
fi
|
||||
|
||||
test-all:
|
||||
@echo "🧪 Running ALL tests (unit + E2E)..."
|
||||
@$(MAKE) test
|
||||
@@ -127,7 +209,7 @@ clean:
|
||||
@echo "🧹 Cleaning up..."
|
||||
@find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true
|
||||
@find . -type d -name ".pytest_cache" -exec rm -rf {} + 2>/dev/null || true
|
||||
@find . -type d -name ".mypy_cache" -exec rm -rf {} + 2>/dev/null || true
|
||||
@find . -type d -name ".pyright" -exec rm -rf {} + 2>/dev/null || true
|
||||
@find . -type d -name ".ruff_cache" -exec rm -rf {} + 2>/dev/null || true
|
||||
@find . -type d -name "*.egg-info" -exec rm -rf {} + 2>/dev/null || true
|
||||
@find . -type d -name "htmlcov" -exec rm -rf {} + 2>/dev/null || true
|
||||
|
||||
@@ -14,7 +14,9 @@ Features:
|
||||
- **Multi-tenancy**: Organization-based access control with roles (Owner/Admin/Member)
|
||||
- **Testing**: 97%+ coverage with security-focused test suite
|
||||
- **Performance**: Async throughout, connection pooling, optimized queries
|
||||
- **Modern Tooling**: uv for dependencies, Ruff for linting/formatting, mypy for type checking
|
||||
- **Modern Tooling**: uv for dependencies, Ruff for linting/formatting, Pyright for type checking
|
||||
- **Security Auditing**: Automated dependency vulnerability scanning, license compliance, secrets detection
|
||||
- **Pre-commit Hooks**: Ruff, detect-secrets, and standard checks on every commit
|
||||
|
||||
## Quick Start
|
||||
|
||||
@@ -149,7 +151,7 @@ uv pip list --outdated
|
||||
# Run any Python command via uv (no activation needed)
|
||||
uv run python script.py
|
||||
uv run pytest
|
||||
uv run mypy app/
|
||||
uv run pyright app/
|
||||
|
||||
# Or activate the virtual environment
|
||||
source .venv/bin/activate
|
||||
@@ -171,12 +173,22 @@ make lint # Run Ruff linter (check only)
|
||||
make lint-fix # Run Ruff with auto-fix
|
||||
make format # Format code with Ruff
|
||||
make format-check # Check if code is formatted
|
||||
make type-check # Run mypy type checking
|
||||
make type-check # Run Pyright type checking
|
||||
make validate # Run all checks (lint + format + types)
|
||||
|
||||
# Security & Audit
|
||||
make dep-audit # Scan dependencies for known vulnerabilities (CVEs)
|
||||
make license-check # Check dependency license compliance
|
||||
make audit # Run all security audits (deps + licenses)
|
||||
make validate-all # Run all quality + security checks
|
||||
make check # Full pipeline: quality + security + tests
|
||||
|
||||
# Testing
|
||||
make test # Run all tests
|
||||
make test-cov # Run tests with coverage report
|
||||
make test-e2e # Run E2E tests (PostgreSQL, requires Docker)
|
||||
make test-e2e-schema # Run Schemathesis API schema tests
|
||||
make test-all # Run all tests (unit + E2E)
|
||||
|
||||
# Utilities
|
||||
make clean # Remove cache and build artifacts
|
||||
@@ -252,7 +264,7 @@ app/
|
||||
│ ├── database.py # Database engine setup
|
||||
│ ├── auth.py # JWT token handling
|
||||
│ └── exceptions.py # Custom exceptions
|
||||
├── crud/ # Database operations
|
||||
├── repositories/ # Repository pattern (database operations)
|
||||
├── models/ # SQLAlchemy ORM models
|
||||
├── schemas/ # Pydantic request/response schemas
|
||||
├── services/ # Business logic layer
|
||||
@@ -352,18 +364,29 @@ open htmlcov/index.html
|
||||
# Using Makefile (recommended)
|
||||
make lint # Ruff linting
|
||||
make format # Ruff formatting
|
||||
make type-check # mypy type checking
|
||||
make type-check # Pyright type checking
|
||||
make validate # All checks at once
|
||||
|
||||
# Security audits
|
||||
make dep-audit # Scan dependencies for CVEs
|
||||
make license-check # Check license compliance
|
||||
make audit # All security audits
|
||||
make validate-all # Quality + security checks
|
||||
make check # Full pipeline: quality + security + tests
|
||||
|
||||
# Using uv directly
|
||||
uv run ruff check app/ tests/
|
||||
uv run ruff format app/ tests/
|
||||
uv run mypy app/
|
||||
uv run pyright app/
|
||||
```
|
||||
|
||||
**Tools:**
|
||||
- **Ruff**: All-in-one linting, formatting, and import sorting (replaces Black, Flake8, isort)
|
||||
- **mypy**: Static type checking with Pydantic plugin
|
||||
- **Pyright**: Static type checking (strict mode)
|
||||
- **pip-audit**: Dependency vulnerability scanning against the OSV database
|
||||
- **pip-licenses**: Dependency license compliance checking
|
||||
- **detect-secrets**: Hardcoded secrets/credentials detection
|
||||
- **pre-commit**: Git hook framework for automated checks on every commit
|
||||
|
||||
All configurations are in `pyproject.toml`.
|
||||
|
||||
@@ -439,7 +462,7 @@ See [docs/FEATURE_EXAMPLE.md](docs/FEATURE_EXAMPLE.md) for step-by-step guide.
|
||||
|
||||
Quick overview:
|
||||
1. Create Pydantic schemas in `app/schemas/`
|
||||
2. Create CRUD operations in `app/crud/`
|
||||
2. Create repository in `app/repositories/`
|
||||
3. Create route in `app/api/routes/`
|
||||
4. Register router in `app/api/main.py`
|
||||
5. Write tests in `tests/api/`
|
||||
@@ -589,13 +612,42 @@ Configured in `app/core/config.py`:
|
||||
- **Security Headers**: CSP, HSTS, X-Frame-Options, etc.
|
||||
- **Input Validation**: Pydantic schemas, SQL injection prevention (ORM)
|
||||
|
||||
### Security Auditing
|
||||
|
||||
Automated, deterministic security checks are built into the development workflow:
|
||||
|
||||
```bash
|
||||
# Scan dependencies for known vulnerabilities (CVEs)
|
||||
make dep-audit
|
||||
|
||||
# Check dependency license compliance (blocks GPL-3.0/AGPL)
|
||||
make license-check
|
||||
|
||||
# Run all security audits
|
||||
make audit
|
||||
|
||||
# Full pipeline: quality + security + tests
|
||||
make check
|
||||
```
|
||||
|
||||
**Pre-commit hooks** automatically run on every commit:
|
||||
- **Ruff** lint + format checks
|
||||
- **detect-secrets** blocks commits containing hardcoded secrets
|
||||
- **Standard checks**: trailing whitespace, YAML/TOML validation, merge conflict detection, large file prevention
|
||||
|
||||
Setup pre-commit hooks:
|
||||
```bash
|
||||
uv run pre-commit install
|
||||
```
|
||||
|
||||
### Security Best Practices
|
||||
|
||||
1. **Never commit secrets**: Use `.env` files (git-ignored)
|
||||
1. **Never commit secrets**: Use `.env` files (git-ignored), enforced by detect-secrets pre-commit hook
|
||||
2. **Strong SECRET_KEY**: Min 32 chars, cryptographically random
|
||||
3. **HTTPS in production**: Required for token security
|
||||
4. **Regular updates**: Keep dependencies current (`uv sync --upgrade`)
|
||||
4. **Regular updates**: Keep dependencies current (`uv sync --upgrade`), run `make dep-audit` to check for CVEs
|
||||
5. **Audit logs**: Monitor authentication events
|
||||
6. **Run `make check` before pushing**: Validates quality, security, and tests in one command
|
||||
|
||||
---
|
||||
|
||||
@@ -645,7 +697,11 @@ logging.basicConfig(level=logging.INFO)
|
||||
**Built with modern Python tooling:**
|
||||
- 🚀 **uv** - 10-100x faster dependency management
|
||||
- ⚡ **Ruff** - 10-100x faster linting & formatting
|
||||
- 🔍 **mypy** - Static type checking
|
||||
- 🔍 **Pyright** - Static type checking (strict mode)
|
||||
- ✅ **pytest** - Comprehensive test suite
|
||||
- 🔒 **pip-audit** - Dependency vulnerability scanning
|
||||
- 🔑 **detect-secrets** - Hardcoded secrets detection
|
||||
- 📜 **pip-licenses** - License compliance checking
|
||||
- 🪝 **pre-commit** - Automated git hooks
|
||||
|
||||
**All configured in a single `pyproject.toml` file!**
|
||||
|
||||
@@ -40,6 +40,7 @@ def include_object(object, name, type_, reflected, compare_to):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
if config.config_file_name is not None:
|
||||
|
||||
@@ -1,262 +1,446 @@
|
||||
"""initial models
|
||||
|
||||
Revision ID: 0001
|
||||
Revises:
|
||||
Revises:
|
||||
Create Date: 2025-11-27 09:08:09.464506
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '0001'
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
revision: str = "0001"
|
||||
down_revision: str | None = None
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('oauth_states',
|
||||
sa.Column('state', sa.String(length=255), nullable=False),
|
||||
sa.Column('code_verifier', sa.String(length=128), nullable=True),
|
||||
sa.Column('nonce', sa.String(length=255), nullable=True),
|
||||
sa.Column('provider', sa.String(length=50), nullable=False),
|
||||
sa.Column('redirect_uri', sa.String(length=500), nullable=True),
|
||||
sa.Column('user_id', sa.UUID(), nullable=True),
|
||||
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
op.create_table(
|
||||
"oauth_states",
|
||||
sa.Column("state", sa.String(length=255), nullable=False),
|
||||
sa.Column("code_verifier", sa.String(length=128), nullable=True),
|
||||
sa.Column("nonce", sa.String(length=255), nullable=True),
|
||||
sa.Column("provider", sa.String(length=50), nullable=False),
|
||||
sa.Column("redirect_uri", sa.String(length=500), nullable=True),
|
||||
sa.Column("user_id", sa.UUID(), nullable=True),
|
||||
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(op.f('ix_oauth_states_state'), 'oauth_states', ['state'], unique=True)
|
||||
op.create_table('organizations',
|
||||
sa.Column('name', sa.String(length=255), nullable=False),
|
||||
sa.Column('slug', sa.String(length=255), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||
sa.Column('settings', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
op.create_index(
|
||||
op.f("ix_oauth_states_state"), "oauth_states", ["state"], unique=True
|
||||
)
|
||||
op.create_index(op.f('ix_organizations_is_active'), 'organizations', ['is_active'], unique=False)
|
||||
op.create_index(op.f('ix_organizations_name'), 'organizations', ['name'], unique=False)
|
||||
op.create_index('ix_organizations_name_active', 'organizations', ['name', 'is_active'], unique=False)
|
||||
op.create_index(op.f('ix_organizations_slug'), 'organizations', ['slug'], unique=True)
|
||||
op.create_index('ix_organizations_slug_active', 'organizations', ['slug', 'is_active'], unique=False)
|
||||
op.create_table('users',
|
||||
sa.Column('email', sa.String(length=255), nullable=False),
|
||||
sa.Column('password_hash', sa.String(length=255), nullable=True),
|
||||
sa.Column('first_name', sa.String(length=100), nullable=False),
|
||||
sa.Column('last_name', sa.String(length=100), nullable=True),
|
||||
sa.Column('phone_number', sa.String(length=20), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||
sa.Column('is_superuser', sa.Boolean(), nullable=False),
|
||||
sa.Column('preferences', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column('locale', sa.String(length=10), nullable=True),
|
||||
sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
op.create_table(
|
||||
"organizations",
|
||||
sa.Column("name", sa.String(length=255), nullable=False),
|
||||
sa.Column("slug", sa.String(length=255), nullable=False),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=False),
|
||||
sa.Column("settings", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(op.f('ix_users_deleted_at'), 'users', ['deleted_at'], unique=False)
|
||||
op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=True)
|
||||
op.create_index(op.f('ix_users_is_active'), 'users', ['is_active'], unique=False)
|
||||
op.create_index(op.f('ix_users_is_superuser'), 'users', ['is_superuser'], unique=False)
|
||||
op.create_index(op.f('ix_users_locale'), 'users', ['locale'], unique=False)
|
||||
op.create_table('oauth_accounts',
|
||||
sa.Column('user_id', sa.UUID(), nullable=False),
|
||||
sa.Column('provider', sa.String(length=50), nullable=False),
|
||||
sa.Column('provider_user_id', sa.String(length=255), nullable=False),
|
||||
sa.Column('provider_email', sa.String(length=255), nullable=True),
|
||||
sa.Column('access_token_encrypted', sa.String(length=2048), nullable=True),
|
||||
sa.Column('refresh_token_encrypted', sa.String(length=2048), nullable=True),
|
||||
sa.Column('token_expires_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('provider', 'provider_user_id', name='uq_oauth_provider_user')
|
||||
op.create_index(
|
||||
op.f("ix_organizations_is_active"), "organizations", ["is_active"], unique=False
|
||||
)
|
||||
op.create_index(op.f('ix_oauth_accounts_provider'), 'oauth_accounts', ['provider'], unique=False)
|
||||
op.create_index(op.f('ix_oauth_accounts_provider_email'), 'oauth_accounts', ['provider_email'], unique=False)
|
||||
op.create_index(op.f('ix_oauth_accounts_user_id'), 'oauth_accounts', ['user_id'], unique=False)
|
||||
op.create_index('ix_oauth_accounts_user_provider', 'oauth_accounts', ['user_id', 'provider'], unique=False)
|
||||
op.create_table('oauth_clients',
|
||||
sa.Column('client_id', sa.String(length=64), nullable=False),
|
||||
sa.Column('client_secret_hash', sa.String(length=255), nullable=True),
|
||||
sa.Column('client_name', sa.String(length=255), nullable=False),
|
||||
sa.Column('client_description', sa.String(length=1000), nullable=True),
|
||||
sa.Column('client_type', sa.String(length=20), nullable=False),
|
||||
sa.Column('redirect_uris', postgresql.JSONB(astext_type=sa.Text()), nullable=False),
|
||||
sa.Column('allowed_scopes', postgresql.JSONB(astext_type=sa.Text()), nullable=False),
|
||||
sa.Column('access_token_lifetime', sa.String(length=10), nullable=False),
|
||||
sa.Column('refresh_token_lifetime', sa.String(length=10), nullable=False),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||
sa.Column('owner_user_id', sa.UUID(), nullable=True),
|
||||
sa.Column('mcp_server_url', sa.String(length=2048), nullable=True),
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(['owner_user_id'], ['users.id'], ondelete='SET NULL'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
op.create_index(
|
||||
op.f("ix_organizations_name"), "organizations", ["name"], unique=False
|
||||
)
|
||||
op.create_index(op.f('ix_oauth_clients_client_id'), 'oauth_clients', ['client_id'], unique=True)
|
||||
op.create_index(op.f('ix_oauth_clients_is_active'), 'oauth_clients', ['is_active'], unique=False)
|
||||
op.create_table('user_organizations',
|
||||
sa.Column('user_id', sa.UUID(), nullable=False),
|
||||
sa.Column('organization_id', sa.UUID(), nullable=False),
|
||||
sa.Column('role', sa.Enum('OWNER', 'ADMIN', 'MEMBER', 'GUEST', name='organizationrole'), nullable=False),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||
sa.Column('custom_permissions', sa.String(length=500), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('user_id', 'organization_id')
|
||||
op.create_index(
|
||||
"ix_organizations_name_active",
|
||||
"organizations",
|
||||
["name", "is_active"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index('ix_user_org_org_active', 'user_organizations', ['organization_id', 'is_active'], unique=False)
|
||||
op.create_index('ix_user_org_role', 'user_organizations', ['role'], unique=False)
|
||||
op.create_index('ix_user_org_user_active', 'user_organizations', ['user_id', 'is_active'], unique=False)
|
||||
op.create_index(op.f('ix_user_organizations_is_active'), 'user_organizations', ['is_active'], unique=False)
|
||||
op.create_table('user_sessions',
|
||||
sa.Column('user_id', sa.UUID(), nullable=False),
|
||||
sa.Column('refresh_token_jti', sa.String(length=255), nullable=False),
|
||||
sa.Column('device_name', sa.String(length=255), nullable=True),
|
||||
sa.Column('device_id', sa.String(length=255), nullable=True),
|
||||
sa.Column('ip_address', sa.String(length=45), nullable=True),
|
||||
sa.Column('user_agent', sa.String(length=500), nullable=True),
|
||||
sa.Column('last_used_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||
sa.Column('location_city', sa.String(length=100), nullable=True),
|
||||
sa.Column('location_country', sa.String(length=100), nullable=True),
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
op.create_index(
|
||||
op.f("ix_organizations_slug"), "organizations", ["slug"], unique=True
|
||||
)
|
||||
op.create_index(op.f('ix_user_sessions_is_active'), 'user_sessions', ['is_active'], unique=False)
|
||||
op.create_index('ix_user_sessions_jti_active', 'user_sessions', ['refresh_token_jti', 'is_active'], unique=False)
|
||||
op.create_index(op.f('ix_user_sessions_refresh_token_jti'), 'user_sessions', ['refresh_token_jti'], unique=True)
|
||||
op.create_index('ix_user_sessions_user_active', 'user_sessions', ['user_id', 'is_active'], unique=False)
|
||||
op.create_index(op.f('ix_user_sessions_user_id'), 'user_sessions', ['user_id'], unique=False)
|
||||
op.create_table('oauth_authorization_codes',
|
||||
sa.Column('code', sa.String(length=128), nullable=False),
|
||||
sa.Column('client_id', sa.String(length=64), nullable=False),
|
||||
sa.Column('user_id', sa.UUID(), nullable=False),
|
||||
sa.Column('redirect_uri', sa.String(length=2048), nullable=False),
|
||||
sa.Column('scope', sa.String(length=1000), nullable=False),
|
||||
sa.Column('code_challenge', sa.String(length=128), nullable=True),
|
||||
sa.Column('code_challenge_method', sa.String(length=10), nullable=True),
|
||||
sa.Column('state', sa.String(length=256), nullable=True),
|
||||
sa.Column('nonce', sa.String(length=256), nullable=True),
|
||||
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('used', sa.Boolean(), nullable=False),
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(['client_id'], ['oauth_clients.client_id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
op.create_index(
|
||||
"ix_organizations_slug_active",
|
||||
"organizations",
|
||||
["slug", "is_active"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index('ix_oauth_authorization_codes_client_user', 'oauth_authorization_codes', ['client_id', 'user_id'], unique=False)
|
||||
op.create_index(op.f('ix_oauth_authorization_codes_code'), 'oauth_authorization_codes', ['code'], unique=True)
|
||||
op.create_index('ix_oauth_authorization_codes_expires_at', 'oauth_authorization_codes', ['expires_at'], unique=False)
|
||||
op.create_table('oauth_consents',
|
||||
sa.Column('user_id', sa.UUID(), nullable=False),
|
||||
sa.Column('client_id', sa.String(length=64), nullable=False),
|
||||
sa.Column('granted_scopes', sa.String(length=1000), nullable=False),
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(['client_id'], ['oauth_clients.client_id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
op.create_table(
|
||||
"users",
|
||||
sa.Column("email", sa.String(length=255), nullable=False),
|
||||
sa.Column("password_hash", sa.String(length=255), nullable=True),
|
||||
sa.Column("first_name", sa.String(length=100), nullable=False),
|
||||
sa.Column("last_name", sa.String(length=100), nullable=True),
|
||||
sa.Column("phone_number", sa.String(length=20), nullable=True),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=False),
|
||||
sa.Column("is_superuser", sa.Boolean(), nullable=False),
|
||||
sa.Column(
|
||||
"preferences", postgresql.JSONB(astext_type=sa.Text()), nullable=True
|
||||
),
|
||||
sa.Column("locale", sa.String(length=10), nullable=True),
|
||||
sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index('ix_oauth_consents_user_client', 'oauth_consents', ['user_id', 'client_id'], unique=True)
|
||||
op.create_table('oauth_provider_refresh_tokens',
|
||||
sa.Column('token_hash', sa.String(length=64), nullable=False),
|
||||
sa.Column('jti', sa.String(length=64), nullable=False),
|
||||
sa.Column('client_id', sa.String(length=64), nullable=False),
|
||||
sa.Column('user_id', sa.UUID(), nullable=False),
|
||||
sa.Column('scope', sa.String(length=1000), nullable=False),
|
||||
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('revoked', sa.Boolean(), nullable=False),
|
||||
sa.Column('last_used_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('device_info', sa.String(length=500), nullable=True),
|
||||
sa.Column('ip_address', sa.String(length=45), nullable=True),
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(['client_id'], ['oauth_clients.client_id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
op.create_index(op.f("ix_users_deleted_at"), "users", ["deleted_at"], unique=False)
|
||||
op.create_index(op.f("ix_users_email"), "users", ["email"], unique=True)
|
||||
op.create_index(op.f("ix_users_is_active"), "users", ["is_active"], unique=False)
|
||||
op.create_index(
|
||||
op.f("ix_users_is_superuser"), "users", ["is_superuser"], unique=False
|
||||
)
|
||||
op.create_index(op.f("ix_users_locale"), "users", ["locale"], unique=False)
|
||||
op.create_table(
|
||||
"oauth_accounts",
|
||||
sa.Column("user_id", sa.UUID(), nullable=False),
|
||||
sa.Column("provider", sa.String(length=50), nullable=False),
|
||||
sa.Column("provider_user_id", sa.String(length=255), nullable=False),
|
||||
sa.Column("provider_email", sa.String(length=255), nullable=True),
|
||||
sa.Column("access_token_encrypted", sa.String(length=2048), nullable=True),
|
||||
sa.Column("refresh_token_encrypted", sa.String(length=2048), nullable=True),
|
||||
sa.Column("token_expires_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint(
|
||||
"provider", "provider_user_id", name="uq_oauth_provider_user"
|
||||
),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_oauth_accounts_provider"), "oauth_accounts", ["provider"], unique=False
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_oauth_accounts_provider_email"),
|
||||
"oauth_accounts",
|
||||
["provider_email"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_oauth_accounts_user_id"), "oauth_accounts", ["user_id"], unique=False
|
||||
)
|
||||
op.create_index(
|
||||
"ix_oauth_accounts_user_provider",
|
||||
"oauth_accounts",
|
||||
["user_id", "provider"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_table(
|
||||
"oauth_clients",
|
||||
sa.Column("client_id", sa.String(length=64), nullable=False),
|
||||
sa.Column("client_secret_hash", sa.String(length=255), nullable=True),
|
||||
sa.Column("client_name", sa.String(length=255), nullable=False),
|
||||
sa.Column("client_description", sa.String(length=1000), nullable=True),
|
||||
sa.Column("client_type", sa.String(length=20), nullable=False),
|
||||
sa.Column(
|
||||
"redirect_uris", postgresql.JSONB(astext_type=sa.Text()), nullable=False
|
||||
),
|
||||
sa.Column(
|
||||
"allowed_scopes", postgresql.JSONB(astext_type=sa.Text()), nullable=False
|
||||
),
|
||||
sa.Column("access_token_lifetime", sa.String(length=10), nullable=False),
|
||||
sa.Column("refresh_token_lifetime", sa.String(length=10), nullable=False),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=False),
|
||||
sa.Column("owner_user_id", sa.UUID(), nullable=True),
|
||||
sa.Column("mcp_server_url", sa.String(length=2048), nullable=True),
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(["owner_user_id"], ["users.id"], ondelete="SET NULL"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_oauth_clients_client_id"), "oauth_clients", ["client_id"], unique=True
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_oauth_clients_is_active"), "oauth_clients", ["is_active"], unique=False
|
||||
)
|
||||
op.create_table(
|
||||
"user_organizations",
|
||||
sa.Column("user_id", sa.UUID(), nullable=False),
|
||||
sa.Column("organization_id", sa.UUID(), nullable=False),
|
||||
sa.Column(
|
||||
"role",
|
||||
sa.Enum("OWNER", "ADMIN", "MEMBER", "GUEST", name="organizationrole"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=False),
|
||||
sa.Column("custom_permissions", sa.String(length=500), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["organization_id"], ["organizations.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("user_id", "organization_id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_user_org_org_active",
|
||||
"user_organizations",
|
||||
["organization_id", "is_active"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index("ix_user_org_role", "user_organizations", ["role"], unique=False)
|
||||
op.create_index(
|
||||
"ix_user_org_user_active",
|
||||
"user_organizations",
|
||||
["user_id", "is_active"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_user_organizations_is_active"),
|
||||
"user_organizations",
|
||||
["is_active"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_table(
|
||||
"user_sessions",
|
||||
sa.Column("user_id", sa.UUID(), nullable=False),
|
||||
sa.Column("refresh_token_jti", sa.String(length=255), nullable=False),
|
||||
sa.Column("device_name", sa.String(length=255), nullable=True),
|
||||
sa.Column("device_id", sa.String(length=255), nullable=True),
|
||||
sa.Column("ip_address", sa.String(length=45), nullable=True),
|
||||
sa.Column("user_agent", sa.String(length=500), nullable=True),
|
||||
sa.Column("last_used_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=False),
|
||||
sa.Column("location_city", sa.String(length=100), nullable=True),
|
||||
sa.Column("location_country", sa.String(length=100), nullable=True),
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_user_sessions_is_active"), "user_sessions", ["is_active"], unique=False
|
||||
)
|
||||
op.create_index(
|
||||
"ix_user_sessions_jti_active",
|
||||
"user_sessions",
|
||||
["refresh_token_jti", "is_active"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_user_sessions_refresh_token_jti"),
|
||||
"user_sessions",
|
||||
["refresh_token_jti"],
|
||||
unique=True,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_user_sessions_user_active",
|
||||
"user_sessions",
|
||||
["user_id", "is_active"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_user_sessions_user_id"), "user_sessions", ["user_id"], unique=False
|
||||
)
|
||||
op.create_table(
|
||||
"oauth_authorization_codes",
|
||||
sa.Column("code", sa.String(length=128), nullable=False),
|
||||
sa.Column("client_id", sa.String(length=64), nullable=False),
|
||||
sa.Column("user_id", sa.UUID(), nullable=False),
|
||||
sa.Column("redirect_uri", sa.String(length=2048), nullable=False),
|
||||
sa.Column("scope", sa.String(length=1000), nullable=False),
|
||||
sa.Column("code_challenge", sa.String(length=128), nullable=True),
|
||||
sa.Column("code_challenge_method", sa.String(length=10), nullable=True),
|
||||
sa.Column("state", sa.String(length=256), nullable=True),
|
||||
sa.Column("nonce", sa.String(length=256), nullable=True),
|
||||
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("used", sa.Boolean(), nullable=False),
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["client_id"], ["oauth_clients.client_id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_oauth_authorization_codes_client_user",
|
||||
"oauth_authorization_codes",
|
||||
["client_id", "user_id"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_oauth_authorization_codes_code"),
|
||||
"oauth_authorization_codes",
|
||||
["code"],
|
||||
unique=True,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_oauth_authorization_codes_expires_at",
|
||||
"oauth_authorization_codes",
|
||||
["expires_at"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_table(
|
||||
"oauth_consents",
|
||||
sa.Column("user_id", sa.UUID(), nullable=False),
|
||||
sa.Column("client_id", sa.String(length=64), nullable=False),
|
||||
sa.Column("granted_scopes", sa.String(length=1000), nullable=False),
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["client_id"], ["oauth_clients.client_id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_oauth_consents_user_client",
|
||||
"oauth_consents",
|
||||
["user_id", "client_id"],
|
||||
unique=True,
|
||||
)
|
||||
op.create_table(
|
||||
"oauth_provider_refresh_tokens",
|
||||
sa.Column("token_hash", sa.String(length=64), nullable=False),
|
||||
sa.Column("jti", sa.String(length=64), nullable=False),
|
||||
sa.Column("client_id", sa.String(length=64), nullable=False),
|
||||
sa.Column("user_id", sa.UUID(), nullable=False),
|
||||
sa.Column("scope", sa.String(length=1000), nullable=False),
|
||||
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("revoked", sa.Boolean(), nullable=False),
|
||||
sa.Column("last_used_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("device_info", sa.String(length=500), nullable=True),
|
||||
sa.Column("ip_address", sa.String(length=45), nullable=True),
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["client_id"], ["oauth_clients.client_id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_oauth_provider_refresh_tokens_client_user",
|
||||
"oauth_provider_refresh_tokens",
|
||||
["client_id", "user_id"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_oauth_provider_refresh_tokens_expires_at",
|
||||
"oauth_provider_refresh_tokens",
|
||||
["expires_at"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_oauth_provider_refresh_tokens_jti"),
|
||||
"oauth_provider_refresh_tokens",
|
||||
["jti"],
|
||||
unique=True,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_oauth_provider_refresh_tokens_revoked"),
|
||||
"oauth_provider_refresh_tokens",
|
||||
["revoked"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_oauth_provider_refresh_tokens_token_hash"),
|
||||
"oauth_provider_refresh_tokens",
|
||||
["token_hash"],
|
||||
unique=True,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_oauth_provider_refresh_tokens_user_revoked",
|
||||
"oauth_provider_refresh_tokens",
|
||||
["user_id", "revoked"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index('ix_oauth_provider_refresh_tokens_client_user', 'oauth_provider_refresh_tokens', ['client_id', 'user_id'], unique=False)
|
||||
op.create_index('ix_oauth_provider_refresh_tokens_expires_at', 'oauth_provider_refresh_tokens', ['expires_at'], unique=False)
|
||||
op.create_index(op.f('ix_oauth_provider_refresh_tokens_jti'), 'oauth_provider_refresh_tokens', ['jti'], unique=True)
|
||||
op.create_index(op.f('ix_oauth_provider_refresh_tokens_revoked'), 'oauth_provider_refresh_tokens', ['revoked'], unique=False)
|
||||
op.create_index(op.f('ix_oauth_provider_refresh_tokens_token_hash'), 'oauth_provider_refresh_tokens', ['token_hash'], unique=True)
|
||||
op.create_index('ix_oauth_provider_refresh_tokens_user_revoked', 'oauth_provider_refresh_tokens', ['user_id', 'revoked'], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index('ix_oauth_provider_refresh_tokens_user_revoked', table_name='oauth_provider_refresh_tokens')
|
||||
op.drop_index(op.f('ix_oauth_provider_refresh_tokens_token_hash'), table_name='oauth_provider_refresh_tokens')
|
||||
op.drop_index(op.f('ix_oauth_provider_refresh_tokens_revoked'), table_name='oauth_provider_refresh_tokens')
|
||||
op.drop_index(op.f('ix_oauth_provider_refresh_tokens_jti'), table_name='oauth_provider_refresh_tokens')
|
||||
op.drop_index('ix_oauth_provider_refresh_tokens_expires_at', table_name='oauth_provider_refresh_tokens')
|
||||
op.drop_index('ix_oauth_provider_refresh_tokens_client_user', table_name='oauth_provider_refresh_tokens')
|
||||
op.drop_table('oauth_provider_refresh_tokens')
|
||||
op.drop_index('ix_oauth_consents_user_client', table_name='oauth_consents')
|
||||
op.drop_table('oauth_consents')
|
||||
op.drop_index('ix_oauth_authorization_codes_expires_at', table_name='oauth_authorization_codes')
|
||||
op.drop_index(op.f('ix_oauth_authorization_codes_code'), table_name='oauth_authorization_codes')
|
||||
op.drop_index('ix_oauth_authorization_codes_client_user', table_name='oauth_authorization_codes')
|
||||
op.drop_table('oauth_authorization_codes')
|
||||
op.drop_index(op.f('ix_user_sessions_user_id'), table_name='user_sessions')
|
||||
op.drop_index('ix_user_sessions_user_active', table_name='user_sessions')
|
||||
op.drop_index(op.f('ix_user_sessions_refresh_token_jti'), table_name='user_sessions')
|
||||
op.drop_index('ix_user_sessions_jti_active', table_name='user_sessions')
|
||||
op.drop_index(op.f('ix_user_sessions_is_active'), table_name='user_sessions')
|
||||
op.drop_table('user_sessions')
|
||||
op.drop_index(op.f('ix_user_organizations_is_active'), table_name='user_organizations')
|
||||
op.drop_index('ix_user_org_user_active', table_name='user_organizations')
|
||||
op.drop_index('ix_user_org_role', table_name='user_organizations')
|
||||
op.drop_index('ix_user_org_org_active', table_name='user_organizations')
|
||||
op.drop_table('user_organizations')
|
||||
op.drop_index(op.f('ix_oauth_clients_is_active'), table_name='oauth_clients')
|
||||
op.drop_index(op.f('ix_oauth_clients_client_id'), table_name='oauth_clients')
|
||||
op.drop_table('oauth_clients')
|
||||
op.drop_index('ix_oauth_accounts_user_provider', table_name='oauth_accounts')
|
||||
op.drop_index(op.f('ix_oauth_accounts_user_id'), table_name='oauth_accounts')
|
||||
op.drop_index(op.f('ix_oauth_accounts_provider_email'), table_name='oauth_accounts')
|
||||
op.drop_index(op.f('ix_oauth_accounts_provider'), table_name='oauth_accounts')
|
||||
op.drop_table('oauth_accounts')
|
||||
op.drop_index(op.f('ix_users_locale'), table_name='users')
|
||||
op.drop_index(op.f('ix_users_is_superuser'), table_name='users')
|
||||
op.drop_index(op.f('ix_users_is_active'), table_name='users')
|
||||
op.drop_index(op.f('ix_users_email'), table_name='users')
|
||||
op.drop_index(op.f('ix_users_deleted_at'), table_name='users')
|
||||
op.drop_table('users')
|
||||
op.drop_index('ix_organizations_slug_active', table_name='organizations')
|
||||
op.drop_index(op.f('ix_organizations_slug'), table_name='organizations')
|
||||
op.drop_index('ix_organizations_name_active', table_name='organizations')
|
||||
op.drop_index(op.f('ix_organizations_name'), table_name='organizations')
|
||||
op.drop_index(op.f('ix_organizations_is_active'), table_name='organizations')
|
||||
op.drop_table('organizations')
|
||||
op.drop_index(op.f('ix_oauth_states_state'), table_name='oauth_states')
|
||||
op.drop_table('oauth_states')
|
||||
op.drop_index(
|
||||
"ix_oauth_provider_refresh_tokens_user_revoked",
|
||||
table_name="oauth_provider_refresh_tokens",
|
||||
)
|
||||
op.drop_index(
|
||||
op.f("ix_oauth_provider_refresh_tokens_token_hash"),
|
||||
table_name="oauth_provider_refresh_tokens",
|
||||
)
|
||||
op.drop_index(
|
||||
op.f("ix_oauth_provider_refresh_tokens_revoked"),
|
||||
table_name="oauth_provider_refresh_tokens",
|
||||
)
|
||||
op.drop_index(
|
||||
op.f("ix_oauth_provider_refresh_tokens_jti"),
|
||||
table_name="oauth_provider_refresh_tokens",
|
||||
)
|
||||
op.drop_index(
|
||||
"ix_oauth_provider_refresh_tokens_expires_at",
|
||||
table_name="oauth_provider_refresh_tokens",
|
||||
)
|
||||
op.drop_index(
|
||||
"ix_oauth_provider_refresh_tokens_client_user",
|
||||
table_name="oauth_provider_refresh_tokens",
|
||||
)
|
||||
op.drop_table("oauth_provider_refresh_tokens")
|
||||
op.drop_index("ix_oauth_consents_user_client", table_name="oauth_consents")
|
||||
op.drop_table("oauth_consents")
|
||||
op.drop_index(
|
||||
"ix_oauth_authorization_codes_expires_at",
|
||||
table_name="oauth_authorization_codes",
|
||||
)
|
||||
op.drop_index(
|
||||
op.f("ix_oauth_authorization_codes_code"),
|
||||
table_name="oauth_authorization_codes",
|
||||
)
|
||||
op.drop_index(
|
||||
"ix_oauth_authorization_codes_client_user",
|
||||
table_name="oauth_authorization_codes",
|
||||
)
|
||||
op.drop_table("oauth_authorization_codes")
|
||||
op.drop_index(op.f("ix_user_sessions_user_id"), table_name="user_sessions")
|
||||
op.drop_index("ix_user_sessions_user_active", table_name="user_sessions")
|
||||
op.drop_index(
|
||||
op.f("ix_user_sessions_refresh_token_jti"), table_name="user_sessions"
|
||||
)
|
||||
op.drop_index("ix_user_sessions_jti_active", table_name="user_sessions")
|
||||
op.drop_index(op.f("ix_user_sessions_is_active"), table_name="user_sessions")
|
||||
op.drop_table("user_sessions")
|
||||
op.drop_index(
|
||||
op.f("ix_user_organizations_is_active"), table_name="user_organizations"
|
||||
)
|
||||
op.drop_index("ix_user_org_user_active", table_name="user_organizations")
|
||||
op.drop_index("ix_user_org_role", table_name="user_organizations")
|
||||
op.drop_index("ix_user_org_org_active", table_name="user_organizations")
|
||||
op.drop_table("user_organizations")
|
||||
op.drop_index(op.f("ix_oauth_clients_is_active"), table_name="oauth_clients")
|
||||
op.drop_index(op.f("ix_oauth_clients_client_id"), table_name="oauth_clients")
|
||||
op.drop_table("oauth_clients")
|
||||
op.drop_index("ix_oauth_accounts_user_provider", table_name="oauth_accounts")
|
||||
op.drop_index(op.f("ix_oauth_accounts_user_id"), table_name="oauth_accounts")
|
||||
op.drop_index(op.f("ix_oauth_accounts_provider_email"), table_name="oauth_accounts")
|
||||
op.drop_index(op.f("ix_oauth_accounts_provider"), table_name="oauth_accounts")
|
||||
op.drop_table("oauth_accounts")
|
||||
op.drop_index(op.f("ix_users_locale"), table_name="users")
|
||||
op.drop_index(op.f("ix_users_is_superuser"), table_name="users")
|
||||
op.drop_index(op.f("ix_users_is_active"), table_name="users")
|
||||
op.drop_index(op.f("ix_users_email"), table_name="users")
|
||||
op.drop_index(op.f("ix_users_deleted_at"), table_name="users")
|
||||
op.drop_table("users")
|
||||
op.drop_index("ix_organizations_slug_active", table_name="organizations")
|
||||
op.drop_index(op.f("ix_organizations_slug"), table_name="organizations")
|
||||
op.drop_index("ix_organizations_name_active", table_name="organizations")
|
||||
op.drop_index(op.f("ix_organizations_name"), table_name="organizations")
|
||||
op.drop_index(op.f("ix_organizations_is_active"), table_name="organizations")
|
||||
op.drop_table("organizations")
|
||||
op.drop_index(op.f("ix_oauth_states_state"), table_name="oauth_states")
|
||||
op.drop_table("oauth_states")
|
||||
# ### end Alembic commands ###
|
||||
|
||||
@@ -114,8 +114,13 @@ def upgrade() -> None:
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop indexes in reverse order
|
||||
op.drop_index("ix_perf_oauth_auth_codes_expires", table_name="oauth_authorization_codes")
|
||||
op.drop_index("ix_perf_oauth_refresh_tokens_expires", table_name="oauth_provider_refresh_tokens")
|
||||
op.drop_index(
|
||||
"ix_perf_oauth_auth_codes_expires", table_name="oauth_authorization_codes"
|
||||
)
|
||||
op.drop_index(
|
||||
"ix_perf_oauth_refresh_tokens_expires",
|
||||
table_name="oauth_provider_refresh_tokens",
|
||||
)
|
||||
op.drop_index("ix_perf_user_sessions_expires", table_name="user_sessions")
|
||||
op.drop_index("ix_perf_organizations_slug_lower", table_name="organizations")
|
||||
op.drop_index("ix_perf_users_active", table_name="users")
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
"""rename oauth account token fields drop encrypted suffix
|
||||
|
||||
Revision ID: 0003
|
||||
Revises: 0002
|
||||
Create Date: 2026-02-27 01:03:18.869178
|
||||
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "0003"
|
||||
down_revision: str | None = "0002"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column(
|
||||
"oauth_accounts", "access_token_encrypted", new_column_name="access_token"
|
||||
)
|
||||
op.alter_column(
|
||||
"oauth_accounts", "refresh_token_encrypted", new_column_name="refresh_token"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.alter_column(
|
||||
"oauth_accounts", "access_token", new_column_name="access_token_encrypted"
|
||||
)
|
||||
op.alter_column(
|
||||
"oauth_accounts", "refresh_token", new_column_name="refresh_token_encrypted"
|
||||
)
|
||||
@@ -1,12 +1,12 @@
|
||||
from fastapi import Depends, Header, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from fastapi.security.utils import get_authorization_scheme_param
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.auth import TokenExpiredError, TokenInvalidError, get_token_data
|
||||
from app.core.database import get_db
|
||||
from app.models.user import User
|
||||
from app.repositories.user import user_repo
|
||||
|
||||
# OAuth2 configuration
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||
@@ -32,9 +32,8 @@ async def get_current_user(
|
||||
# Decode token and get user ID
|
||||
token_data = get_token_data(token)
|
||||
|
||||
# Get user from database
|
||||
result = await db.execute(select(User).where(User.id == token_data.user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
# Get user from database via repository
|
||||
user = await user_repo.get(db, id=str(token_data.user_id))
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
@@ -144,8 +143,7 @@ async def get_optional_current_user(
|
||||
|
||||
try:
|
||||
token_data = get_token_data(token)
|
||||
result = await db.execute(select(User).where(User.id == token_data.user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
user = await user_repo.get(db, id=str(token_data.user_id))
|
||||
if not user or not user.is_active:
|
||||
return None
|
||||
return user
|
||||
|
||||
@@ -15,9 +15,9 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.dependencies.auth import get_current_user
|
||||
from app.core.database import get_db
|
||||
from app.crud.organization import organization as organization_crud
|
||||
from app.models.user import User
|
||||
from app.models.user_organization import OrganizationRole
|
||||
from app.services.organization_service import organization_service
|
||||
|
||||
|
||||
def require_superuser(current_user: User = Depends(get_current_user)) -> User:
|
||||
@@ -81,7 +81,7 @@ class OrganizationPermission:
|
||||
return current_user
|
||||
|
||||
# Get user's role in organization
|
||||
user_role = await organization_crud.get_user_role_in_org(
|
||||
user_role = await organization_service.get_user_role_in_org(
|
||||
db, user_id=current_user.id, organization_id=organization_id
|
||||
)
|
||||
|
||||
@@ -123,7 +123,7 @@ async def require_org_membership(
|
||||
if current_user.is_superuser:
|
||||
return current_user
|
||||
|
||||
user_role = await organization_crud.get_user_role_in_org(
|
||||
user_role = await organization_service.get_user_role_in_org(
|
||||
db, user_id=current_user.id, organization_id=organization_id
|
||||
)
|
||||
|
||||
|
||||
41
backend/app/api/dependencies/services.py
Normal file
41
backend/app/api/dependencies/services.py
Normal file
@@ -0,0 +1,41 @@
|
||||
# app/api/dependencies/services.py
|
||||
"""FastAPI dependency functions for service singletons."""
|
||||
|
||||
from app.services import oauth_provider_service
|
||||
from app.services.auth_service import AuthService
|
||||
from app.services.oauth_service import OAuthService
|
||||
from app.services.organization_service import OrganizationService, organization_service
|
||||
from app.services.session_service import SessionService, session_service
|
||||
from app.services.user_service import UserService, user_service
|
||||
|
||||
|
||||
def get_auth_service() -> AuthService:
|
||||
"""Return the AuthService singleton for dependency injection."""
|
||||
from app.services.auth_service import AuthService as _AuthService
|
||||
|
||||
return _AuthService()
|
||||
|
||||
|
||||
def get_user_service() -> UserService:
|
||||
"""Return the UserService singleton for dependency injection."""
|
||||
return user_service
|
||||
|
||||
|
||||
def get_organization_service() -> OrganizationService:
|
||||
"""Return the OrganizationService singleton for dependency injection."""
|
||||
return organization_service
|
||||
|
||||
|
||||
def get_session_service() -> SessionService:
|
||||
"""Return the SessionService singleton for dependency injection."""
|
||||
return session_service
|
||||
|
||||
|
||||
def get_oauth_service() -> OAuthService:
|
||||
"""Return OAuthService for dependency injection."""
|
||||
return OAuthService()
|
||||
|
||||
|
||||
def get_oauth_provider_service():
|
||||
"""Return the oauth_provider_service module for dependency injection."""
|
||||
return oauth_provider_service
|
||||
@@ -14,7 +14,6 @@ from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, status
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.dependencies.permissions import require_superuser
|
||||
@@ -25,12 +24,9 @@ from app.core.exceptions import (
|
||||
ErrorCode,
|
||||
NotFoundError,
|
||||
)
|
||||
from app.crud.organization import organization as organization_crud
|
||||
from app.crud.session import session as session_crud
|
||||
from app.crud.user import user as user_crud
|
||||
from app.models.organization import Organization
|
||||
from app.core.repository_exceptions import DuplicateEntryError
|
||||
from app.models.user import User
|
||||
from app.models.user_organization import OrganizationRole, UserOrganization
|
||||
from app.models.user_organization import OrganizationRole
|
||||
from app.schemas.common import (
|
||||
MessageResponse,
|
||||
PaginatedResponse,
|
||||
@@ -46,6 +42,9 @@ from app.schemas.organizations import (
|
||||
)
|
||||
from app.schemas.sessions import AdminSessionResponse
|
||||
from app.schemas.users import UserCreate, UserResponse, UserUpdate
|
||||
from app.services.organization_service import organization_service
|
||||
from app.services.session_service import session_service
|
||||
from app.services.user_service import user_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -66,7 +65,7 @@ class BulkUserAction(BaseModel):
|
||||
|
||||
action: BulkAction = Field(..., description="Action to perform on selected users")
|
||||
user_ids: list[UUID] = Field(
|
||||
..., min_items=1, max_items=100, description="List of user IDs (max 100)"
|
||||
..., min_length=1, max_length=100, description="List of user IDs (max 100)"
|
||||
)
|
||||
|
||||
|
||||
@@ -178,38 +177,29 @@ async def admin_get_stats(
|
||||
"""Get admin dashboard statistics with real data from database."""
|
||||
from app.core.config import settings
|
||||
|
||||
# Check if we have any data
|
||||
total_users_query = select(func.count()).select_from(User)
|
||||
total_users = (await db.execute(total_users_query)).scalar() or 0
|
||||
stats = await user_service.get_stats(db)
|
||||
total_users = stats["total_users"]
|
||||
active_count = stats["active_count"]
|
||||
inactive_count = stats["inactive_count"]
|
||||
all_users = stats["all_users"]
|
||||
|
||||
# If database is essentially empty (only admin user), return demo data
|
||||
if total_users <= 1 and settings.DEMO_MODE: # pragma: no cover
|
||||
logger.info("Returning demo stats data (empty database in demo mode)")
|
||||
return _generate_demo_stats()
|
||||
|
||||
# 1. User Growth (Last 30 days) - Improved calculation
|
||||
datetime.now(UTC) - timedelta(days=30)
|
||||
|
||||
# Get all users with their creation dates
|
||||
all_users_query = select(User).order_by(User.created_at)
|
||||
result = await db.execute(all_users_query)
|
||||
all_users = result.scalars().all()
|
||||
|
||||
# Build cumulative counts per day
|
||||
# 1. User Growth (Last 30 days)
|
||||
user_growth = []
|
||||
for i in range(29, -1, -1):
|
||||
date = datetime.now(UTC) - timedelta(days=i)
|
||||
date_start = date.replace(hour=0, minute=0, second=0, microsecond=0, tzinfo=UTC)
|
||||
date_end = date_start + timedelta(days=1)
|
||||
|
||||
# Count all users created before end of this day
|
||||
# Make comparison timezone-aware
|
||||
total_users_on_date = sum(
|
||||
1
|
||||
for u in all_users
|
||||
if u.created_at and u.created_at.replace(tzinfo=UTC) < date_end
|
||||
)
|
||||
# Count active users created before end of this day
|
||||
active_users_on_date = sum(
|
||||
1
|
||||
for u in all_users
|
||||
@@ -227,27 +217,16 @@ async def admin_get_stats(
|
||||
)
|
||||
|
||||
# 2. Organization Distribution - Top 6 organizations by member count
|
||||
org_query = (
|
||||
select(Organization.name, func.count(UserOrganization.user_id).label("count"))
|
||||
.join(UserOrganization, Organization.id == UserOrganization.organization_id)
|
||||
.group_by(Organization.name)
|
||||
.order_by(func.count(UserOrganization.user_id).desc())
|
||||
.limit(6)
|
||||
)
|
||||
result = await db.execute(org_query)
|
||||
org_dist = [
|
||||
OrgDistributionData(name=row.name, value=row.count) for row in result.all()
|
||||
]
|
||||
org_rows = await organization_service.get_org_distribution(db, limit=6)
|
||||
org_dist = [OrgDistributionData(name=r["name"], value=r["value"]) for r in org_rows]
|
||||
|
||||
# 3. User Registration Activity (Last 14 days) - NEW
|
||||
# 3. User Registration Activity (Last 14 days)
|
||||
registration_activity = []
|
||||
for i in range(13, -1, -1):
|
||||
date = datetime.now(UTC) - timedelta(days=i)
|
||||
date_start = date.replace(hour=0, minute=0, second=0, microsecond=0, tzinfo=UTC)
|
||||
date_end = date_start + timedelta(days=1)
|
||||
|
||||
# Count users created on this specific day
|
||||
# Make comparison timezone-aware
|
||||
day_registrations = sum(
|
||||
1
|
||||
for u in all_users
|
||||
@@ -263,16 +242,8 @@ async def admin_get_stats(
|
||||
)
|
||||
|
||||
# 4. User Status - Active vs Inactive
|
||||
active_query = select(func.count()).select_from(User).where(User.is_active)
|
||||
inactive_query = (
|
||||
select(func.count()).select_from(User).where(User.is_active.is_(False))
|
||||
)
|
||||
|
||||
active_count = (await db.execute(active_query)).scalar() or 0
|
||||
inactive_count = (await db.execute(inactive_query)).scalar() or 0
|
||||
|
||||
logger.info(
|
||||
f"User status counts - Active: {active_count}, Inactive: {inactive_count}"
|
||||
"User status counts - Active: %s, Inactive: %s", active_count, inactive_count
|
||||
)
|
||||
|
||||
user_status = [
|
||||
@@ -321,7 +292,7 @@ async def admin_list_users(
|
||||
filters["is_superuser"] = is_superuser
|
||||
|
||||
# Get users with search
|
||||
users, total = await user_crud.get_multi_with_total(
|
||||
users, total = await user_service.list_users(
|
||||
db,
|
||||
skip=pagination.offset,
|
||||
limit=pagination.limit,
|
||||
@@ -341,7 +312,7 @@ async def admin_list_users(
|
||||
return PaginatedResponse(data=users, pagination=pagination_meta)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing users (admin): {e!s}", exc_info=True)
|
||||
logger.exception("Error listing users (admin): %s", e)
|
||||
raise
|
||||
|
||||
|
||||
@@ -364,14 +335,14 @@ async def admin_create_user(
|
||||
Allows setting is_superuser and other fields.
|
||||
"""
|
||||
try:
|
||||
user = await user_crud.create(db, obj_in=user_in)
|
||||
logger.info(f"Admin {admin.email} created user {user.email}")
|
||||
user = await user_service.create_user(db, user_in)
|
||||
logger.info("Admin %s created user %s", admin.email, user.email)
|
||||
return user
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to create user: {e!s}")
|
||||
raise NotFoundError(message=str(e), error_code=ErrorCode.USER_ALREADY_EXISTS)
|
||||
except DuplicateEntryError as e:
|
||||
logger.warning("Failed to create user: %s", e)
|
||||
raise DuplicateError(message=str(e), error_code=ErrorCode.USER_ALREADY_EXISTS)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating user (admin): {e!s}", exc_info=True)
|
||||
logger.exception("Error creating user (admin): %s", e)
|
||||
raise
|
||||
|
||||
|
||||
@@ -388,11 +359,7 @@ async def admin_get_user(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""Get detailed information about a specific user."""
|
||||
user = await user_crud.get(db, id=user_id)
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
|
||||
)
|
||||
user = await user_service.get_user(db, str(user_id))
|
||||
return user
|
||||
|
||||
|
||||
@@ -411,20 +378,13 @@ async def admin_update_user(
|
||||
) -> Any:
|
||||
"""Update user information with admin privileges."""
|
||||
try:
|
||||
user = await user_crud.get(db, id=user_id)
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
|
||||
)
|
||||
|
||||
updated_user = await user_crud.update(db, db_obj=user, obj_in=user_in)
|
||||
logger.info(f"Admin {admin.email} updated user {updated_user.email}")
|
||||
user = await user_service.get_user(db, str(user_id))
|
||||
updated_user = await user_service.update_user(db, user=user, obj_in=user_in)
|
||||
logger.info("Admin %s updated user %s", admin.email, updated_user.email)
|
||||
return updated_user
|
||||
|
||||
except NotFoundError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating user (admin): {e!s}", exc_info=True)
|
||||
logger.exception("Error updating user (admin): %s", e)
|
||||
raise
|
||||
|
||||
|
||||
@@ -442,11 +402,7 @@ async def admin_delete_user(
|
||||
) -> Any:
|
||||
"""Soft delete a user (sets deleted_at timestamp)."""
|
||||
try:
|
||||
user = await user_crud.get(db, id=user_id)
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
|
||||
)
|
||||
user = await user_service.get_user(db, str(user_id))
|
||||
|
||||
# Prevent deleting yourself
|
||||
if user.id == admin.id:
|
||||
@@ -456,17 +412,15 @@ async def admin_delete_user(
|
||||
error_code=ErrorCode.OPERATION_FORBIDDEN,
|
||||
)
|
||||
|
||||
await user_crud.soft_delete(db, id=user_id)
|
||||
logger.info(f"Admin {admin.email} deleted user {user.email}")
|
||||
await user_service.soft_delete_user(db, str(user_id))
|
||||
logger.info("Admin %s deleted user %s", admin.email, user.email)
|
||||
|
||||
return MessageResponse(
|
||||
success=True, message=f"User {user.email} has been deleted"
|
||||
)
|
||||
|
||||
except NotFoundError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting user (admin): {e!s}", exc_info=True)
|
||||
logger.exception("Error deleting user (admin): %s", e)
|
||||
raise
|
||||
|
||||
|
||||
@@ -484,23 +438,16 @@ async def admin_activate_user(
|
||||
) -> Any:
|
||||
"""Activate a user account."""
|
||||
try:
|
||||
user = await user_crud.get(db, id=user_id)
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
|
||||
)
|
||||
|
||||
await user_crud.update(db, db_obj=user, obj_in={"is_active": True})
|
||||
logger.info(f"Admin {admin.email} activated user {user.email}")
|
||||
user = await user_service.get_user(db, str(user_id))
|
||||
await user_service.update_user(db, user=user, obj_in={"is_active": True})
|
||||
logger.info("Admin %s activated user %s", admin.email, user.email)
|
||||
|
||||
return MessageResponse(
|
||||
success=True, message=f"User {user.email} has been activated"
|
||||
)
|
||||
|
||||
except NotFoundError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error activating user (admin): {e!s}", exc_info=True)
|
||||
logger.exception("Error activating user (admin): %s", e)
|
||||
raise
|
||||
|
||||
|
||||
@@ -518,11 +465,7 @@ async def admin_deactivate_user(
|
||||
) -> Any:
|
||||
"""Deactivate a user account."""
|
||||
try:
|
||||
user = await user_crud.get(db, id=user_id)
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
|
||||
)
|
||||
user = await user_service.get_user(db, str(user_id))
|
||||
|
||||
# Prevent deactivating yourself
|
||||
if user.id == admin.id:
|
||||
@@ -532,17 +475,15 @@ async def admin_deactivate_user(
|
||||
error_code=ErrorCode.OPERATION_FORBIDDEN,
|
||||
)
|
||||
|
||||
await user_crud.update(db, db_obj=user, obj_in={"is_active": False})
|
||||
logger.info(f"Admin {admin.email} deactivated user {user.email}")
|
||||
await user_service.update_user(db, user=user, obj_in={"is_active": False})
|
||||
logger.info("Admin %s deactivated user %s", admin.email, user.email)
|
||||
|
||||
return MessageResponse(
|
||||
success=True, message=f"User {user.email} has been deactivated"
|
||||
)
|
||||
|
||||
except NotFoundError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error deactivating user (admin): {e!s}", exc_info=True)
|
||||
logger.exception("Error deactivating user (admin): %s", e)
|
||||
raise
|
||||
|
||||
|
||||
@@ -567,16 +508,16 @@ async def admin_bulk_user_action(
|
||||
try:
|
||||
# Use efficient bulk operations instead of loop
|
||||
if bulk_action.action == BulkAction.ACTIVATE:
|
||||
affected_count = await user_crud.bulk_update_status(
|
||||
affected_count = await user_service.bulk_update_status(
|
||||
db, user_ids=bulk_action.user_ids, is_active=True
|
||||
)
|
||||
elif bulk_action.action == BulkAction.DEACTIVATE:
|
||||
affected_count = await user_crud.bulk_update_status(
|
||||
affected_count = await user_service.bulk_update_status(
|
||||
db, user_ids=bulk_action.user_ids, is_active=False
|
||||
)
|
||||
elif bulk_action.action == BulkAction.DELETE:
|
||||
# bulk_soft_delete automatically excludes the admin user
|
||||
affected_count = await user_crud.bulk_soft_delete(
|
||||
affected_count = await user_service.bulk_soft_delete(
|
||||
db, user_ids=bulk_action.user_ids, exclude_user_id=admin.id
|
||||
)
|
||||
else: # pragma: no cover
|
||||
@@ -587,8 +528,11 @@ async def admin_bulk_user_action(
|
||||
failed_count = requested_count - affected_count
|
||||
|
||||
logger.info(
|
||||
f"Admin {admin.email} performed bulk {bulk_action.action.value} "
|
||||
f"on {affected_count} users ({failed_count} skipped/failed)"
|
||||
"Admin %s performed bulk %s on %s users (%s skipped/failed)",
|
||||
admin.email,
|
||||
bulk_action.action.value,
|
||||
affected_count,
|
||||
failed_count,
|
||||
)
|
||||
|
||||
return BulkActionResult(
|
||||
@@ -600,7 +544,7 @@ async def admin_bulk_user_action(
|
||||
)
|
||||
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error(f"Error in bulk user action: {e!s}", exc_info=True)
|
||||
logger.exception("Error in bulk user action: %s", e)
|
||||
raise
|
||||
|
||||
|
||||
@@ -624,7 +568,7 @@ async def admin_list_organizations(
|
||||
"""List all organizations with filtering and search."""
|
||||
try:
|
||||
# Use optimized method that gets member counts in single query (no N+1)
|
||||
orgs_with_data, total = await organization_crud.get_multi_with_member_counts(
|
||||
orgs_with_data, total = await organization_service.get_multi_with_member_counts(
|
||||
db,
|
||||
skip=pagination.offset,
|
||||
limit=pagination.limit,
|
||||
@@ -661,7 +605,7 @@ async def admin_list_organizations(
|
||||
return PaginatedResponse(data=orgs_with_count, pagination=pagination_meta)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing organizations (admin): {e!s}", exc_info=True)
|
||||
logger.exception("Error listing organizations (admin): %s", e)
|
||||
raise
|
||||
|
||||
|
||||
@@ -680,8 +624,8 @@ async def admin_create_organization(
|
||||
) -> Any:
|
||||
"""Create a new organization."""
|
||||
try:
|
||||
org = await organization_crud.create(db, obj_in=org_in)
|
||||
logger.info(f"Admin {admin.email} created organization {org.name}")
|
||||
org = await organization_service.create_organization(db, obj_in=org_in)
|
||||
logger.info("Admin %s created organization %s", admin.email, org.name)
|
||||
|
||||
# Add member count
|
||||
org_dict = {
|
||||
@@ -697,11 +641,11 @@ async def admin_create_organization(
|
||||
}
|
||||
return OrganizationResponse(**org_dict)
|
||||
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to create organization: {e!s}")
|
||||
raise NotFoundError(message=str(e), error_code=ErrorCode.ALREADY_EXISTS)
|
||||
except DuplicateEntryError as e:
|
||||
logger.warning("Failed to create organization: %s", e)
|
||||
raise DuplicateError(message=str(e), error_code=ErrorCode.ALREADY_EXISTS)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating organization (admin): {e!s}", exc_info=True)
|
||||
logger.exception("Error creating organization (admin): %s", e)
|
||||
raise
|
||||
|
||||
|
||||
@@ -718,12 +662,7 @@ async def admin_get_organization(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""Get detailed information about a specific organization."""
|
||||
org = await organization_crud.get(db, id=org_id)
|
||||
if not org:
|
||||
raise NotFoundError(
|
||||
message=f"Organization {org_id} not found", error_code=ErrorCode.NOT_FOUND
|
||||
)
|
||||
|
||||
org = await organization_service.get_organization(db, str(org_id))
|
||||
org_dict = {
|
||||
"id": org.id,
|
||||
"name": org.name,
|
||||
@@ -733,7 +672,7 @@ async def admin_get_organization(
|
||||
"settings": org.settings,
|
||||
"created_at": org.created_at,
|
||||
"updated_at": org.updated_at,
|
||||
"member_count": await organization_crud.get_member_count(
|
||||
"member_count": await organization_service.get_member_count(
|
||||
db, organization_id=org.id
|
||||
),
|
||||
}
|
||||
@@ -755,15 +694,11 @@ async def admin_update_organization(
|
||||
) -> Any:
|
||||
"""Update organization information."""
|
||||
try:
|
||||
org = await organization_crud.get(db, id=org_id)
|
||||
if not org:
|
||||
raise NotFoundError(
|
||||
message=f"Organization {org_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
updated_org = await organization_crud.update(db, db_obj=org, obj_in=org_in)
|
||||
logger.info(f"Admin {admin.email} updated organization {updated_org.name}")
|
||||
org = await organization_service.get_organization(db, str(org_id))
|
||||
updated_org = await organization_service.update_organization(
|
||||
db, org=org, obj_in=org_in
|
||||
)
|
||||
logger.info("Admin %s updated organization %s", admin.email, updated_org.name)
|
||||
|
||||
org_dict = {
|
||||
"id": updated_org.id,
|
||||
@@ -774,16 +709,14 @@ async def admin_update_organization(
|
||||
"settings": updated_org.settings,
|
||||
"created_at": updated_org.created_at,
|
||||
"updated_at": updated_org.updated_at,
|
||||
"member_count": await organization_crud.get_member_count(
|
||||
"member_count": await organization_service.get_member_count(
|
||||
db, organization_id=updated_org.id
|
||||
),
|
||||
}
|
||||
return OrganizationResponse(**org_dict)
|
||||
|
||||
except NotFoundError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating organization (admin): {e!s}", exc_info=True)
|
||||
logger.exception("Error updating organization (admin): %s", e)
|
||||
raise
|
||||
|
||||
|
||||
@@ -801,24 +734,16 @@ async def admin_delete_organization(
|
||||
) -> Any:
|
||||
"""Delete an organization and all its relationships."""
|
||||
try:
|
||||
org = await organization_crud.get(db, id=org_id)
|
||||
if not org:
|
||||
raise NotFoundError(
|
||||
message=f"Organization {org_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
await organization_crud.remove(db, id=org_id)
|
||||
logger.info(f"Admin {admin.email} deleted organization {org.name}")
|
||||
org = await organization_service.get_organization(db, str(org_id))
|
||||
await organization_service.remove_organization(db, str(org_id))
|
||||
logger.info("Admin %s deleted organization %s", admin.email, org.name)
|
||||
|
||||
return MessageResponse(
|
||||
success=True, message=f"Organization {org.name} has been deleted"
|
||||
)
|
||||
|
||||
except NotFoundError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting organization (admin): {e!s}", exc_info=True)
|
||||
logger.exception("Error deleting organization (admin): %s", e)
|
||||
raise
|
||||
|
||||
|
||||
@@ -838,14 +763,8 @@ async def admin_list_organization_members(
|
||||
) -> Any:
|
||||
"""List all members of an organization."""
|
||||
try:
|
||||
org = await organization_crud.get(db, id=org_id)
|
||||
if not org:
|
||||
raise NotFoundError(
|
||||
message=f"Organization {org_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
members, total = await organization_crud.get_organization_members(
|
||||
await organization_service.get_organization(db, str(org_id)) # validates exists
|
||||
members, total = await organization_service.get_organization_members(
|
||||
db,
|
||||
organization_id=org_id,
|
||||
skip=pagination.offset,
|
||||
@@ -868,9 +787,7 @@ async def admin_list_organization_members(
|
||||
except NotFoundError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error listing organization members (admin): {e!s}", exc_info=True
|
||||
)
|
||||
logger.exception("Error listing organization members (admin): %s", e)
|
||||
raise
|
||||
|
||||
|
||||
@@ -898,45 +815,32 @@ async def admin_add_organization_member(
|
||||
) -> Any:
|
||||
"""Add a user to an organization."""
|
||||
try:
|
||||
org = await organization_crud.get(db, id=org_id)
|
||||
if not org:
|
||||
raise NotFoundError(
|
||||
message=f"Organization {org_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
org = await organization_service.get_organization(db, str(org_id))
|
||||
user = await user_service.get_user(db, str(request.user_id))
|
||||
|
||||
user = await user_crud.get(db, id=request.user_id)
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
message=f"User {request.user_id} not found",
|
||||
error_code=ErrorCode.USER_NOT_FOUND,
|
||||
)
|
||||
|
||||
await organization_crud.add_user(
|
||||
await organization_service.add_member(
|
||||
db, organization_id=org_id, user_id=request.user_id, role=request.role
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Admin {admin.email} added user {user.email} to organization {org.name} "
|
||||
f"with role {request.role.value}"
|
||||
"Admin %s added user %s to organization %s with role %s",
|
||||
admin.email,
|
||||
user.email,
|
||||
org.name,
|
||||
request.role.value,
|
||||
)
|
||||
|
||||
return MessageResponse(
|
||||
success=True, message=f"User {user.email} added to organization {org.name}"
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to add user to organization: {e!s}")
|
||||
# Use DuplicateError for "already exists" scenarios
|
||||
except DuplicateEntryError as e:
|
||||
logger.warning("Failed to add user to organization: %s", e)
|
||||
raise DuplicateError(
|
||||
message=str(e), error_code=ErrorCode.USER_ALREADY_EXISTS, field="user_id"
|
||||
)
|
||||
except NotFoundError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error adding member to organization (admin): {e!s}", exc_info=True
|
||||
)
|
||||
logger.exception("Error adding member to organization (admin): %s", e)
|
||||
raise
|
||||
|
||||
|
||||
@@ -955,20 +859,10 @@ async def admin_remove_organization_member(
|
||||
) -> Any:
|
||||
"""Remove a user from an organization."""
|
||||
try:
|
||||
org = await organization_crud.get(db, id=org_id)
|
||||
if not org:
|
||||
raise NotFoundError(
|
||||
message=f"Organization {org_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
org = await organization_service.get_organization(db, str(org_id))
|
||||
user = await user_service.get_user(db, str(user_id))
|
||||
|
||||
user = await user_crud.get(db, id=user_id)
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
|
||||
)
|
||||
|
||||
success = await organization_crud.remove_user(
|
||||
success = await organization_service.remove_member(
|
||||
db, organization_id=org_id, user_id=user_id
|
||||
)
|
||||
|
||||
@@ -979,7 +873,10 @@ async def admin_remove_organization_member(
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Admin {admin.email} removed user {user.email} from organization {org.name}"
|
||||
"Admin %s removed user %s from organization %s",
|
||||
admin.email,
|
||||
user.email,
|
||||
org.name,
|
||||
)
|
||||
|
||||
return MessageResponse(
|
||||
@@ -990,9 +887,7 @@ async def admin_remove_organization_member(
|
||||
except NotFoundError:
|
||||
raise
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error(
|
||||
f"Error removing member from organization (admin): {e!s}", exc_info=True
|
||||
)
|
||||
logger.exception("Error removing member from organization (admin): %s", e)
|
||||
raise
|
||||
|
||||
|
||||
@@ -1022,7 +917,7 @@ async def admin_list_sessions(
|
||||
"""List all sessions across all users with filtering and pagination."""
|
||||
try:
|
||||
# Get sessions with user info (eager loaded to prevent N+1)
|
||||
sessions, total = await session_crud.get_all_sessions(
|
||||
sessions, total = await session_service.get_all_sessions(
|
||||
db,
|
||||
skip=pagination.offset,
|
||||
limit=pagination.limit,
|
||||
@@ -1061,7 +956,10 @@ async def admin_list_sessions(
|
||||
session_responses.append(session_response)
|
||||
|
||||
logger.info(
|
||||
f"Admin {admin.email} listed {len(session_responses)} sessions (total: {total})"
|
||||
"Admin %s listed %s sessions (total: %s)",
|
||||
admin.email,
|
||||
len(session_responses),
|
||||
total,
|
||||
)
|
||||
|
||||
pagination_meta = create_pagination_meta(
|
||||
@@ -1074,5 +972,5 @@ async def admin_list_sessions(
|
||||
return PaginatedResponse(data=session_responses, pagination=pagination_meta)
|
||||
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error(f"Error listing sessions (admin): {e!s}", exc_info=True)
|
||||
logger.exception("Error listing sessions (admin): %s", e)
|
||||
raise
|
||||
|
||||
@@ -15,16 +15,14 @@ from app.core.auth import (
|
||||
TokenExpiredError,
|
||||
TokenInvalidError,
|
||||
decode_token,
|
||||
get_password_hash,
|
||||
)
|
||||
from app.core.database import get_db
|
||||
from app.core.exceptions import (
|
||||
AuthenticationError as AuthError,
|
||||
DatabaseError,
|
||||
DuplicateError,
|
||||
ErrorCode,
|
||||
)
|
||||
from app.crud.session import session as session_crud
|
||||
from app.crud.user import user as user_crud
|
||||
from app.models.user import User
|
||||
from app.schemas.common import MessageResponse
|
||||
from app.schemas.sessions import LogoutRequest, SessionCreate
|
||||
@@ -39,6 +37,8 @@ from app.schemas.users import (
|
||||
)
|
||||
from app.services.auth_service import AuthenticationError, AuthService
|
||||
from app.services.email_service import email_service
|
||||
from app.services.session_service import session_service
|
||||
from app.services.user_service import user_service
|
||||
from app.utils.device import extract_device_info
|
||||
from app.utils.security import create_password_reset_token, verify_password_reset_token
|
||||
|
||||
@@ -91,17 +91,18 @@ async def _create_login_session(
|
||||
location_country=device_info.location_country,
|
||||
)
|
||||
|
||||
await session_crud.create_session(db, obj_in=session_data)
|
||||
await session_service.create_session(db, obj_in=session_data)
|
||||
|
||||
logger.info(
|
||||
f"{login_type.capitalize()} successful: {user.email} from {device_info.device_name} "
|
||||
f"(IP: {device_info.ip_address})"
|
||||
"%s successful: %s from %s (IP: %s)",
|
||||
login_type.capitalize(),
|
||||
user.email,
|
||||
device_info.device_name,
|
||||
device_info.ip_address,
|
||||
)
|
||||
except Exception as session_err:
|
||||
# Log but don't fail login if session creation fails
|
||||
logger.error(
|
||||
f"Failed to create session for {user.email}: {session_err!s}", exc_info=True
|
||||
)
|
||||
logger.exception("Failed to create session for %s: %s", user.email, session_err)
|
||||
|
||||
|
||||
@router.post(
|
||||
@@ -123,15 +124,21 @@ async def register_user(
|
||||
try:
|
||||
user = await AuthService.create_user(db, user_data)
|
||||
return user
|
||||
except AuthenticationError as e:
|
||||
except DuplicateError:
|
||||
# SECURITY: Don't reveal if email exists - generic error message
|
||||
logger.warning(f"Registration failed: {e!s}")
|
||||
logger.warning("Registration failed: duplicate email %s", user_data.email)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Registration failed. Please check your information and try again.",
|
||||
)
|
||||
except AuthError as e:
|
||||
logger.warning("Registration failed: %s", e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Registration failed. Please check your information and try again.",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during registration: {e!s}", exc_info=True)
|
||||
logger.exception("Unexpected error during registration: %s", e)
|
||||
raise DatabaseError(
|
||||
message="An unexpected error occurred. Please try again later.",
|
||||
error_code=ErrorCode.INTERNAL_ERROR,
|
||||
@@ -159,7 +166,7 @@ async def login(
|
||||
|
||||
# Explicitly check for None result and raise correct exception
|
||||
if user is None:
|
||||
logger.warning(f"Invalid login attempt for: {login_data.email}")
|
||||
logger.warning("Invalid login attempt for: %s", login_data.email)
|
||||
raise AuthError(
|
||||
message="Invalid email or password",
|
||||
error_code=ErrorCode.INVALID_CREDENTIALS,
|
||||
@@ -175,14 +182,11 @@ async def login(
|
||||
|
||||
except AuthenticationError as e:
|
||||
# Handle specific authentication errors like inactive accounts
|
||||
logger.warning(f"Authentication failed: {e!s}")
|
||||
logger.warning("Authentication failed: %s", e)
|
||||
raise AuthError(message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS)
|
||||
except AuthError:
|
||||
# Re-raise custom auth exceptions without modification
|
||||
raise
|
||||
except Exception as e:
|
||||
# Handle unexpected errors
|
||||
logger.error(f"Unexpected error during login: {e!s}", exc_info=True)
|
||||
logger.exception("Unexpected error during login: %s", e)
|
||||
raise DatabaseError(
|
||||
message="An unexpected error occurred. Please try again later.",
|
||||
error_code=ErrorCode.INTERNAL_ERROR,
|
||||
@@ -224,13 +228,10 @@ async def login_oauth(
|
||||
# Return full token response with user data
|
||||
return tokens
|
||||
except AuthenticationError as e:
|
||||
logger.warning(f"OAuth authentication failed: {e!s}")
|
||||
logger.warning("OAuth authentication failed: %s", e)
|
||||
raise AuthError(message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS)
|
||||
except AuthError:
|
||||
# Re-raise custom auth exceptions without modification
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during OAuth login: {e!s}", exc_info=True)
|
||||
logger.exception("Unexpected error during OAuth login: %s", e)
|
||||
raise DatabaseError(
|
||||
message="An unexpected error occurred. Please try again later.",
|
||||
error_code=ErrorCode.INTERNAL_ERROR,
|
||||
@@ -259,11 +260,12 @@ async def refresh_token(
|
||||
)
|
||||
|
||||
# Check if session exists and is active
|
||||
session = await session_crud.get_active_by_jti(db, jti=refresh_payload.jti)
|
||||
session = await session_service.get_active_by_jti(db, jti=refresh_payload.jti)
|
||||
|
||||
if not session:
|
||||
logger.warning(
|
||||
f"Refresh token used for inactive or non-existent session: {refresh_payload.jti}"
|
||||
"Refresh token used for inactive or non-existent session: %s",
|
||||
refresh_payload.jti,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
@@ -279,16 +281,14 @@ async def refresh_token(
|
||||
|
||||
# Update session with new refresh token JTI and expiration
|
||||
try:
|
||||
await session_crud.update_refresh_token(
|
||||
await session_service.update_refresh_token(
|
||||
db,
|
||||
session=session,
|
||||
new_jti=new_refresh_payload.jti,
|
||||
new_expires_at=datetime.fromtimestamp(new_refresh_payload.exp, tz=UTC),
|
||||
)
|
||||
except Exception as session_err:
|
||||
logger.error(
|
||||
f"Failed to update session {session.id}: {session_err!s}", exc_info=True
|
||||
)
|
||||
logger.exception("Failed to update session %s: %s", session.id, session_err)
|
||||
# Continue anyway - tokens are already issued
|
||||
|
||||
return tokens
|
||||
@@ -311,7 +311,7 @@ async def refresh_token(
|
||||
# Re-raise HTTP exceptions (like session revoked)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during token refresh: {e!s}")
|
||||
logger.error("Unexpected error during token refresh: %s", e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="An unexpected error occurred. Please try again later.",
|
||||
@@ -347,7 +347,7 @@ async def request_password_reset(
|
||||
"""
|
||||
try:
|
||||
# Look up user by email
|
||||
user = await user_crud.get_by_email(db, email=reset_request.email)
|
||||
user = await user_service.get_by_email(db, email=reset_request.email)
|
||||
|
||||
# Only send email if user exists and is active
|
||||
if user and user.is_active:
|
||||
@@ -358,11 +358,12 @@ async def request_password_reset(
|
||||
await email_service.send_password_reset_email(
|
||||
to_email=user.email, reset_token=reset_token, user_name=user.first_name
|
||||
)
|
||||
logger.info(f"Password reset requested for {user.email}")
|
||||
logger.info("Password reset requested for %s", user.email)
|
||||
else:
|
||||
# Log attempt but don't reveal if email exists
|
||||
logger.warning(
|
||||
f"Password reset requested for non-existent or inactive email: {reset_request.email}"
|
||||
"Password reset requested for non-existent or inactive email: %s",
|
||||
reset_request.email,
|
||||
)
|
||||
|
||||
# Always return success to prevent email enumeration
|
||||
@@ -371,7 +372,7 @@ async def request_password_reset(
|
||||
message="If your email is registered, you will receive a password reset link shortly",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing password reset request: {e!s}", exc_info=True)
|
||||
logger.exception("Error processing password reset request: %s", e)
|
||||
# Still return success to prevent information leakage
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
@@ -412,40 +413,34 @@ async def confirm_password_reset(
|
||||
detail="Invalid or expired password reset token",
|
||||
)
|
||||
|
||||
# Look up user
|
||||
user = await user_crud.get_by_email(db, email=email)
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
||||
# Reset password via service (validates user exists and is active)
|
||||
try:
|
||||
user = await AuthService.reset_password(
|
||||
db, email=email, new_password=reset_confirm.new_password
|
||||
)
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="User account is inactive",
|
||||
)
|
||||
|
||||
# Update password
|
||||
user.password_hash = get_password_hash(reset_confirm.new_password)
|
||||
db.add(user)
|
||||
await db.commit()
|
||||
except AuthenticationError as e:
|
||||
err_msg = str(e)
|
||||
if "inactive" in err_msg.lower():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail=err_msg
|
||||
)
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=err_msg)
|
||||
|
||||
# SECURITY: Invalidate all existing sessions after password reset
|
||||
# This prevents stolen sessions from being used after password change
|
||||
from app.crud.session import session as session_crud
|
||||
|
||||
try:
|
||||
deactivated_count = await session_crud.deactivate_all_user_sessions(
|
||||
deactivated_count = await session_service.deactivate_all_user_sessions(
|
||||
db, user_id=str(user.id)
|
||||
)
|
||||
logger.info(
|
||||
f"Password reset successful for {user.email}, invalidated {deactivated_count} sessions"
|
||||
"Password reset successful for %s, invalidated %s sessions",
|
||||
user.email,
|
||||
deactivated_count,
|
||||
)
|
||||
except Exception as session_error:
|
||||
# Log but don't fail password reset if session invalidation fails
|
||||
logger.error(
|
||||
f"Failed to invalidate sessions after password reset: {session_error!s}"
|
||||
"Failed to invalidate sessions after password reset: %s", session_error
|
||||
)
|
||||
|
||||
return MessageResponse(
|
||||
@@ -456,7 +451,7 @@ async def confirm_password_reset(
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error confirming password reset: {e!s}", exc_info=True)
|
||||
logger.exception("Error confirming password reset: %s", e)
|
||||
await db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
@@ -506,19 +501,21 @@ async def logout(
|
||||
)
|
||||
except (TokenExpiredError, TokenInvalidError) as e:
|
||||
# Even if token is expired/invalid, try to deactivate session
|
||||
logger.warning(f"Logout with invalid/expired token: {e!s}")
|
||||
logger.warning("Logout with invalid/expired token: %s", e)
|
||||
# Don't fail - return success anyway
|
||||
return MessageResponse(success=True, message="Logged out successfully")
|
||||
|
||||
# Find the session by JTI
|
||||
session = await session_crud.get_by_jti(db, jti=refresh_payload.jti)
|
||||
session = await session_service.get_by_jti(db, jti=refresh_payload.jti)
|
||||
|
||||
if session:
|
||||
# Verify session belongs to current user (security check)
|
||||
if str(session.user_id) != str(current_user.id):
|
||||
logger.warning(
|
||||
f"User {current_user.id} attempted to logout session {session.id} "
|
||||
f"belonging to user {session.user_id}"
|
||||
"User %s attempted to logout session %s belonging to user %s",
|
||||
current_user.id,
|
||||
session.id,
|
||||
session.user_id,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
@@ -526,17 +523,20 @@ async def logout(
|
||||
)
|
||||
|
||||
# Deactivate the session
|
||||
await session_crud.deactivate(db, session_id=str(session.id))
|
||||
await session_service.deactivate(db, session_id=str(session.id))
|
||||
|
||||
logger.info(
|
||||
f"User {current_user.id} logged out from {session.device_name} "
|
||||
f"(session {session.id})"
|
||||
"User %s logged out from %s (session %s)",
|
||||
current_user.id,
|
||||
session.device_name,
|
||||
session.id,
|
||||
)
|
||||
else:
|
||||
# Session not found - maybe already deleted or never existed
|
||||
# Return success anyway (idempotent)
|
||||
logger.info(
|
||||
f"Logout requested for non-existent session (JTI: {refresh_payload.jti})"
|
||||
"Logout requested for non-existent session (JTI: %s)",
|
||||
refresh_payload.jti,
|
||||
)
|
||||
|
||||
return MessageResponse(success=True, message="Logged out successfully")
|
||||
@@ -544,9 +544,7 @@ async def logout(
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error during logout for user {current_user.id}: {e!s}", exc_info=True
|
||||
)
|
||||
logger.exception("Error during logout for user %s: %s", current_user.id, e)
|
||||
# Don't expose error details
|
||||
return MessageResponse(success=True, message="Logged out successfully")
|
||||
|
||||
@@ -584,12 +582,12 @@ async def logout_all(
|
||||
"""
|
||||
try:
|
||||
# Deactivate all sessions for this user
|
||||
count = await session_crud.deactivate_all_user_sessions(
|
||||
count = await session_service.deactivate_all_user_sessions(
|
||||
db, user_id=str(current_user.id)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"User {current_user.id} logged out from all devices ({count} sessions)"
|
||||
"User %s logged out from all devices (%s sessions)", current_user.id, count
|
||||
)
|
||||
|
||||
return MessageResponse(
|
||||
@@ -598,9 +596,7 @@ async def logout_all(
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error during logout-all for user {current_user.id}: {e!s}", exc_info=True
|
||||
)
|
||||
logger.exception("Error during logout-all for user %s: %s", current_user.id, e)
|
||||
await db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
|
||||
@@ -25,8 +25,6 @@ from app.core.auth import decode_token
|
||||
from app.core.config import settings
|
||||
from app.core.database import get_db
|
||||
from app.core.exceptions import AuthenticationError as AuthError
|
||||
from app.crud import oauth_account
|
||||
from app.crud.session import session as session_crud
|
||||
from app.models.user import User
|
||||
from app.schemas.oauth import (
|
||||
OAuthAccountsListResponse,
|
||||
@@ -38,6 +36,7 @@ from app.schemas.oauth import (
|
||||
from app.schemas.sessions import SessionCreate
|
||||
from app.schemas.users import Token
|
||||
from app.services.oauth_service import OAuthService
|
||||
from app.services.session_service import session_service
|
||||
from app.utils.device import extract_device_info
|
||||
|
||||
router = APIRouter()
|
||||
@@ -82,17 +81,19 @@ async def _create_oauth_login_session(
|
||||
location_country=device_info.location_country,
|
||||
)
|
||||
|
||||
await session_crud.create_session(db, obj_in=session_data)
|
||||
await session_service.create_session(db, obj_in=session_data)
|
||||
|
||||
logger.info(
|
||||
f"OAuth login successful: {user.email} via {provider} "
|
||||
f"from {device_info.device_name} (IP: {device_info.ip_address})"
|
||||
"OAuth login successful: %s via %s from %s (IP: %s)",
|
||||
user.email,
|
||||
provider,
|
||||
device_info.device_name,
|
||||
device_info.ip_address,
|
||||
)
|
||||
except Exception as session_err:
|
||||
# Log but don't fail login if session creation fails
|
||||
logger.error(
|
||||
f"Failed to create session for OAuth login {user.email}: {session_err!s}",
|
||||
exc_info=True,
|
||||
logger.exception(
|
||||
"Failed to create session for OAuth login %s: %s", user.email, session_err
|
||||
)
|
||||
|
||||
|
||||
@@ -177,13 +178,13 @@ async def get_authorization_url(
|
||||
}
|
||||
|
||||
except AuthError as e:
|
||||
logger.warning(f"OAuth authorization failed: {e!s}")
|
||||
logger.warning("OAuth authorization failed: %s", e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"OAuth authorization error: {e!s}", exc_info=True)
|
||||
logger.exception("OAuth authorization error: %s", e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to create authorization URL",
|
||||
@@ -251,13 +252,13 @@ async def handle_callback(
|
||||
return result
|
||||
|
||||
except AuthError as e:
|
||||
logger.warning(f"OAuth callback failed: {e!s}")
|
||||
logger.warning("OAuth callback failed: %s", e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=str(e),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"OAuth callback error: {e!s}", exc_info=True)
|
||||
logger.exception("OAuth callback error: %s", e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="OAuth authentication failed",
|
||||
@@ -289,7 +290,7 @@ async def list_accounts(
|
||||
Returns:
|
||||
List of linked OAuth accounts
|
||||
"""
|
||||
accounts = await oauth_account.get_user_accounts(db, user_id=current_user.id)
|
||||
accounts = await OAuthService.get_user_accounts(db, user_id=current_user.id)
|
||||
return OAuthAccountsListResponse(accounts=accounts)
|
||||
|
||||
|
||||
@@ -338,13 +339,13 @@ async def unlink_account(
|
||||
)
|
||||
|
||||
except AuthError as e:
|
||||
logger.warning(f"OAuth unlink failed for {current_user.email}: {e!s}")
|
||||
logger.warning("OAuth unlink failed for %s: %s", current_user.email, e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"OAuth unlink error: {e!s}", exc_info=True)
|
||||
logger.exception("OAuth unlink error: %s", e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to unlink OAuth account",
|
||||
@@ -397,7 +398,7 @@ async def start_link(
|
||||
)
|
||||
|
||||
# Check if user already has this provider linked
|
||||
existing = await oauth_account.get_user_account_by_provider(
|
||||
existing = await OAuthService.get_user_account_by_provider(
|
||||
db, user_id=current_user.id, provider=provider
|
||||
)
|
||||
if existing:
|
||||
@@ -420,13 +421,13 @@ async def start_link(
|
||||
}
|
||||
|
||||
except AuthError as e:
|
||||
logger.warning(f"OAuth link authorization failed: {e!s}")
|
||||
logger.warning("OAuth link authorization failed: %s", e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"OAuth link error: {e!s}", exc_info=True)
|
||||
logger.exception("OAuth link error: %s", e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to create authorization URL",
|
||||
|
||||
@@ -34,7 +34,6 @@ from app.api.dependencies.auth import (
|
||||
)
|
||||
from app.core.config import settings
|
||||
from app.core.database import get_db
|
||||
from app.crud import oauth_client as oauth_client_crud
|
||||
from app.models.user import User
|
||||
from app.schemas.oauth import (
|
||||
OAuthClientCreate,
|
||||
@@ -453,7 +452,7 @@ async def token(
|
||||
except Exception as e:
|
||||
# Log malformed Basic auth for security monitoring
|
||||
logger.warning(
|
||||
f"Malformed Basic auth header in token request: {type(e).__name__}"
|
||||
"Malformed Basic auth header in token request: %s", type(e).__name__
|
||||
)
|
||||
# Fall back to form body
|
||||
|
||||
@@ -564,7 +563,8 @@ async def revoke(
|
||||
except Exception as e:
|
||||
# Log malformed Basic auth for security monitoring
|
||||
logger.warning(
|
||||
f"Malformed Basic auth header in revoke request: {type(e).__name__}"
|
||||
"Malformed Basic auth header in revoke request: %s",
|
||||
type(e).__name__,
|
||||
)
|
||||
# Fall back to form body
|
||||
|
||||
@@ -586,7 +586,7 @@ async def revoke(
|
||||
)
|
||||
except Exception as e:
|
||||
# Log but don't expose errors per RFC 7009
|
||||
logger.warning(f"Token revocation error: {e}")
|
||||
logger.warning("Token revocation error: %s", e)
|
||||
|
||||
# Always return 200 OK per RFC 7009
|
||||
return {"status": "ok"}
|
||||
@@ -635,7 +635,8 @@ async def introspect(
|
||||
except Exception as e:
|
||||
# Log malformed Basic auth for security monitoring
|
||||
logger.warning(
|
||||
f"Malformed Basic auth header in introspect request: {type(e).__name__}"
|
||||
"Malformed Basic auth header in introspect request: %s",
|
||||
type(e).__name__,
|
||||
)
|
||||
# Fall back to form body
|
||||
|
||||
@@ -655,8 +656,8 @@ async def introspect(
|
||||
headers={"WWW-Authenticate": "Basic"},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Token introspection error: {e}")
|
||||
return OAuthTokenIntrospectionResponse(active=False)
|
||||
logger.warning("Token introspection error: %s", e)
|
||||
return OAuthTokenIntrospectionResponse(active=False) # pyright: ignore[reportCallIssue]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
@@ -712,7 +713,7 @@ async def register_client(
|
||||
client_type=client_type,
|
||||
)
|
||||
|
||||
client, secret = await oauth_client_crud.create_client(db, obj_in=client_data)
|
||||
client, secret = await provider_service.register_client(db, client_data)
|
||||
|
||||
# Update MCP server URL if provided
|
||||
if mcp_server_url:
|
||||
@@ -750,7 +751,7 @@ async def list_clients(
|
||||
current_user: User = Depends(get_current_superuser),
|
||||
) -> list[OAuthClientResponse]:
|
||||
"""List all OAuth clients."""
|
||||
clients = await oauth_client_crud.get_all_clients(db)
|
||||
clients = await provider_service.list_clients(db)
|
||||
return [OAuthClientResponse.model_validate(c) for c in clients]
|
||||
|
||||
|
||||
@@ -776,7 +777,7 @@ async def delete_client(
|
||||
detail="Client not found",
|
||||
)
|
||||
|
||||
await oauth_client_crud.delete_client(db, client_id=client_id)
|
||||
await provider_service.delete_client_by_id(db, client_id=client_id)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
@@ -797,30 +798,7 @@ async def list_my_consents(
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
) -> list[dict]:
|
||||
"""List applications the user has authorized."""
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.models.oauth_client import OAuthClient
|
||||
from app.models.oauth_provider_token import OAuthConsent
|
||||
|
||||
result = await db.execute(
|
||||
select(OAuthConsent, OAuthClient)
|
||||
.join(OAuthClient, OAuthConsent.client_id == OAuthClient.client_id)
|
||||
.where(OAuthConsent.user_id == current_user.id)
|
||||
)
|
||||
rows = result.all()
|
||||
|
||||
return [
|
||||
{
|
||||
"client_id": consent.client_id,
|
||||
"client_name": client.client_name,
|
||||
"client_description": client.client_description,
|
||||
"granted_scopes": consent.granted_scopes.split()
|
||||
if consent.granted_scopes
|
||||
else [],
|
||||
"granted_at": consent.created_at.isoformat(),
|
||||
}
|
||||
for consent, client in rows
|
||||
]
|
||||
return await provider_service.list_user_consents(db, user_id=current_user.id)
|
||||
|
||||
|
||||
@router.delete(
|
||||
|
||||
@@ -15,8 +15,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.api.dependencies.auth import get_current_user
|
||||
from app.api.dependencies.permissions import require_org_admin, require_org_membership
|
||||
from app.core.database import get_db
|
||||
from app.core.exceptions import ErrorCode, NotFoundError
|
||||
from app.crud.organization import organization as organization_crud
|
||||
from app.models.user import User
|
||||
from app.schemas.common import (
|
||||
PaginatedResponse,
|
||||
@@ -28,6 +26,7 @@ from app.schemas.organizations import (
|
||||
OrganizationResponse,
|
||||
OrganizationUpdate,
|
||||
)
|
||||
from app.services.organization_service import organization_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -54,7 +53,7 @@ async def get_my_organizations(
|
||||
"""
|
||||
try:
|
||||
# Get all org data in single query with JOIN and subquery
|
||||
orgs_data = await organization_crud.get_user_organizations_with_details(
|
||||
orgs_data = await organization_service.get_user_organizations_with_details(
|
||||
db, user_id=current_user.id, is_active=is_active
|
||||
)
|
||||
|
||||
@@ -78,7 +77,7 @@ async def get_my_organizations(
|
||||
return orgs_with_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user organizations: {e!s}", exc_info=True)
|
||||
logger.exception("Error getting user organizations: %s", e)
|
||||
raise
|
||||
|
||||
|
||||
@@ -100,13 +99,7 @@ async def get_organization(
|
||||
User must be a member of the organization.
|
||||
"""
|
||||
try:
|
||||
org = await organization_crud.get(db, id=organization_id)
|
||||
if not org: # pragma: no cover - Permission check prevents this (see docs/UNREACHABLE_DEFENSIVE_CODE_ANALYSIS.md)
|
||||
raise NotFoundError(
|
||||
detail=f"Organization {organization_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
org = await organization_service.get_organization(db, str(organization_id))
|
||||
org_dict = {
|
||||
"id": org.id,
|
||||
"name": org.name,
|
||||
@@ -116,16 +109,14 @@ async def get_organization(
|
||||
"settings": org.settings,
|
||||
"created_at": org.created_at,
|
||||
"updated_at": org.updated_at,
|
||||
"member_count": await organization_crud.get_member_count(
|
||||
"member_count": await organization_service.get_member_count(
|
||||
db, organization_id=org.id
|
||||
),
|
||||
}
|
||||
return OrganizationResponse(**org_dict)
|
||||
|
||||
except NotFoundError: # pragma: no cover - See above
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting organization: {e!s}", exc_info=True)
|
||||
logger.exception("Error getting organization: %s", e)
|
||||
raise
|
||||
|
||||
|
||||
@@ -149,7 +140,7 @@ async def get_organization_members(
|
||||
User must be a member of the organization to view members.
|
||||
"""
|
||||
try:
|
||||
members, total = await organization_crud.get_organization_members(
|
||||
members, total = await organization_service.get_organization_members(
|
||||
db,
|
||||
organization_id=organization_id,
|
||||
skip=pagination.offset,
|
||||
@@ -169,7 +160,7 @@ async def get_organization_members(
|
||||
return PaginatedResponse(data=member_responses, pagination=pagination_meta)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting organization members: {e!s}", exc_info=True)
|
||||
logger.exception("Error getting organization members: %s", e)
|
||||
raise
|
||||
|
||||
|
||||
@@ -192,16 +183,12 @@ async def update_organization(
|
||||
Requires owner or admin role in the organization.
|
||||
"""
|
||||
try:
|
||||
org = await organization_crud.get(db, id=organization_id)
|
||||
if not org: # pragma: no cover - Permission check prevents this (see docs/UNREACHABLE_DEFENSIVE_CODE_ANALYSIS.md)
|
||||
raise NotFoundError(
|
||||
detail=f"Organization {organization_id} not found",
|
||||
error_code=ErrorCode.NOT_FOUND,
|
||||
)
|
||||
|
||||
updated_org = await organization_crud.update(db, db_obj=org, obj_in=org_in)
|
||||
org = await organization_service.get_organization(db, str(organization_id))
|
||||
updated_org = await organization_service.update_organization(
|
||||
db, org=org, obj_in=org_in
|
||||
)
|
||||
logger.info(
|
||||
f"User {current_user.email} updated organization {updated_org.name}"
|
||||
"User %s updated organization %s", current_user.email, updated_org.name
|
||||
)
|
||||
|
||||
org_dict = {
|
||||
@@ -213,14 +200,12 @@ async def update_organization(
|
||||
"settings": updated_org.settings,
|
||||
"created_at": updated_org.created_at,
|
||||
"updated_at": updated_org.updated_at,
|
||||
"member_count": await organization_crud.get_member_count(
|
||||
"member_count": await organization_service.get_member_count(
|
||||
db, organization_id=updated_org.id
|
||||
),
|
||||
}
|
||||
return OrganizationResponse(**org_dict)
|
||||
|
||||
except NotFoundError: # pragma: no cover - See above
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating organization: {e!s}", exc_info=True)
|
||||
logger.exception("Error updating organization: %s", e)
|
||||
raise
|
||||
|
||||
@@ -17,10 +17,10 @@ from app.api.dependencies.auth import get_current_user
|
||||
from app.core.auth import decode_token
|
||||
from app.core.database import get_db
|
||||
from app.core.exceptions import AuthorizationError, ErrorCode, NotFoundError
|
||||
from app.crud.session import session as session_crud
|
||||
from app.models.user import User
|
||||
from app.schemas.common import MessageResponse
|
||||
from app.schemas.sessions import SessionListResponse, SessionResponse
|
||||
from app.services.session_service import session_service
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -60,7 +60,7 @@ async def list_my_sessions(
|
||||
"""
|
||||
try:
|
||||
# Get all active sessions for user
|
||||
sessions = await session_crud.get_user_sessions(
|
||||
sessions = await session_service.get_user_sessions(
|
||||
db, user_id=str(current_user.id), active_only=True
|
||||
)
|
||||
|
||||
@@ -74,9 +74,7 @@ async def list_my_sessions(
|
||||
# For now, we'll mark current based on most recent activity
|
||||
except Exception as e:
|
||||
# Optional token parsing - silently ignore failures
|
||||
logger.debug(
|
||||
f"Failed to decode access token for session marking: {e!s}"
|
||||
)
|
||||
logger.debug("Failed to decode access token for session marking: %s", e)
|
||||
|
||||
# Convert to response format
|
||||
session_responses = []
|
||||
@@ -98,7 +96,7 @@ async def list_my_sessions(
|
||||
session_responses.append(session_response)
|
||||
|
||||
logger.info(
|
||||
f"User {current_user.id} listed {len(session_responses)} active sessions"
|
||||
"User %s listed %s active sessions", current_user.id, len(session_responses)
|
||||
)
|
||||
|
||||
return SessionListResponse(
|
||||
@@ -106,9 +104,7 @@ async def list_my_sessions(
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error listing sessions for user {current_user.id}: {e!s}", exc_info=True
|
||||
)
|
||||
logger.exception("Error listing sessions for user %s: %s", current_user.id, e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to retrieve sessions",
|
||||
@@ -150,7 +146,7 @@ async def revoke_session(
|
||||
"""
|
||||
try:
|
||||
# Get the session
|
||||
session = await session_crud.get(db, id=str(session_id))
|
||||
session = await session_service.get_session(db, str(session_id))
|
||||
|
||||
if not session:
|
||||
raise NotFoundError(
|
||||
@@ -161,8 +157,10 @@ async def revoke_session(
|
||||
# Verify session belongs to current user
|
||||
if str(session.user_id) != str(current_user.id):
|
||||
logger.warning(
|
||||
f"User {current_user.id} attempted to revoke session {session_id} "
|
||||
f"belonging to user {session.user_id}"
|
||||
"User %s attempted to revoke session %s belonging to user %s",
|
||||
current_user.id,
|
||||
session_id,
|
||||
session.user_id,
|
||||
)
|
||||
raise AuthorizationError(
|
||||
message="You can only revoke your own sessions",
|
||||
@@ -170,11 +168,13 @@ async def revoke_session(
|
||||
)
|
||||
|
||||
# Deactivate the session
|
||||
await session_crud.deactivate(db, session_id=str(session_id))
|
||||
await session_service.deactivate(db, session_id=str(session_id))
|
||||
|
||||
logger.info(
|
||||
f"User {current_user.id} revoked session {session_id} "
|
||||
f"({session.device_name})"
|
||||
"User %s revoked session %s (%s)",
|
||||
current_user.id,
|
||||
session_id,
|
||||
session.device_name,
|
||||
)
|
||||
|
||||
return MessageResponse(
|
||||
@@ -185,7 +185,7 @@ async def revoke_session(
|
||||
except (NotFoundError, AuthorizationError):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error revoking session {session_id}: {e!s}", exc_info=True)
|
||||
logger.exception("Error revoking session %s: %s", session_id, e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to revoke session",
|
||||
@@ -224,12 +224,12 @@ async def cleanup_expired_sessions(
|
||||
"""
|
||||
try:
|
||||
# Use optimized bulk DELETE instead of N individual deletes
|
||||
deleted_count = await session_crud.cleanup_expired_for_user(
|
||||
deleted_count = await session_service.cleanup_expired_for_user(
|
||||
db, user_id=str(current_user.id)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"User {current_user.id} cleaned up {deleted_count} expired sessions"
|
||||
"User %s cleaned up %s expired sessions", current_user.id, deleted_count
|
||||
)
|
||||
|
||||
return MessageResponse(
|
||||
@@ -237,9 +237,8 @@ async def cleanup_expired_sessions(
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error cleaning up sessions for user {current_user.id}: {e!s}",
|
||||
exc_info=True,
|
||||
logger.exception(
|
||||
"Error cleaning up sessions for user %s: %s", current_user.id, e
|
||||
)
|
||||
await db.rollback()
|
||||
raise HTTPException(
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
User management endpoints for CRUD operations.
|
||||
User management endpoints for database operations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
@@ -13,8 +13,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.dependencies.auth import get_current_superuser, get_current_user
|
||||
from app.core.database import get_db
|
||||
from app.core.exceptions import AuthorizationError, ErrorCode, NotFoundError
|
||||
from app.crud.user import user as user_crud
|
||||
from app.core.exceptions import AuthorizationError, ErrorCode
|
||||
from app.models.user import User
|
||||
from app.schemas.common import (
|
||||
MessageResponse,
|
||||
@@ -25,6 +24,7 @@ from app.schemas.common import (
|
||||
)
|
||||
from app.schemas.users import PasswordChange, UserResponse, UserUpdate
|
||||
from app.services.auth_service import AuthenticationError, AuthService
|
||||
from app.services.user_service import user_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -71,7 +71,7 @@ async def list_users(
|
||||
filters["is_superuser"] = is_superuser
|
||||
|
||||
# Get paginated users with total count
|
||||
users, total = await user_crud.get_multi_with_total(
|
||||
users, total = await user_service.list_users(
|
||||
db,
|
||||
skip=pagination.offset,
|
||||
limit=pagination.limit,
|
||||
@@ -90,7 +90,7 @@ async def list_users(
|
||||
|
||||
return PaginatedResponse(data=users, pagination=pagination_meta)
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing users: {e!s}", exc_info=True)
|
||||
logger.exception("Error listing users: %s", e)
|
||||
raise
|
||||
|
||||
|
||||
@@ -107,7 +107,9 @@ async def list_users(
|
||||
""",
|
||||
operation_id="get_current_user_profile",
|
||||
)
|
||||
def get_current_user_profile(current_user: User = Depends(get_current_user)) -> Any:
|
||||
async def get_current_user_profile(
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> Any:
|
||||
"""Get current user's profile."""
|
||||
return current_user
|
||||
|
||||
@@ -138,18 +140,16 @@ async def update_current_user(
|
||||
Users cannot elevate their own permissions (protected by UserUpdate schema validator).
|
||||
"""
|
||||
try:
|
||||
updated_user = await user_crud.update(
|
||||
db, db_obj=current_user, obj_in=user_update
|
||||
updated_user = await user_service.update_user(
|
||||
db, user=current_user, obj_in=user_update
|
||||
)
|
||||
logger.info(f"User {current_user.id} updated their profile")
|
||||
logger.info("User %s updated their profile", current_user.id)
|
||||
return updated_user
|
||||
except ValueError as e:
|
||||
logger.error(f"Error updating user {current_user.id}: {e!s}")
|
||||
logger.error("Error updating user %s: %s", current_user.id, e)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Unexpected error updating user {current_user.id}: {e!s}", exc_info=True
|
||||
)
|
||||
logger.exception("Unexpected error updating user %s: %s", current_user.id, e)
|
||||
raise
|
||||
|
||||
|
||||
@@ -182,7 +182,9 @@ async def get_user_by_id(
|
||||
# Check permissions
|
||||
if str(user_id) != str(current_user.id) and not current_user.is_superuser:
|
||||
logger.warning(
|
||||
f"User {current_user.id} attempted to access user {user_id} without permission"
|
||||
"User %s attempted to access user %s without permission",
|
||||
current_user.id,
|
||||
user_id,
|
||||
)
|
||||
raise AuthorizationError(
|
||||
message="Not enough permissions to view this user",
|
||||
@@ -190,13 +192,7 @@ async def get_user_by_id(
|
||||
)
|
||||
|
||||
# Get user
|
||||
user = await user_crud.get(db, id=str(user_id))
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
message=f"User with id {user_id} not found",
|
||||
error_code=ErrorCode.USER_NOT_FOUND,
|
||||
)
|
||||
|
||||
user = await user_service.get_user(db, str(user_id))
|
||||
return user
|
||||
|
||||
|
||||
@@ -233,7 +229,9 @@ async def update_user(
|
||||
|
||||
if not is_own_profile and not current_user.is_superuser:
|
||||
logger.warning(
|
||||
f"User {current_user.id} attempted to update user {user_id} without permission"
|
||||
"User %s attempted to update user %s without permission",
|
||||
current_user.id,
|
||||
user_id,
|
||||
)
|
||||
raise AuthorizationError(
|
||||
message="Not enough permissions to update this user",
|
||||
@@ -241,22 +239,17 @@ async def update_user(
|
||||
)
|
||||
|
||||
# Get user
|
||||
user = await user_crud.get(db, id=str(user_id))
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
message=f"User with id {user_id} not found",
|
||||
error_code=ErrorCode.USER_NOT_FOUND,
|
||||
)
|
||||
user = await user_service.get_user(db, str(user_id))
|
||||
|
||||
try:
|
||||
updated_user = await user_crud.update(db, db_obj=user, obj_in=user_update)
|
||||
logger.info(f"User {user_id} updated by {current_user.id}")
|
||||
updated_user = await user_service.update_user(db, user=user, obj_in=user_update)
|
||||
logger.info("User %s updated by %s", user_id, current_user.id)
|
||||
return updated_user
|
||||
except ValueError as e:
|
||||
logger.error(f"Error updating user {user_id}: {e!s}")
|
||||
logger.error("Error updating user %s: %s", user_id, e)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error updating user {user_id}: {e!s}", exc_info=True)
|
||||
logger.exception("Unexpected error updating user %s: %s", user_id, e)
|
||||
raise
|
||||
|
||||
|
||||
@@ -296,19 +289,19 @@ async def change_current_user_password(
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info(f"User {current_user.id} changed their password")
|
||||
logger.info("User %s changed their password", current_user.id)
|
||||
return MessageResponse(
|
||||
success=True, message="Password changed successfully"
|
||||
)
|
||||
except AuthenticationError as e:
|
||||
logger.warning(
|
||||
f"Failed password change attempt for user {current_user.id}: {e!s}"
|
||||
"Failed password change attempt for user %s: %s", current_user.id, e
|
||||
)
|
||||
raise AuthorizationError(
|
||||
message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error changing password for user {current_user.id}: {e!s}")
|
||||
logger.error("Error changing password for user %s: %s", current_user.id, e)
|
||||
raise
|
||||
|
||||
|
||||
@@ -346,24 +339,19 @@ async def delete_user(
|
||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
)
|
||||
|
||||
# Get user
|
||||
user = await user_crud.get(db, id=str(user_id))
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
message=f"User with id {user_id} not found",
|
||||
error_code=ErrorCode.USER_NOT_FOUND,
|
||||
)
|
||||
# Get user (raises NotFoundError if not found)
|
||||
await user_service.get_user(db, str(user_id))
|
||||
|
||||
try:
|
||||
# Use soft delete instead of hard delete
|
||||
await user_crud.soft_delete(db, id=str(user_id))
|
||||
logger.info(f"User {user_id} soft-deleted by {current_user.id}")
|
||||
await user_service.soft_delete_user(db, str(user_id))
|
||||
logger.info("User %s soft-deleted by %s", user_id, current_user.id)
|
||||
return MessageResponse(
|
||||
success=True, message=f"User {user_id} deleted successfully"
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"Error deleting user {user_id}: {e!s}")
|
||||
logger.error("Error deleting user %s: %s", user_id, e)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error deleting user {user_id}: {e!s}", exc_info=True)
|
||||
logger.exception("Unexpected error deleting user %s: %s", user_id, e)
|
||||
raise
|
||||
|
||||
@@ -1,23 +1,21 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
import bcrypt
|
||||
import jwt
|
||||
from jwt.exceptions import (
|
||||
ExpiredSignatureError,
|
||||
InvalidTokenError,
|
||||
MissingRequiredClaimError,
|
||||
)
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.core.config import settings
|
||||
from app.schemas.users import TokenData, TokenPayload
|
||||
|
||||
# Suppress passlib bcrypt warnings about ident
|
||||
logging.getLogger("passlib").setLevel(logging.ERROR)
|
||||
|
||||
# Password hashing context
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
|
||||
# Custom exceptions for auth
|
||||
class AuthError(Exception):
|
||||
@@ -37,13 +35,16 @@ class TokenMissingClaimError(AuthError):
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""Verify a password against a hash."""
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
"""Verify a password against a bcrypt hash."""
|
||||
return bcrypt.checkpw(
|
||||
plain_password.encode("utf-8"), hashed_password.encode("utf-8")
|
||||
)
|
||||
|
||||
|
||||
def get_password_hash(password: str) -> str:
|
||||
"""Generate a password hash."""
|
||||
return pwd_context.hash(password)
|
||||
"""Generate a bcrypt password hash."""
|
||||
salt = bcrypt.gensalt()
|
||||
return bcrypt.hashpw(password.encode("utf-8"), salt).decode("utf-8")
|
||||
|
||||
|
||||
async def verify_password_async(plain_password: str, hashed_password: str) -> bool:
|
||||
@@ -60,9 +61,9 @@ async def verify_password_async(plain_password: str, hashed_password: str) -> bo
|
||||
Returns:
|
||||
True if password matches, False otherwise
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(
|
||||
None, partial(pwd_context.verify, plain_password, hashed_password)
|
||||
None, partial(verify_password, plain_password, hashed_password)
|
||||
)
|
||||
|
||||
|
||||
@@ -80,8 +81,8 @@ async def get_password_hash_async(password: str) -> str:
|
||||
Returns:
|
||||
Hashed password string
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, pwd_context.hash, password)
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(None, get_password_hash, password)
|
||||
|
||||
|
||||
def create_access_token(
|
||||
@@ -121,11 +122,7 @@ def create_access_token(
|
||||
to_encode.update(claims)
|
||||
|
||||
# Create the JWT
|
||||
encoded_jwt = jwt.encode(
|
||||
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
|
||||
)
|
||||
|
||||
return encoded_jwt
|
||||
return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||||
|
||||
|
||||
def create_refresh_token(
|
||||
@@ -154,11 +151,7 @@ def create_refresh_token(
|
||||
"type": "refresh",
|
||||
}
|
||||
|
||||
encoded_jwt = jwt.encode(
|
||||
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
|
||||
)
|
||||
|
||||
return encoded_jwt
|
||||
return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||||
|
||||
|
||||
def decode_token(token: str, verify_type: str | None = None) -> TokenPayload:
|
||||
@@ -198,7 +191,7 @@ def decode_token(token: str, verify_type: str | None = None) -> TokenPayload:
|
||||
|
||||
# Reject weak or unexpected algorithms
|
||||
# NOTE: These are defensive checks that provide defense-in-depth.
|
||||
# The python-jose library rejects these tokens BEFORE we reach here,
|
||||
# PyJWT rejects these tokens BEFORE we reach here,
|
||||
# but we keep these checks in case the library changes or is misconfigured.
|
||||
# Coverage: Marked as pragma since library catches first (see tests/core/test_auth_security.py)
|
||||
if token_algorithm == "NONE": # pragma: no cover
|
||||
@@ -219,10 +212,11 @@ def decode_token(token: str, verify_type: str | None = None) -> TokenPayload:
|
||||
token_data = TokenPayload(**payload)
|
||||
return token_data
|
||||
|
||||
except JWTError as e:
|
||||
# Check if the error is due to an expired token
|
||||
if "expired" in str(e).lower():
|
||||
raise TokenExpiredError("Token has expired")
|
||||
except ExpiredSignatureError:
|
||||
raise TokenExpiredError("Token has expired")
|
||||
except MissingRequiredClaimError as e:
|
||||
raise TokenMissingClaimError(f"Token missing required claim: {e}")
|
||||
except InvalidTokenError:
|
||||
raise TokenInvalidError("Invalid authentication token")
|
||||
except ValidationError:
|
||||
raise TokenInvalidError("Invalid token payload")
|
||||
|
||||
@@ -128,8 +128,8 @@ async def async_transaction_scope() -> AsyncGenerator[AsyncSession, None]:
|
||||
|
||||
Usage:
|
||||
async with async_transaction_scope() as db:
|
||||
user = await user_crud.create(db, obj_in=user_create)
|
||||
profile = await profile_crud.create(db, obj_in=profile_create)
|
||||
user = await user_repo.create(db, obj_in=user_create)
|
||||
profile = await profile_repo.create(db, obj_in=profile_create)
|
||||
# Both operations committed together
|
||||
"""
|
||||
async with SessionLocal() as session:
|
||||
@@ -139,7 +139,7 @@ async def async_transaction_scope() -> AsyncGenerator[AsyncSession, None]:
|
||||
logger.debug("Async transaction committed successfully")
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.error(f"Async transaction failed, rolling back: {e!s}")
|
||||
logger.error("Async transaction failed, rolling back: %s", e)
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
@@ -155,7 +155,7 @@ async def check_async_database_health() -> bool:
|
||||
await db.execute(text("SELECT 1"))
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Async database health check failed: {e!s}")
|
||||
logger.error("Async database health check failed: %s", e)
|
||||
return False
|
||||
|
||||
|
||||
|
||||
@@ -143,8 +143,11 @@ async def api_exception_handler(request: Request, exc: APIException) -> JSONResp
|
||||
Returns a standardized error response with error code and message.
|
||||
"""
|
||||
logger.warning(
|
||||
f"API exception: {exc.error_code} - {exc.message} "
|
||||
f"(status: {exc.status_code}, path: {request.url.path})"
|
||||
"API exception: %s - %s (status: %s, path: %s)",
|
||||
exc.error_code,
|
||||
exc.message,
|
||||
exc.status_code,
|
||||
request.url.path,
|
||||
)
|
||||
|
||||
error_response = ErrorResponse(
|
||||
@@ -186,7 +189,9 @@ async def validation_exception_handler(
|
||||
)
|
||||
)
|
||||
|
||||
logger.warning(f"Validation error: {len(errors)} errors (path: {request.url.path})")
|
||||
logger.warning(
|
||||
"Validation error: %s errors (path: %s)", len(errors), request.url.path
|
||||
)
|
||||
|
||||
error_response = ErrorResponse(errors=errors)
|
||||
|
||||
@@ -218,11 +223,14 @@ async def http_exception_handler(request: Request, exc: HTTPException) -> JSONRe
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
f"HTTP exception: {exc.status_code} - {exc.detail} (path: {request.url.path})"
|
||||
"HTTP exception: %s - %s (path: %s)",
|
||||
exc.status_code,
|
||||
exc.detail,
|
||||
request.url.path,
|
||||
)
|
||||
|
||||
error_response = ErrorResponse(
|
||||
errors=[ErrorDetail(code=error_code, message=str(exc.detail))]
|
||||
errors=[ErrorDetail(code=error_code, message=str(exc.detail), field=None)]
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
@@ -239,10 +247,11 @@ async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONR
|
||||
Logs the full exception and returns a generic error response to avoid
|
||||
leaking sensitive information in production.
|
||||
"""
|
||||
logger.error(
|
||||
f"Unhandled exception: {type(exc).__name__} - {exc!s} "
|
||||
f"(path: {request.url.path})",
|
||||
exc_info=True,
|
||||
logger.exception(
|
||||
"Unhandled exception: %s - %s (path: %s)",
|
||||
type(exc).__name__,
|
||||
exc,
|
||||
request.url.path,
|
||||
)
|
||||
|
||||
# In production, don't expose internal error details
|
||||
@@ -254,7 +263,7 @@ async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONR
|
||||
message = f"{type(exc).__name__}: {exc!s}"
|
||||
|
||||
error_response = ErrorResponse(
|
||||
errors=[ErrorDetail(code=ErrorCode.INTERNAL_ERROR, message=message)]
|
||||
errors=[ErrorDetail(code=ErrorCode.INTERNAL_ERROR, message=message, field=None)]
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
|
||||
26
backend/app/core/repository_exceptions.py
Normal file
26
backend/app/core/repository_exceptions.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""
|
||||
Custom exceptions for the repository layer.
|
||||
|
||||
These exceptions allow services and routes to handle database-level errors
|
||||
with proper semantics, without leaking SQLAlchemy internals.
|
||||
"""
|
||||
|
||||
|
||||
class RepositoryError(Exception):
|
||||
"""Base for all repository-layer errors."""
|
||||
|
||||
|
||||
class DuplicateEntryError(RepositoryError):
|
||||
"""Raised on unique constraint violations. Maps to HTTP 409 Conflict."""
|
||||
|
||||
|
||||
class IntegrityConstraintError(RepositoryError):
|
||||
"""Raised on FK or check constraint violations."""
|
||||
|
||||
|
||||
class RecordNotFoundError(RepositoryError):
|
||||
"""Raised when an expected record doesn't exist."""
|
||||
|
||||
|
||||
class InvalidInputError(RepositoryError):
|
||||
"""Raised on bad pagination params, invalid UUIDs, or other invalid inputs."""
|
||||
@@ -1,14 +0,0 @@
|
||||
# app/crud/__init__.py
|
||||
from .oauth import oauth_account, oauth_client, oauth_state
|
||||
from .organization import organization
|
||||
from .session import session as session_crud
|
||||
from .user import user
|
||||
|
||||
__all__ = [
|
||||
"oauth_account",
|
||||
"oauth_client",
|
||||
"oauth_state",
|
||||
"organization",
|
||||
"session_crud",
|
||||
"user",
|
||||
]
|
||||
@@ -1,718 +0,0 @@
|
||||
"""
|
||||
Async CRUD operations for OAuth models using SQLAlchemy 2.0 patterns.
|
||||
|
||||
Provides operations for:
|
||||
- OAuthAccount: Managing linked OAuth provider accounts
|
||||
- OAuthState: CSRF protection state during OAuth flows
|
||||
- OAuthClient: Registered OAuth clients (provider mode skeleton)
|
||||
"""
|
||||
|
||||
import logging
|
||||
import secrets
|
||||
from datetime import UTC, datetime
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import and_, delete, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from app.crud.base import CRUDBase
|
||||
from app.models.oauth_account import OAuthAccount
|
||||
from app.models.oauth_client import OAuthClient
|
||||
from app.models.oauth_state import OAuthState
|
||||
from app.schemas.oauth import OAuthAccountCreate, OAuthClientCreate, OAuthStateCreate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# OAuth Account CRUD
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class EmptySchema(BaseModel):
|
||||
"""Placeholder schema for CRUD operations that don't need update schemas."""
|
||||
|
||||
|
||||
class CRUDOAuthAccount(CRUDBase[OAuthAccount, OAuthAccountCreate, EmptySchema]):
|
||||
"""CRUD operations for OAuth account links."""
|
||||
|
||||
async def get_by_provider_id(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
provider: str,
|
||||
provider_user_id: str,
|
||||
) -> OAuthAccount | None:
|
||||
"""
|
||||
Get OAuth account by provider and provider user ID.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
provider: OAuth provider name (google, github)
|
||||
provider_user_id: User ID from the OAuth provider
|
||||
|
||||
Returns:
|
||||
OAuthAccount if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(OAuthAccount)
|
||||
.where(
|
||||
and_(
|
||||
OAuthAccount.provider == provider,
|
||||
OAuthAccount.provider_user_id == provider_user_id,
|
||||
)
|
||||
)
|
||||
.options(joinedload(OAuthAccount.user))
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e: # pragma: no cover # pragma: no cover
|
||||
logger.error(
|
||||
f"Error getting OAuth account for {provider}:{provider_user_id}: {e!s}"
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_by_provider_email(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
provider: str,
|
||||
email: str,
|
||||
) -> OAuthAccount | None:
|
||||
"""
|
||||
Get OAuth account by provider and email.
|
||||
|
||||
Used for auto-linking existing accounts by email.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
provider: OAuth provider name
|
||||
email: Email address from the OAuth provider
|
||||
|
||||
Returns:
|
||||
OAuthAccount if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(OAuthAccount)
|
||||
.where(
|
||||
and_(
|
||||
OAuthAccount.provider == provider,
|
||||
OAuthAccount.provider_email == email,
|
||||
)
|
||||
)
|
||||
.options(joinedload(OAuthAccount.user))
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e: # pragma: no cover # pragma: no cover
|
||||
logger.error(
|
||||
f"Error getting OAuth account for {provider} email {email}: {e!s}"
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_user_accounts(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: str | UUID,
|
||||
) -> list[OAuthAccount]:
|
||||
"""
|
||||
Get all OAuth accounts linked to a user.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
List of OAuthAccount objects
|
||||
"""
|
||||
try:
|
||||
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
|
||||
|
||||
result = await db.execute(
|
||||
select(OAuthAccount)
|
||||
.where(OAuthAccount.user_id == user_uuid)
|
||||
.order_by(OAuthAccount.created_at.desc())
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error(f"Error getting OAuth accounts for user {user_id}: {e!s}")
|
||||
raise
|
||||
|
||||
async def get_user_account_by_provider(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: str | UUID,
|
||||
provider: str,
|
||||
) -> OAuthAccount | None:
|
||||
"""
|
||||
Get a specific OAuth account for a user and provider.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_id: User ID
|
||||
provider: OAuth provider name
|
||||
|
||||
Returns:
|
||||
OAuthAccount if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
|
||||
|
||||
result = await db.execute(
|
||||
select(OAuthAccount).where(
|
||||
and_(
|
||||
OAuthAccount.user_id == user_uuid,
|
||||
OAuthAccount.provider == provider,
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error(
|
||||
f"Error getting OAuth account for user {user_id}, provider {provider}: {e!s}"
|
||||
)
|
||||
raise
|
||||
|
||||
async def create_account(
|
||||
self, db: AsyncSession, *, obj_in: OAuthAccountCreate
|
||||
) -> OAuthAccount:
|
||||
"""
|
||||
Create a new OAuth account link.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
obj_in: OAuth account creation data
|
||||
|
||||
Returns:
|
||||
Created OAuthAccount
|
||||
|
||||
Raises:
|
||||
ValueError: If account already exists or creation fails
|
||||
"""
|
||||
try:
|
||||
db_obj = OAuthAccount(
|
||||
user_id=obj_in.user_id,
|
||||
provider=obj_in.provider,
|
||||
provider_user_id=obj_in.provider_user_id,
|
||||
provider_email=obj_in.provider_email,
|
||||
access_token_encrypted=obj_in.access_token_encrypted,
|
||||
refresh_token_encrypted=obj_in.refresh_token_encrypted,
|
||||
token_expires_at=obj_in.token_expires_at,
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
|
||||
logger.info(
|
||||
f"OAuth account created: {obj_in.provider} linked to user {obj_in.user_id}"
|
||||
)
|
||||
return db_obj
|
||||
except IntegrityError as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
if "uq_oauth_provider_user" in error_msg.lower():
|
||||
logger.warning(
|
||||
f"OAuth account already exists: {obj_in.provider}:{obj_in.provider_user_id}"
|
||||
)
|
||||
raise ValueError(
|
||||
f"This {obj_in.provider} account is already linked to another user"
|
||||
)
|
||||
logger.error(f"Integrity error creating OAuth account: {error_msg}")
|
||||
raise ValueError(f"Failed to create OAuth account: {error_msg}")
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error(f"Error creating OAuth account: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def delete_account(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: str | UUID,
|
||||
provider: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Delete an OAuth account link.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_id: User ID
|
||||
provider: OAuth provider name
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
try:
|
||||
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
|
||||
|
||||
result = await db.execute(
|
||||
delete(OAuthAccount).where(
|
||||
and_(
|
||||
OAuthAccount.user_id == user_uuid,
|
||||
OAuthAccount.provider == provider,
|
||||
)
|
||||
)
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
deleted = result.rowcount > 0
|
||||
if deleted:
|
||||
logger.info(
|
||||
f"OAuth account deleted: {provider} unlinked from user {user_id}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"OAuth account not found for deletion: {provider} for user {user_id}"
|
||||
)
|
||||
|
||||
return deleted
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Error deleting OAuth account {provider} for user {user_id}: {e!s}"
|
||||
)
|
||||
raise
|
||||
|
||||
async def update_tokens(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
account: OAuthAccount,
|
||||
access_token_encrypted: str | None = None,
|
||||
refresh_token_encrypted: str | None = None,
|
||||
token_expires_at: datetime | None = None,
|
||||
) -> OAuthAccount:
|
||||
"""
|
||||
Update OAuth tokens for an account.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
account: OAuthAccount to update
|
||||
access_token_encrypted: New encrypted access token
|
||||
refresh_token_encrypted: New encrypted refresh token
|
||||
token_expires_at: New token expiration time
|
||||
|
||||
Returns:
|
||||
Updated OAuthAccount
|
||||
"""
|
||||
try:
|
||||
if access_token_encrypted is not None:
|
||||
account.access_token_encrypted = access_token_encrypted
|
||||
if refresh_token_encrypted is not None:
|
||||
account.refresh_token_encrypted = refresh_token_encrypted
|
||||
if token_expires_at is not None:
|
||||
account.token_expires_at = token_expires_at
|
||||
|
||||
db.add(account)
|
||||
await db.commit()
|
||||
await db.refresh(account)
|
||||
|
||||
return account
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error(f"Error updating OAuth tokens: {e!s}")
|
||||
raise
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# OAuth State CRUD
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class CRUDOAuthState(CRUDBase[OAuthState, OAuthStateCreate, EmptySchema]):
|
||||
"""CRUD operations for OAuth state (CSRF protection)."""
|
||||
|
||||
async def create_state(
|
||||
self, db: AsyncSession, *, obj_in: OAuthStateCreate
|
||||
) -> OAuthState:
|
||||
"""
|
||||
Create a new OAuth state for CSRF protection.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
obj_in: OAuth state creation data
|
||||
|
||||
Returns:
|
||||
Created OAuthState
|
||||
"""
|
||||
try:
|
||||
db_obj = OAuthState(
|
||||
state=obj_in.state,
|
||||
code_verifier=obj_in.code_verifier,
|
||||
nonce=obj_in.nonce,
|
||||
provider=obj_in.provider,
|
||||
redirect_uri=obj_in.redirect_uri,
|
||||
user_id=obj_in.user_id,
|
||||
expires_at=obj_in.expires_at,
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
|
||||
logger.debug(f"OAuth state created for {obj_in.provider}")
|
||||
return db_obj
|
||||
except IntegrityError as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
# State collision (extremely rare with cryptographic random)
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
logger.error(f"OAuth state collision: {error_msg}")
|
||||
raise ValueError("Failed to create OAuth state, please retry")
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error(f"Error creating OAuth state: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_and_consume_state(
|
||||
self, db: AsyncSession, *, state: str
|
||||
) -> OAuthState | None:
|
||||
"""
|
||||
Get and delete OAuth state (consume it).
|
||||
|
||||
This ensures each state can only be used once (replay protection).
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
state: State string to look up
|
||||
|
||||
Returns:
|
||||
OAuthState if found and valid, None otherwise
|
||||
"""
|
||||
try:
|
||||
# Get the state
|
||||
result = await db.execute(
|
||||
select(OAuthState).where(OAuthState.state == state)
|
||||
)
|
||||
db_obj = result.scalar_one_or_none()
|
||||
|
||||
if db_obj is None:
|
||||
logger.warning(f"OAuth state not found: {state[:8]}...")
|
||||
return None
|
||||
|
||||
# Check expiration
|
||||
# Handle both timezone-aware and timezone-naive datetimes
|
||||
now = datetime.now(UTC)
|
||||
expires_at = db_obj.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
# SQLite returns naive datetimes, assume UTC
|
||||
expires_at = expires_at.replace(tzinfo=UTC)
|
||||
|
||||
if expires_at < now:
|
||||
logger.warning(f"OAuth state expired: {state[:8]}...")
|
||||
await db.delete(db_obj)
|
||||
await db.commit()
|
||||
return None
|
||||
|
||||
# Delete it (consume)
|
||||
await db.delete(db_obj)
|
||||
await db.commit()
|
||||
|
||||
logger.debug(f"OAuth state consumed: {state[:8]}...")
|
||||
return db_obj
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error(f"Error consuming OAuth state: {e!s}")
|
||||
raise
|
||||
|
||||
async def cleanup_expired(self, db: AsyncSession) -> int:
|
||||
"""
|
||||
Clean up expired OAuth states.
|
||||
|
||||
Should be called periodically to remove stale states.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Number of states deleted
|
||||
"""
|
||||
try:
|
||||
now = datetime.now(UTC)
|
||||
|
||||
stmt = delete(OAuthState).where(OAuthState.expires_at < now)
|
||||
result = await db.execute(stmt)
|
||||
await db.commit()
|
||||
|
||||
count = result.rowcount
|
||||
if count > 0:
|
||||
logger.info(f"Cleaned up {count} expired OAuth states")
|
||||
|
||||
return count
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error(f"Error cleaning up expired OAuth states: {e!s}")
|
||||
raise
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# OAuth Client CRUD (Provider Mode - Skeleton)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class CRUDOAuthClient(CRUDBase[OAuthClient, OAuthClientCreate, EmptySchema]):
|
||||
"""
|
||||
CRUD operations for OAuth clients (provider mode).
|
||||
|
||||
This is a skeleton implementation for MCP client registration.
|
||||
Full implementation can be expanded when needed.
|
||||
"""
|
||||
|
||||
async def get_by_client_id(
|
||||
self, db: AsyncSession, *, client_id: str
|
||||
) -> OAuthClient | None:
|
||||
"""
|
||||
Get OAuth client by client_id.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
client_id: OAuth client ID
|
||||
|
||||
Returns:
|
||||
OAuthClient if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(OAuthClient).where(
|
||||
and_(
|
||||
OAuthClient.client_id == client_id,
|
||||
OAuthClient.is_active == True, # noqa: E712
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error(f"Error getting OAuth client {client_id}: {e!s}")
|
||||
raise
|
||||
|
||||
async def create_client(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
obj_in: OAuthClientCreate,
|
||||
owner_user_id: UUID | None = None,
|
||||
) -> tuple[OAuthClient, str | None]:
|
||||
"""
|
||||
Create a new OAuth client.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
obj_in: OAuth client creation data
|
||||
owner_user_id: Optional owner user ID
|
||||
|
||||
Returns:
|
||||
Tuple of (created OAuthClient, client_secret or None for public clients)
|
||||
"""
|
||||
try:
|
||||
# Generate client_id
|
||||
client_id = secrets.token_urlsafe(32)
|
||||
|
||||
# Generate client_secret for confidential clients
|
||||
client_secret = None
|
||||
client_secret_hash = None
|
||||
if obj_in.client_type == "confidential":
|
||||
client_secret = secrets.token_urlsafe(48)
|
||||
# SECURITY: Use bcrypt for secret storage (not SHA-256)
|
||||
# bcrypt is computationally expensive, making brute-force attacks infeasible
|
||||
from app.core.auth import get_password_hash
|
||||
|
||||
client_secret_hash = get_password_hash(client_secret)
|
||||
|
||||
db_obj = OAuthClient(
|
||||
client_id=client_id,
|
||||
client_secret_hash=client_secret_hash,
|
||||
client_name=obj_in.client_name,
|
||||
client_description=obj_in.client_description,
|
||||
client_type=obj_in.client_type,
|
||||
redirect_uris=obj_in.redirect_uris,
|
||||
allowed_scopes=obj_in.allowed_scopes,
|
||||
owner_user_id=owner_user_id,
|
||||
is_active=True,
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
|
||||
logger.info(
|
||||
f"OAuth client created: {obj_in.client_name} ({client_id[:8]}...)"
|
||||
)
|
||||
return db_obj, client_secret
|
||||
except IntegrityError as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
logger.error(f"Error creating OAuth client: {error_msg}")
|
||||
raise ValueError(f"Failed to create OAuth client: {error_msg}")
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error(f"Error creating OAuth client: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def deactivate_client(
|
||||
self, db: AsyncSession, *, client_id: str
|
||||
) -> OAuthClient | None:
|
||||
"""
|
||||
Deactivate an OAuth client.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
client_id: OAuth client ID
|
||||
|
||||
Returns:
|
||||
Deactivated OAuthClient if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
client = await self.get_by_client_id(db, client_id=client_id)
|
||||
if client is None:
|
||||
return None
|
||||
|
||||
client.is_active = False
|
||||
db.add(client)
|
||||
await db.commit()
|
||||
await db.refresh(client)
|
||||
|
||||
logger.info(f"OAuth client deactivated: {client.client_name}")
|
||||
return client
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error(f"Error deactivating OAuth client {client_id}: {e!s}")
|
||||
raise
|
||||
|
||||
async def validate_redirect_uri(
|
||||
self, db: AsyncSession, *, client_id: str, redirect_uri: str
|
||||
) -> bool:
|
||||
"""
|
||||
Validate that a redirect URI is allowed for a client.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
client_id: OAuth client ID
|
||||
redirect_uri: Redirect URI to validate
|
||||
|
||||
Returns:
|
||||
True if valid, False otherwise
|
||||
"""
|
||||
try:
|
||||
client = await self.get_by_client_id(db, client_id=client_id)
|
||||
if client is None:
|
||||
return False
|
||||
|
||||
return redirect_uri in (client.redirect_uris or [])
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error(f"Error validating redirect URI: {e!s}")
|
||||
return False
|
||||
|
||||
async def verify_client_secret(
|
||||
self, db: AsyncSession, *, client_id: str, client_secret: str
|
||||
) -> bool:
|
||||
"""
|
||||
Verify client credentials.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
client_id: OAuth client ID
|
||||
client_secret: Client secret to verify
|
||||
|
||||
Returns:
|
||||
True if valid, False otherwise
|
||||
"""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(OAuthClient).where(
|
||||
and_(
|
||||
OAuthClient.client_id == client_id,
|
||||
OAuthClient.is_active == True, # noqa: E712
|
||||
)
|
||||
)
|
||||
)
|
||||
client = result.scalar_one_or_none()
|
||||
|
||||
if client is None or client.client_secret_hash is None:
|
||||
return False
|
||||
|
||||
# SECURITY: Verify secret using bcrypt (not SHA-256)
|
||||
# This supports both old SHA-256 hashes (for migration) and new bcrypt hashes
|
||||
from app.core.auth import verify_password
|
||||
|
||||
stored_hash: str = str(client.client_secret_hash)
|
||||
|
||||
# Check if it's a bcrypt hash (starts with $2b$) or legacy SHA-256
|
||||
if stored_hash.startswith("$2"):
|
||||
# New bcrypt format
|
||||
return verify_password(client_secret, stored_hash)
|
||||
else:
|
||||
# Legacy SHA-256 format - still support for migration
|
||||
import hashlib
|
||||
|
||||
secret_hash = hashlib.sha256(client_secret.encode()).hexdigest()
|
||||
return secrets.compare_digest(stored_hash, secret_hash)
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error(f"Error verifying client secret: {e!s}")
|
||||
return False
|
||||
|
||||
async def get_all_clients(
|
||||
self, db: AsyncSession, *, include_inactive: bool = False
|
||||
) -> list[OAuthClient]:
|
||||
"""
|
||||
Get all OAuth clients.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
include_inactive: Whether to include inactive clients
|
||||
|
||||
Returns:
|
||||
List of OAuthClient objects
|
||||
"""
|
||||
try:
|
||||
query = select(OAuthClient).order_by(OAuthClient.created_at.desc())
|
||||
if not include_inactive:
|
||||
query = query.where(OAuthClient.is_active == True) # noqa: E712
|
||||
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error(f"Error getting all OAuth clients: {e!s}")
|
||||
raise
|
||||
|
||||
async def delete_client(self, db: AsyncSession, *, client_id: str) -> bool:
|
||||
"""
|
||||
Delete an OAuth client permanently.
|
||||
|
||||
Note: This will cascade delete related records (tokens, consents, etc.)
|
||||
due to foreign key constraints.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
client_id: OAuth client ID
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
try:
|
||||
result = await db.execute(
|
||||
delete(OAuthClient).where(OAuthClient.client_id == client_id)
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
deleted = result.rowcount > 0
|
||||
if deleted:
|
||||
logger.info(f"OAuth client deleted: {client_id}")
|
||||
else:
|
||||
logger.warning(f"OAuth client not found for deletion: {client_id}")
|
||||
|
||||
return deleted
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error(f"Error deleting OAuth client {client_id}: {e!s}")
|
||||
raise
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Singleton instances
|
||||
# ============================================================================
|
||||
|
||||
oauth_account = CRUDOAuthAccount(OAuthAccount)
|
||||
oauth_state = CRUDOAuthState(OAuthState)
|
||||
oauth_client = CRUDOAuthClient(OAuthClient)
|
||||
@@ -16,10 +16,10 @@ from sqlalchemy import select, text
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.database import SessionLocal, engine
|
||||
from app.crud.user import user as user_crud
|
||||
from app.models.organization import Organization
|
||||
from app.models.user import User
|
||||
from app.models.user_organization import UserOrganization
|
||||
from app.repositories.user import user_repo as user_repo
|
||||
from app.schemas.users import UserCreate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -44,16 +44,17 @@ async def init_db() -> User | None:
|
||||
if not settings.FIRST_SUPERUSER_EMAIL or not settings.FIRST_SUPERUSER_PASSWORD:
|
||||
logger.warning(
|
||||
"First superuser credentials not configured in settings. "
|
||||
f"Using defaults: {superuser_email}"
|
||||
"Using defaults: %s",
|
||||
superuser_email,
|
||||
)
|
||||
|
||||
async with SessionLocal() as session:
|
||||
try:
|
||||
# Check if superuser already exists
|
||||
existing_user = await user_crud.get_by_email(session, email=superuser_email)
|
||||
existing_user = await user_repo.get_by_email(session, email=superuser_email)
|
||||
|
||||
if existing_user:
|
||||
logger.info(f"Superuser already exists: {existing_user.email}")
|
||||
logger.info("Superuser already exists: %s", existing_user.email)
|
||||
return existing_user
|
||||
|
||||
# Create superuser if doesn't exist
|
||||
@@ -65,11 +66,11 @@ async def init_db() -> User | None:
|
||||
is_superuser=True,
|
||||
)
|
||||
|
||||
user = await user_crud.create(session, obj_in=user_in)
|
||||
user = await user_repo.create(session, obj_in=user_in)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
|
||||
logger.info(f"Created first superuser: {user.email}")
|
||||
logger.info("Created first superuser: %s", user.email)
|
||||
|
||||
# Create demo data if in demo mode
|
||||
if settings.DEMO_MODE:
|
||||
@@ -79,7 +80,7 @@ async def init_db() -> User | None:
|
||||
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.error(f"Error initializing database: {e}")
|
||||
logger.error("Error initializing database: %s", e)
|
||||
raise
|
||||
|
||||
|
||||
@@ -92,7 +93,7 @@ async def load_demo_data(session):
|
||||
"""Load demo data from JSON file."""
|
||||
demo_data_path = Path(__file__).parent / "core" / "demo_data.json"
|
||||
if not demo_data_path.exists():
|
||||
logger.warning(f"Demo data file not found: {demo_data_path}")
|
||||
logger.warning("Demo data file not found: %s", demo_data_path)
|
||||
return
|
||||
|
||||
try:
|
||||
@@ -119,7 +120,7 @@ async def load_demo_data(session):
|
||||
session.add(org)
|
||||
await session.flush() # Flush to get ID
|
||||
org_map[org.slug] = org
|
||||
logger.info(f"Created demo organization: {org.name}")
|
||||
logger.info("Created demo organization: %s", org.name)
|
||||
else:
|
||||
# We can't easily get the ORM object from raw SQL result for map without querying again or mapping
|
||||
# So let's just query it properly if we need it for relationships
|
||||
@@ -135,7 +136,7 @@ async def load_demo_data(session):
|
||||
|
||||
# Create Users
|
||||
for user_data in data.get("users", []):
|
||||
existing_user = await user_crud.get_by_email(
|
||||
existing_user = await user_repo.get_by_email(
|
||||
session, email=user_data["email"]
|
||||
)
|
||||
if not existing_user:
|
||||
@@ -148,7 +149,7 @@ async def load_demo_data(session):
|
||||
is_superuser=user_data["is_superuser"],
|
||||
is_active=user_data.get("is_active", True),
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_in)
|
||||
user = await user_repo.create(session, obj_in=user_in)
|
||||
|
||||
# Randomize created_at for demo data (last 30 days)
|
||||
# This makes the charts look more realistic
|
||||
@@ -174,7 +175,10 @@ async def load_demo_data(session):
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Created demo user: {user.email} (created {days_ago} days ago, active={user_data.get('is_active', True)})"
|
||||
"Created demo user: %s (created %s days ago, active=%s)",
|
||||
user.email,
|
||||
days_ago,
|
||||
user_data.get("is_active", True),
|
||||
)
|
||||
|
||||
# Add to organization if specified
|
||||
@@ -187,15 +191,15 @@ async def load_demo_data(session):
|
||||
user_id=user.id, organization_id=org.id, role=role
|
||||
)
|
||||
session.add(member)
|
||||
logger.info(f"Added {user.email} to {org.name} as {role}")
|
||||
logger.info("Added %s to %s as %s", user.email, org.name, role)
|
||||
else:
|
||||
logger.info(f"Demo user already exists: {existing_user.email}")
|
||||
logger.info("Demo user already exists: %s", existing_user.email)
|
||||
|
||||
await session.commit()
|
||||
logger.info("Demo data loaded successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading demo data: {e}")
|
||||
logger.error("Error loading demo data: %s", e)
|
||||
raise
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
@@ -16,7 +16,7 @@ from slowapi.util import get_remote_address
|
||||
from app.api.main import api_router
|
||||
from app.api.routes.oauth_provider import wellknown_router as oauth_wellknown_router
|
||||
from app.core.config import settings
|
||||
from app.core.database import check_database_health
|
||||
from app.core.database import check_database_health, close_async_db
|
||||
from app.core.exceptions import (
|
||||
APIException,
|
||||
api_exception_handler,
|
||||
@@ -72,6 +72,7 @@ async def lifespan(app: FastAPI):
|
||||
if os.getenv("IS_TEST", "False") != "True":
|
||||
scheduler.shutdown()
|
||||
logger.info("Scheduled jobs stopped")
|
||||
await close_async_db()
|
||||
|
||||
|
||||
logger.info("Starting app!!!")
|
||||
@@ -294,7 +295,7 @@ async def health_check() -> JSONResponse:
|
||||
"""
|
||||
health_status: dict[str, Any] = {
|
||||
"status": "healthy",
|
||||
"timestamp": datetime.utcnow().isoformat() + "Z",
|
||||
"timestamp": datetime.now(UTC).isoformat().replace("+00:00", "Z"),
|
||||
"version": settings.VERSION,
|
||||
"environment": settings.ENVIRONMENT,
|
||||
"checks": {},
|
||||
@@ -319,7 +320,7 @@ async def health_check() -> JSONResponse:
|
||||
"message": f"Database connection failed: {e!s}",
|
||||
}
|
||||
response_status = status.HTTP_503_SERVICE_UNAVAILABLE
|
||||
logger.error(f"Health check failed - database error: {e}")
|
||||
logger.error("Health check failed - database error: %s", e)
|
||||
|
||||
return JSONResponse(status_code=response_status, content=health_status)
|
||||
|
||||
|
||||
@@ -36,9 +36,9 @@ class OAuthAccount(Base, UUIDMixin, TimestampMixin):
|
||||
) # Email from provider (for reference)
|
||||
|
||||
# Optional: store provider tokens for API access
|
||||
# These should be encrypted at rest in production
|
||||
access_token_encrypted = Column(String(2048), nullable=True)
|
||||
refresh_token_encrypted = Column(String(2048), nullable=True)
|
||||
# TODO: Encrypt these at rest in production (requires key management infrastructure)
|
||||
access_token = Column(String(2048), nullable=True)
|
||||
refresh_token = Column(String(2048), nullable=True)
|
||||
token_expires_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Relationship
|
||||
|
||||
@@ -92,7 +92,7 @@ class OAuthAuthorizationCode(Base, UUIDMixin, TimestampMixin):
|
||||
# Handle both timezone-aware and naive datetimes from DB
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=UTC)
|
||||
return now > expires_at
|
||||
return bool(now > expires_at)
|
||||
|
||||
@property
|
||||
def is_valid(self) -> bool:
|
||||
|
||||
@@ -99,7 +99,7 @@ class OAuthProviderRefreshToken(Base, UUIDMixin, TimestampMixin):
|
||||
# Handle both timezone-aware and naive datetimes from DB
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=UTC)
|
||||
return now > expires_at
|
||||
return bool(now > expires_at)
|
||||
|
||||
@property
|
||||
def is_valid(self) -> bool:
|
||||
|
||||
@@ -76,7 +76,11 @@ class UserSession(Base, UUIDMixin, TimestampMixin):
|
||||
"""Check if session has expired."""
|
||||
from datetime import datetime
|
||||
|
||||
return self.expires_at < datetime.now(UTC)
|
||||
now = datetime.now(UTC)
|
||||
expires_at = self.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=UTC)
|
||||
return bool(expires_at < now)
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert session to dictionary for serialization."""
|
||||
|
||||
39
backend/app/repositories/__init__.py
Normal file
39
backend/app/repositories/__init__.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# app/repositories/__init__.py
|
||||
"""Repository layer — all database access goes through these classes."""
|
||||
|
||||
from app.repositories.oauth_account import OAuthAccountRepository, oauth_account_repo
|
||||
from app.repositories.oauth_authorization_code import (
|
||||
OAuthAuthorizationCodeRepository,
|
||||
oauth_authorization_code_repo,
|
||||
)
|
||||
from app.repositories.oauth_client import OAuthClientRepository, oauth_client_repo
|
||||
from app.repositories.oauth_consent import OAuthConsentRepository, oauth_consent_repo
|
||||
from app.repositories.oauth_provider_token import (
|
||||
OAuthProviderTokenRepository,
|
||||
oauth_provider_token_repo,
|
||||
)
|
||||
from app.repositories.oauth_state import OAuthStateRepository, oauth_state_repo
|
||||
from app.repositories.organization import OrganizationRepository, organization_repo
|
||||
from app.repositories.session import SessionRepository, session_repo
|
||||
from app.repositories.user import UserRepository, user_repo
|
||||
|
||||
__all__ = [
|
||||
"OAuthAccountRepository",
|
||||
"OAuthAuthorizationCodeRepository",
|
||||
"OAuthClientRepository",
|
||||
"OAuthConsentRepository",
|
||||
"OAuthProviderTokenRepository",
|
||||
"OAuthStateRepository",
|
||||
"OrganizationRepository",
|
||||
"SessionRepository",
|
||||
"UserRepository",
|
||||
"oauth_account_repo",
|
||||
"oauth_authorization_code_repo",
|
||||
"oauth_client_repo",
|
||||
"oauth_consent_repo",
|
||||
"oauth_provider_token_repo",
|
||||
"oauth_state_repo",
|
||||
"organization_repo",
|
||||
"session_repo",
|
||||
"user_repo",
|
||||
]
|
||||
177
backend/app/crud/base.py → backend/app/repositories/base.py
Executable file → Normal file
177
backend/app/crud/base.py → backend/app/repositories/base.py
Executable file → Normal file
@@ -1,6 +1,6 @@
|
||||
# app/crud/base_async.py
|
||||
# app/repositories/base.py
|
||||
"""
|
||||
Async CRUD operations base class using SQLAlchemy 2.0 async patterns.
|
||||
Base repository class for async database operations using SQLAlchemy 2.0 async patterns.
|
||||
|
||||
Provides reusable create, read, update, and delete operations for all models.
|
||||
"""
|
||||
@@ -18,6 +18,11 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Load
|
||||
|
||||
from app.core.database import Base
|
||||
from app.core.repository_exceptions import (
|
||||
DuplicateEntryError,
|
||||
IntegrityConstraintError,
|
||||
InvalidInputError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -26,16 +31,16 @@ CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
|
||||
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
|
||||
|
||||
|
||||
class CRUDBase[
|
||||
class BaseRepository[
|
||||
ModelType: Base,
|
||||
CreateSchemaType: BaseModel,
|
||||
UpdateSchemaType: BaseModel,
|
||||
]:
|
||||
"""Async CRUD operations for a model."""
|
||||
"""Async repository operations for a model."""
|
||||
|
||||
def __init__(self, model: type[ModelType]):
|
||||
"""
|
||||
CRUD object with default async methods to Create, Read, Update, Delete.
|
||||
Repository object with default async methods to Create, Read, Update, Delete.
|
||||
|
||||
Parameters:
|
||||
model: A SQLAlchemy model class
|
||||
@@ -56,26 +61,19 @@ class CRUDBase[
|
||||
|
||||
Returns:
|
||||
Model instance or None if not found
|
||||
|
||||
Example:
|
||||
# Eager load user relationship
|
||||
from sqlalchemy.orm import joinedload
|
||||
session = await session_crud.get(db, id=session_id, options=[joinedload(UserSession.user)])
|
||||
"""
|
||||
# Validate UUID format and convert to UUID object if string
|
||||
try:
|
||||
if isinstance(id, uuid.UUID):
|
||||
uuid_obj = id
|
||||
else:
|
||||
uuid_obj = uuid.UUID(str(id))
|
||||
except (ValueError, AttributeError, TypeError) as e:
|
||||
logger.warning(f"Invalid UUID format: {id} - {e!s}")
|
||||
logger.warning("Invalid UUID format: %s - %s", id, e)
|
||||
return None
|
||||
|
||||
try:
|
||||
query = select(self.model).where(self.model.id == uuid_obj)
|
||||
|
||||
# Apply eager loading options if provided
|
||||
if options:
|
||||
for option in options:
|
||||
query = query.options(option)
|
||||
@@ -83,7 +81,9 @@ class CRUDBase[
|
||||
result = await db.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving {self.model.__name__} with id {id}: {e!s}")
|
||||
logger.error(
|
||||
"Error retrieving %s with id %s: %s", self.model.__name__, id, e
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_multi(
|
||||
@@ -96,28 +96,17 @@ class CRUDBase[
|
||||
) -> list[ModelType]:
|
||||
"""
|
||||
Get multiple records with pagination validation and optional eager loading.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
skip: Number of records to skip
|
||||
limit: Maximum number of records to return
|
||||
options: Optional list of SQLAlchemy load options for eager loading
|
||||
|
||||
Returns:
|
||||
List of model instances
|
||||
"""
|
||||
# Validate pagination parameters
|
||||
if skip < 0:
|
||||
raise ValueError("skip must be non-negative")
|
||||
raise InvalidInputError("skip must be non-negative")
|
||||
if limit < 0:
|
||||
raise ValueError("limit must be non-negative")
|
||||
raise InvalidInputError("limit must be non-negative")
|
||||
if limit > 1000:
|
||||
raise ValueError("Maximum limit is 1000")
|
||||
raise InvalidInputError("Maximum limit is 1000")
|
||||
|
||||
try:
|
||||
query = select(self.model).offset(skip).limit(limit)
|
||||
query = select(self.model).order_by(self.model.id).offset(skip).limit(limit)
|
||||
|
||||
# Apply eager loading options if provided
|
||||
if options:
|
||||
for option in options:
|
||||
query = query.options(option)
|
||||
@@ -126,7 +115,7 @@ class CRUDBase[
|
||||
return list(result.scalars().all())
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error retrieving multiple {self.model.__name__} records: {e!s}"
|
||||
"Error retrieving multiple %s records: %s", self.model.__name__, e
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -136,9 +125,8 @@ class CRUDBase[
|
||||
"""Create a new record with error handling.
|
||||
|
||||
NOTE: This method is defensive code that's never called in practice.
|
||||
All CRUD subclasses (CRUDUser, CRUDOrganization, CRUDSession) override this method
|
||||
with their own implementations, so the base implementation and its exception handlers
|
||||
are never executed. Marked as pragma: no cover to avoid false coverage gaps.
|
||||
All repository subclasses override this method with their own implementations.
|
||||
Marked as pragma: no cover to avoid false coverage gaps.
|
||||
"""
|
||||
try: # pragma: no cover
|
||||
obj_in_data = jsonable_encoder(obj_in)
|
||||
@@ -152,22 +140,24 @@ class CRUDBase[
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
|
||||
logger.warning(
|
||||
f"Duplicate entry attempted for {self.model.__name__}: {error_msg}"
|
||||
"Duplicate entry attempted for %s: %s",
|
||||
self.model.__name__,
|
||||
error_msg,
|
||||
)
|
||||
raise ValueError(
|
||||
raise DuplicateEntryError(
|
||||
f"A {self.model.__name__} with this data already exists"
|
||||
)
|
||||
logger.error(f"Integrity error creating {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
logger.error(
|
||||
"Integrity error creating %s: %s", self.model.__name__, error_msg
|
||||
)
|
||||
raise IntegrityConstraintError(f"Database integrity error: {error_msg}")
|
||||
except (OperationalError, DataError) as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error(f"Database error creating {self.model.__name__}: {e!s}")
|
||||
raise ValueError(f"Database operation failed: {e!s}")
|
||||
logger.error("Database error creating %s: %s", self.model.__name__, e)
|
||||
raise IntegrityConstraintError(f"Database operation failed: {e!s}")
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Unexpected error creating {self.model.__name__}: {e!s}", exc_info=True
|
||||
)
|
||||
logger.exception("Unexpected error creating %s: %s", self.model.__name__, e)
|
||||
raise
|
||||
|
||||
async def update(
|
||||
@@ -198,34 +188,35 @@ class CRUDBase[
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
|
||||
logger.warning(
|
||||
f"Duplicate entry attempted for {self.model.__name__}: {error_msg}"
|
||||
"Duplicate entry attempted for %s: %s",
|
||||
self.model.__name__,
|
||||
error_msg,
|
||||
)
|
||||
raise ValueError(
|
||||
raise DuplicateEntryError(
|
||||
f"A {self.model.__name__} with this data already exists"
|
||||
)
|
||||
logger.error(f"Integrity error updating {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
logger.error(
|
||||
"Integrity error updating %s: %s", self.model.__name__, error_msg
|
||||
)
|
||||
raise IntegrityConstraintError(f"Database integrity error: {error_msg}")
|
||||
except (OperationalError, DataError) as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Database error updating {self.model.__name__}: {e!s}")
|
||||
raise ValueError(f"Database operation failed: {e!s}")
|
||||
logger.error("Database error updating %s: %s", self.model.__name__, e)
|
||||
raise IntegrityConstraintError(f"Database operation failed: {e!s}")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Unexpected error updating {self.model.__name__}: {e!s}", exc_info=True
|
||||
)
|
||||
logger.exception("Unexpected error updating %s: %s", self.model.__name__, e)
|
||||
raise
|
||||
|
||||
async def remove(self, db: AsyncSession, *, id: str) -> ModelType | None:
|
||||
"""Delete a record with error handling and null check."""
|
||||
# Validate UUID format and convert to UUID object if string
|
||||
try:
|
||||
if isinstance(id, uuid.UUID):
|
||||
uuid_obj = id
|
||||
else:
|
||||
uuid_obj = uuid.UUID(str(id))
|
||||
except (ValueError, AttributeError, TypeError) as e:
|
||||
logger.warning(f"Invalid UUID format for deletion: {id} - {e!s}")
|
||||
logger.warning("Invalid UUID format for deletion: %s - %s", id, e)
|
||||
return None
|
||||
|
||||
try:
|
||||
@@ -236,7 +227,7 @@ class CRUDBase[
|
||||
|
||||
if obj is None:
|
||||
logger.warning(
|
||||
f"{self.model.__name__} with id {id} not found for deletion"
|
||||
"%s with id %s not found for deletion", self.model.__name__, id
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -246,15 +237,16 @@ class CRUDBase[
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
logger.error(f"Integrity error deleting {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(
|
||||
logger.error(
|
||||
"Integrity error deleting %s: %s", self.model.__name__, error_msg
|
||||
)
|
||||
raise IntegrityConstraintError(
|
||||
f"Cannot delete {self.model.__name__}: referenced by other records"
|
||||
)
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Error deleting {self.model.__name__} with id {id}: {e!s}",
|
||||
exc_info=True,
|
||||
logger.exception(
|
||||
"Error deleting %s with id %s: %s", self.model.__name__, id, e
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -272,57 +264,40 @@ class CRUDBase[
|
||||
Get multiple records with total count, filtering, and sorting.
|
||||
|
||||
NOTE: This method is defensive code that's never called in practice.
|
||||
All CRUD subclasses (CRUDUser, CRUDOrganization, CRUDSession) override this method
|
||||
with their own implementations that include additional parameters like search.
|
||||
All repository subclasses override this method with their own implementations.
|
||||
Marked as pragma: no cover to avoid false coverage gaps.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
skip: Number of records to skip
|
||||
limit: Maximum number of records to return
|
||||
sort_by: Field name to sort by (must be a valid model attribute)
|
||||
sort_order: Sort order ("asc" or "desc")
|
||||
filters: Dictionary of filters (field_name: value)
|
||||
|
||||
Returns:
|
||||
Tuple of (items, total_count)
|
||||
"""
|
||||
# Validate pagination parameters
|
||||
if skip < 0:
|
||||
raise ValueError("skip must be non-negative")
|
||||
raise InvalidInputError("skip must be non-negative")
|
||||
if limit < 0:
|
||||
raise ValueError("limit must be non-negative")
|
||||
raise InvalidInputError("limit must be non-negative")
|
||||
if limit > 1000:
|
||||
raise ValueError("Maximum limit is 1000")
|
||||
raise InvalidInputError("Maximum limit is 1000")
|
||||
|
||||
try:
|
||||
# Build base query
|
||||
query = select(self.model)
|
||||
|
||||
# Exclude soft-deleted records by default
|
||||
if hasattr(self.model, "deleted_at"):
|
||||
query = query.where(self.model.deleted_at.is_(None))
|
||||
|
||||
# Apply filters
|
||||
if filters:
|
||||
for field, value in filters.items():
|
||||
if hasattr(self.model, field) and value is not None:
|
||||
query = query.where(getattr(self.model, field) == value)
|
||||
|
||||
# Get total count (before pagination)
|
||||
count_query = select(func.count()).select_from(query.alias())
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# Apply sorting
|
||||
if sort_by and hasattr(self.model, sort_by):
|
||||
sort_column = getattr(self.model, sort_by)
|
||||
if sort_order.lower() == "desc":
|
||||
query = query.order_by(sort_column.desc())
|
||||
else:
|
||||
query = query.order_by(sort_column.asc())
|
||||
else:
|
||||
query = query.order_by(self.model.id)
|
||||
|
||||
# Apply pagination
|
||||
query = query.offset(skip).limit(limit)
|
||||
items_result = await db.execute(query)
|
||||
items = list(items_result.scalars().all())
|
||||
@@ -330,7 +305,7 @@ class CRUDBase[
|
||||
return items, total
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error(
|
||||
f"Error retrieving paginated {self.model.__name__} records: {e!s}"
|
||||
"Error retrieving paginated %s records: %s", self.model.__name__, e
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -340,7 +315,7 @@ class CRUDBase[
|
||||
result = await db.execute(select(func.count(self.model.id)))
|
||||
return result.scalar_one()
|
||||
except Exception as e:
|
||||
logger.error(f"Error counting {self.model.__name__} records: {e!s}")
|
||||
logger.error("Error counting %s records: %s", self.model.__name__, e)
|
||||
raise
|
||||
|
||||
async def exists(self, db: AsyncSession, id: str) -> bool:
|
||||
@@ -356,14 +331,13 @@ class CRUDBase[
|
||||
"""
|
||||
from datetime import datetime
|
||||
|
||||
# Validate UUID format and convert to UUID object if string
|
||||
try:
|
||||
if isinstance(id, uuid.UUID):
|
||||
uuid_obj = id
|
||||
else:
|
||||
uuid_obj = uuid.UUID(str(id))
|
||||
except (ValueError, AttributeError, TypeError) as e:
|
||||
logger.warning(f"Invalid UUID format for soft deletion: {id} - {e!s}")
|
||||
logger.warning("Invalid UUID format for soft deletion: %s - %s", id, e)
|
||||
return None
|
||||
|
||||
try:
|
||||
@@ -374,18 +348,16 @@ class CRUDBase[
|
||||
|
||||
if obj is None:
|
||||
logger.warning(
|
||||
f"{self.model.__name__} with id {id} not found for soft deletion"
|
||||
"%s with id %s not found for soft deletion", self.model.__name__, id
|
||||
)
|
||||
return None
|
||||
|
||||
# Check if model supports soft deletes
|
||||
if not hasattr(self.model, "deleted_at"):
|
||||
logger.error(f"{self.model.__name__} does not support soft deletes")
|
||||
raise ValueError(
|
||||
logger.error("%s does not support soft deletes", self.model.__name__)
|
||||
raise InvalidInputError(
|
||||
f"{self.model.__name__} does not have a deleted_at column"
|
||||
)
|
||||
|
||||
# Set deleted_at timestamp
|
||||
obj.deleted_at = datetime.now(UTC)
|
||||
db.add(obj)
|
||||
await db.commit()
|
||||
@@ -393,9 +365,8 @@ class CRUDBase[
|
||||
return obj
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Error soft deleting {self.model.__name__} with id {id}: {e!s}",
|
||||
exc_info=True,
|
||||
logger.exception(
|
||||
"Error soft deleting %s with id %s: %s", self.model.__name__, id, e
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -405,18 +376,16 @@ class CRUDBase[
|
||||
|
||||
Only works if the model has a 'deleted_at' column.
|
||||
"""
|
||||
# Validate UUID format
|
||||
try:
|
||||
if isinstance(id, uuid.UUID):
|
||||
uuid_obj = id
|
||||
else:
|
||||
uuid_obj = uuid.UUID(str(id))
|
||||
except (ValueError, AttributeError, TypeError) as e:
|
||||
logger.warning(f"Invalid UUID format for restoration: {id} - {e!s}")
|
||||
logger.warning("Invalid UUID format for restoration: %s - %s", id, e)
|
||||
return None
|
||||
|
||||
try:
|
||||
# Find the soft-deleted record
|
||||
if hasattr(self.model, "deleted_at"):
|
||||
result = await db.execute(
|
||||
select(self.model).where(
|
||||
@@ -425,18 +394,19 @@ class CRUDBase[
|
||||
)
|
||||
obj = result.scalar_one_or_none()
|
||||
else:
|
||||
logger.error(f"{self.model.__name__} does not support soft deletes")
|
||||
raise ValueError(
|
||||
logger.error("%s does not support soft deletes", self.model.__name__)
|
||||
raise InvalidInputError(
|
||||
f"{self.model.__name__} does not have a deleted_at column"
|
||||
)
|
||||
|
||||
if obj is None:
|
||||
logger.warning(
|
||||
f"Soft-deleted {self.model.__name__} with id {id} not found for restoration"
|
||||
"Soft-deleted %s with id %s not found for restoration",
|
||||
self.model.__name__,
|
||||
id,
|
||||
)
|
||||
return None
|
||||
|
||||
# Clear deleted_at timestamp
|
||||
obj.deleted_at = None
|
||||
db.add(obj)
|
||||
await db.commit()
|
||||
@@ -444,8 +414,7 @@ class CRUDBase[
|
||||
return obj
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Error restoring {self.model.__name__} with id {id}: {e!s}",
|
||||
exc_info=True,
|
||||
logger.exception(
|
||||
"Error restoring %s with id %s: %s", self.model.__name__, id, e
|
||||
)
|
||||
raise
|
||||
249
backend/app/repositories/oauth_account.py
Normal file
249
backend/app/repositories/oauth_account.py
Normal file
@@ -0,0 +1,249 @@
|
||||
# app/repositories/oauth_account.py
|
||||
"""Repository for OAuthAccount model async database operations."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import and_, delete, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from app.core.repository_exceptions import DuplicateEntryError
|
||||
from app.models.oauth_account import OAuthAccount
|
||||
from app.repositories.base import BaseRepository
|
||||
from app.schemas.oauth import OAuthAccountCreate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmptySchema(BaseModel):
|
||||
"""Placeholder schema for repository operations that don't need update schemas."""
|
||||
|
||||
|
||||
class OAuthAccountRepository(
|
||||
BaseRepository[OAuthAccount, OAuthAccountCreate, EmptySchema]
|
||||
):
|
||||
"""Repository for OAuth account links."""
|
||||
|
||||
async def get_by_provider_id(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
provider: str,
|
||||
provider_user_id: str,
|
||||
) -> OAuthAccount | None:
|
||||
"""Get OAuth account by provider and provider user ID."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(OAuthAccount)
|
||||
.where(
|
||||
and_(
|
||||
OAuthAccount.provider == provider,
|
||||
OAuthAccount.provider_user_id == provider_user_id,
|
||||
)
|
||||
)
|
||||
.options(joinedload(OAuthAccount.user))
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error(
|
||||
"Error getting OAuth account for %s:%s: %s",
|
||||
provider,
|
||||
provider_user_id,
|
||||
e,
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_by_provider_email(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
provider: str,
|
||||
email: str,
|
||||
) -> OAuthAccount | None:
|
||||
"""Get OAuth account by provider and email."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(OAuthAccount)
|
||||
.where(
|
||||
and_(
|
||||
OAuthAccount.provider == provider,
|
||||
OAuthAccount.provider_email == email,
|
||||
)
|
||||
)
|
||||
.options(joinedload(OAuthAccount.user))
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error(
|
||||
"Error getting OAuth account for %s email %s: %s", provider, email, e
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_user_accounts(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: str | UUID,
|
||||
) -> list[OAuthAccount]:
|
||||
"""Get all OAuth accounts linked to a user."""
|
||||
try:
|
||||
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
|
||||
|
||||
result = await db.execute(
|
||||
select(OAuthAccount)
|
||||
.where(OAuthAccount.user_id == user_uuid)
|
||||
.order_by(OAuthAccount.created_at.desc())
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error("Error getting OAuth accounts for user %s: %s", user_id, e)
|
||||
raise
|
||||
|
||||
async def get_user_account_by_provider(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: str | UUID,
|
||||
provider: str,
|
||||
) -> OAuthAccount | None:
|
||||
"""Get a specific OAuth account for a user and provider."""
|
||||
try:
|
||||
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
|
||||
|
||||
result = await db.execute(
|
||||
select(OAuthAccount).where(
|
||||
and_(
|
||||
OAuthAccount.user_id == user_uuid,
|
||||
OAuthAccount.provider == provider,
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error(
|
||||
"Error getting OAuth account for user %s, provider %s: %s",
|
||||
user_id,
|
||||
provider,
|
||||
e,
|
||||
)
|
||||
raise
|
||||
|
||||
async def create_account(
|
||||
self, db: AsyncSession, *, obj_in: OAuthAccountCreate
|
||||
) -> OAuthAccount:
|
||||
"""Create a new OAuth account link."""
|
||||
try:
|
||||
db_obj = OAuthAccount(
|
||||
user_id=obj_in.user_id,
|
||||
provider=obj_in.provider,
|
||||
provider_user_id=obj_in.provider_user_id,
|
||||
provider_email=obj_in.provider_email,
|
||||
access_token=obj_in.access_token,
|
||||
refresh_token=obj_in.refresh_token,
|
||||
token_expires_at=obj_in.token_expires_at,
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
|
||||
logger.info(
|
||||
"OAuth account created: %s linked to user %s",
|
||||
obj_in.provider,
|
||||
obj_in.user_id,
|
||||
)
|
||||
return db_obj
|
||||
except IntegrityError as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
if "uq_oauth_provider_user" in error_msg.lower():
|
||||
logger.warning(
|
||||
"OAuth account already exists: %s:%s",
|
||||
obj_in.provider,
|
||||
obj_in.provider_user_id,
|
||||
)
|
||||
raise DuplicateEntryError(
|
||||
f"This {obj_in.provider} account is already linked to another user"
|
||||
)
|
||||
logger.error("Integrity error creating OAuth account: %s", error_msg)
|
||||
raise DuplicateEntryError(f"Failed to create OAuth account: {error_msg}")
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.exception("Error creating OAuth account: %s", e)
|
||||
raise
|
||||
|
||||
async def delete_account(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: str | UUID,
|
||||
provider: str,
|
||||
) -> bool:
|
||||
"""Delete an OAuth account link."""
|
||||
try:
|
||||
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
|
||||
|
||||
result = await db.execute(
|
||||
delete(OAuthAccount).where(
|
||||
and_(
|
||||
OAuthAccount.user_id == user_uuid,
|
||||
OAuthAccount.provider == provider,
|
||||
)
|
||||
)
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
deleted = result.rowcount > 0
|
||||
if deleted:
|
||||
logger.info(
|
||||
"OAuth account deleted: %s unlinked from user %s", provider, user_id
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"OAuth account not found for deletion: %s for user %s",
|
||||
provider,
|
||||
user_id,
|
||||
)
|
||||
|
||||
return deleted
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
"Error deleting OAuth account %s for user %s: %s", provider, user_id, e
|
||||
)
|
||||
raise
|
||||
|
||||
async def update_tokens(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
account: OAuthAccount,
|
||||
access_token: str | None = None,
|
||||
refresh_token: str | None = None,
|
||||
token_expires_at: datetime | None = None,
|
||||
) -> OAuthAccount:
|
||||
"""Update OAuth tokens for an account."""
|
||||
try:
|
||||
if access_token is not None:
|
||||
account.access_token = access_token
|
||||
if refresh_token is not None:
|
||||
account.refresh_token = refresh_token
|
||||
if token_expires_at is not None:
|
||||
account.token_expires_at = token_expires_at
|
||||
|
||||
db.add(account)
|
||||
await db.commit()
|
||||
await db.refresh(account)
|
||||
|
||||
return account
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error("Error updating OAuth tokens: %s", e)
|
||||
raise
|
||||
|
||||
|
||||
# Singleton instance
|
||||
oauth_account_repo = OAuthAccountRepository(OAuthAccount)
|
||||
108
backend/app/repositories/oauth_authorization_code.py
Normal file
108
backend/app/repositories/oauth_authorization_code.py
Normal file
@@ -0,0 +1,108 @@
|
||||
# app/repositories/oauth_authorization_code.py
|
||||
"""Repository for OAuthAuthorizationCode model."""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_, delete, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.oauth_authorization_code import OAuthAuthorizationCode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuthAuthorizationCodeRepository:
|
||||
"""Repository for OAuth 2.0 authorization codes."""
|
||||
|
||||
async def create_code(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
code: str,
|
||||
client_id: str,
|
||||
user_id: UUID,
|
||||
redirect_uri: str,
|
||||
scope: str,
|
||||
expires_at: datetime,
|
||||
code_challenge: str | None = None,
|
||||
code_challenge_method: str | None = None,
|
||||
state: str | None = None,
|
||||
nonce: str | None = None,
|
||||
) -> OAuthAuthorizationCode:
|
||||
"""Create and persist a new authorization code."""
|
||||
auth_code = OAuthAuthorizationCode(
|
||||
code=code,
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
redirect_uri=redirect_uri,
|
||||
scope=scope,
|
||||
code_challenge=code_challenge,
|
||||
code_challenge_method=code_challenge_method,
|
||||
state=state,
|
||||
nonce=nonce,
|
||||
expires_at=expires_at,
|
||||
used=False,
|
||||
)
|
||||
db.add(auth_code)
|
||||
await db.commit()
|
||||
return auth_code
|
||||
|
||||
async def consume_code_atomically(
|
||||
self, db: AsyncSession, *, code: str
|
||||
) -> UUID | None:
|
||||
"""
|
||||
Atomically mark a code as used and return its UUID.
|
||||
|
||||
Returns the UUID if the code was found and not yet used, None otherwise.
|
||||
This prevents race conditions per RFC 6749 Section 4.1.2.
|
||||
"""
|
||||
stmt = (
|
||||
update(OAuthAuthorizationCode)
|
||||
.where(
|
||||
and_(
|
||||
OAuthAuthorizationCode.code == code,
|
||||
OAuthAuthorizationCode.used == False, # noqa: E712
|
||||
)
|
||||
)
|
||||
.values(used=True)
|
||||
.returning(OAuthAuthorizationCode.id)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
row_id = result.scalar_one_or_none()
|
||||
if row_id is not None:
|
||||
await db.commit()
|
||||
return row_id
|
||||
|
||||
async def get_by_id(
|
||||
self, db: AsyncSession, *, code_id: UUID
|
||||
) -> OAuthAuthorizationCode | None:
|
||||
"""Get authorization code by its UUID primary key."""
|
||||
result = await db.execute(
|
||||
select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.id == code_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_by_code(
|
||||
self, db: AsyncSession, *, code: str
|
||||
) -> OAuthAuthorizationCode | None:
|
||||
"""Get authorization code by the code string value."""
|
||||
result = await db.execute(
|
||||
select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.code == code)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def cleanup_expired(self, db: AsyncSession) -> int:
|
||||
"""Delete all expired authorization codes. Returns count deleted."""
|
||||
result = await db.execute(
|
||||
delete(OAuthAuthorizationCode).where(
|
||||
OAuthAuthorizationCode.expires_at < datetime.now(UTC)
|
||||
)
|
||||
)
|
||||
await db.commit()
|
||||
return result.rowcount # type: ignore[attr-defined]
|
||||
|
||||
|
||||
# Singleton instance
|
||||
oauth_authorization_code_repo = OAuthAuthorizationCodeRepository()
|
||||
201
backend/app/repositories/oauth_client.py
Normal file
201
backend/app/repositories/oauth_client.py
Normal file
@@ -0,0 +1,201 @@
|
||||
# app/repositories/oauth_client.py
|
||||
"""Repository for OAuthClient model async database operations."""
|
||||
|
||||
import logging
|
||||
import secrets
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import and_, delete, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.repository_exceptions import DuplicateEntryError
|
||||
from app.models.oauth_client import OAuthClient
|
||||
from app.repositories.base import BaseRepository
|
||||
from app.schemas.oauth import OAuthClientCreate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmptySchema(BaseModel):
|
||||
"""Placeholder schema for repository operations that don't need update schemas."""
|
||||
|
||||
|
||||
class OAuthClientRepository(
|
||||
BaseRepository[OAuthClient, OAuthClientCreate, EmptySchema]
|
||||
):
|
||||
"""Repository for OAuth clients (provider mode)."""
|
||||
|
||||
async def get_by_client_id(
|
||||
self, db: AsyncSession, *, client_id: str
|
||||
) -> OAuthClient | None:
|
||||
"""Get OAuth client by client_id."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(OAuthClient).where(
|
||||
and_(
|
||||
OAuthClient.client_id == client_id,
|
||||
OAuthClient.is_active == True, # noqa: E712
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error("Error getting OAuth client %s: %s", client_id, e)
|
||||
raise
|
||||
|
||||
async def create_client(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
obj_in: OAuthClientCreate,
|
||||
owner_user_id: UUID | None = None,
|
||||
) -> tuple[OAuthClient, str | None]:
|
||||
"""Create a new OAuth client."""
|
||||
try:
|
||||
client_id = secrets.token_urlsafe(32)
|
||||
|
||||
client_secret = None
|
||||
client_secret_hash = None
|
||||
if obj_in.client_type == "confidential":
|
||||
client_secret = secrets.token_urlsafe(48)
|
||||
from app.core.auth import get_password_hash
|
||||
|
||||
client_secret_hash = get_password_hash(client_secret)
|
||||
|
||||
db_obj = OAuthClient(
|
||||
client_id=client_id,
|
||||
client_secret_hash=client_secret_hash,
|
||||
client_name=obj_in.client_name,
|
||||
client_description=obj_in.client_description,
|
||||
client_type=obj_in.client_type,
|
||||
redirect_uris=obj_in.redirect_uris,
|
||||
allowed_scopes=obj_in.allowed_scopes,
|
||||
owner_user_id=owner_user_id,
|
||||
is_active=True,
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
|
||||
logger.info(
|
||||
"OAuth client created: %s (%s...)", obj_in.client_name, client_id[:8]
|
||||
)
|
||||
return db_obj, client_secret
|
||||
except IntegrityError as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
logger.error("Error creating OAuth client: %s", error_msg)
|
||||
raise DuplicateEntryError(f"Failed to create OAuth client: {error_msg}")
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.exception("Error creating OAuth client: %s", e)
|
||||
raise
|
||||
|
||||
async def deactivate_client(
|
||||
self, db: AsyncSession, *, client_id: str
|
||||
) -> OAuthClient | None:
|
||||
"""Deactivate an OAuth client."""
|
||||
try:
|
||||
client = await self.get_by_client_id(db, client_id=client_id)
|
||||
if client is None:
|
||||
return None
|
||||
|
||||
client.is_active = False
|
||||
db.add(client)
|
||||
await db.commit()
|
||||
await db.refresh(client)
|
||||
|
||||
logger.info("OAuth client deactivated: %s", client.client_name)
|
||||
return client
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error("Error deactivating OAuth client %s: %s", client_id, e)
|
||||
raise
|
||||
|
||||
async def validate_redirect_uri(
|
||||
self, db: AsyncSession, *, client_id: str, redirect_uri: str
|
||||
) -> bool:
|
||||
"""Validate that a redirect URI is allowed for a client."""
|
||||
try:
|
||||
client = await self.get_by_client_id(db, client_id=client_id)
|
||||
if client is None:
|
||||
return False
|
||||
|
||||
return redirect_uri in (client.redirect_uris or [])
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error("Error validating redirect URI: %s", e)
|
||||
return False
|
||||
|
||||
async def verify_client_secret(
|
||||
self, db: AsyncSession, *, client_id: str, client_secret: str
|
||||
) -> bool:
|
||||
"""Verify client credentials."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(OAuthClient).where(
|
||||
and_(
|
||||
OAuthClient.client_id == client_id,
|
||||
OAuthClient.is_active == True, # noqa: E712
|
||||
)
|
||||
)
|
||||
)
|
||||
client = result.scalar_one_or_none()
|
||||
|
||||
if client is None or client.client_secret_hash is None:
|
||||
return False
|
||||
|
||||
from app.core.auth import verify_password
|
||||
|
||||
stored_hash: str = str(client.client_secret_hash)
|
||||
|
||||
if stored_hash.startswith("$2"):
|
||||
return verify_password(client_secret, stored_hash)
|
||||
else:
|
||||
import hashlib
|
||||
|
||||
secret_hash = hashlib.sha256(client_secret.encode()).hexdigest()
|
||||
return secrets.compare_digest(stored_hash, secret_hash)
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error("Error verifying client secret: %s", e)
|
||||
return False
|
||||
|
||||
async def get_all_clients(
|
||||
self, db: AsyncSession, *, include_inactive: bool = False
|
||||
) -> list[OAuthClient]:
|
||||
"""Get all OAuth clients."""
|
||||
try:
|
||||
query = select(OAuthClient).order_by(OAuthClient.created_at.desc())
|
||||
if not include_inactive:
|
||||
query = query.where(OAuthClient.is_active == True) # noqa: E712
|
||||
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.error("Error getting all OAuth clients: %s", e)
|
||||
raise
|
||||
|
||||
async def delete_client(self, db: AsyncSession, *, client_id: str) -> bool:
|
||||
"""Delete an OAuth client permanently."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
delete(OAuthClient).where(OAuthClient.client_id == client_id)
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
deleted = result.rowcount > 0
|
||||
if deleted:
|
||||
logger.info("OAuth client deleted: %s", client_id)
|
||||
else:
|
||||
logger.warning("OAuth client not found for deletion: %s", client_id)
|
||||
|
||||
return deleted
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error("Error deleting OAuth client %s: %s", client_id, e)
|
||||
raise
|
||||
|
||||
|
||||
# Singleton instance
|
||||
oauth_client_repo = OAuthClientRepository(OAuthClient)
|
||||
113
backend/app/repositories/oauth_consent.py
Normal file
113
backend/app/repositories/oauth_consent.py
Normal file
@@ -0,0 +1,113 @@
|
||||
# app/repositories/oauth_consent.py
|
||||
"""Repository for OAuthConsent model."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_, delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.oauth_client import OAuthClient
|
||||
from app.models.oauth_provider_token import OAuthConsent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuthConsentRepository:
|
||||
"""Repository for OAuth consent records (user grants to clients)."""
|
||||
|
||||
async def get_consent(
|
||||
self, db: AsyncSession, *, user_id: UUID, client_id: str
|
||||
) -> OAuthConsent | None:
|
||||
"""Get the consent record for a user-client pair, or None if not found."""
|
||||
result = await db.execute(
|
||||
select(OAuthConsent).where(
|
||||
and_(
|
||||
OAuthConsent.user_id == user_id,
|
||||
OAuthConsent.client_id == client_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def grant_consent(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: UUID,
|
||||
client_id: str,
|
||||
scopes: list[str],
|
||||
) -> OAuthConsent:
|
||||
"""
|
||||
Create or update consent for a user-client pair.
|
||||
|
||||
If consent already exists, the new scopes are merged with existing ones.
|
||||
Returns the created or updated consent record.
|
||||
"""
|
||||
consent = await self.get_consent(db, user_id=user_id, client_id=client_id)
|
||||
|
||||
if consent:
|
||||
existing = (
|
||||
set(consent.granted_scopes.split()) if consent.granted_scopes else set()
|
||||
)
|
||||
merged = existing | set(scopes)
|
||||
consent.granted_scopes = " ".join(sorted(merged)) # type: ignore[assignment]
|
||||
else:
|
||||
consent = OAuthConsent(
|
||||
user_id=user_id,
|
||||
client_id=client_id,
|
||||
granted_scopes=" ".join(sorted(set(scopes))),
|
||||
)
|
||||
db.add(consent)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(consent)
|
||||
return consent
|
||||
|
||||
async def get_user_consents_with_clients(
|
||||
self, db: AsyncSession, *, user_id: UUID
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get all consent records for a user joined with client details."""
|
||||
result = await db.execute(
|
||||
select(OAuthConsent, OAuthClient)
|
||||
.join(OAuthClient, OAuthConsent.client_id == OAuthClient.client_id)
|
||||
.where(OAuthConsent.user_id == user_id)
|
||||
)
|
||||
rows = result.all()
|
||||
return [
|
||||
{
|
||||
"client_id": consent.client_id,
|
||||
"client_name": client.client_name,
|
||||
"client_description": client.client_description,
|
||||
"granted_scopes": consent.granted_scopes.split()
|
||||
if consent.granted_scopes
|
||||
else [],
|
||||
"granted_at": consent.created_at.isoformat(),
|
||||
}
|
||||
for consent, client in rows
|
||||
]
|
||||
|
||||
async def revoke_consent(
|
||||
self, db: AsyncSession, *, user_id: UUID, client_id: str
|
||||
) -> bool:
|
||||
"""
|
||||
Delete the consent record for a user-client pair.
|
||||
|
||||
Returns True if a record was found and deleted.
|
||||
Note: Callers are responsible for also revoking associated tokens.
|
||||
"""
|
||||
result = await db.execute(
|
||||
delete(OAuthConsent).where(
|
||||
and_(
|
||||
OAuthConsent.user_id == user_id,
|
||||
OAuthConsent.client_id == client_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
await db.commit()
|
||||
return result.rowcount > 0 # type: ignore[attr-defined]
|
||||
|
||||
|
||||
# Singleton instance
|
||||
oauth_consent_repo = OAuthConsentRepository()
|
||||
142
backend/app/repositories/oauth_provider_token.py
Normal file
142
backend/app/repositories/oauth_provider_token.py
Normal file
@@ -0,0 +1,142 @@
|
||||
# app/repositories/oauth_provider_token.py
|
||||
"""Repository for OAuthProviderRefreshToken model."""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_, delete, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.oauth_provider_token import OAuthProviderRefreshToken
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuthProviderTokenRepository:
|
||||
"""Repository for OAuth provider refresh tokens."""
|
||||
|
||||
async def create_token(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
token_hash: str,
|
||||
jti: str,
|
||||
client_id: str,
|
||||
user_id: UUID,
|
||||
scope: str,
|
||||
expires_at: datetime,
|
||||
device_info: str | None = None,
|
||||
ip_address: str | None = None,
|
||||
) -> OAuthProviderRefreshToken:
|
||||
"""Create and persist a new refresh token record."""
|
||||
token = OAuthProviderRefreshToken(
|
||||
token_hash=token_hash,
|
||||
jti=jti,
|
||||
client_id=client_id,
|
||||
user_id=user_id,
|
||||
scope=scope,
|
||||
expires_at=expires_at,
|
||||
device_info=device_info,
|
||||
ip_address=ip_address,
|
||||
)
|
||||
db.add(token)
|
||||
await db.commit()
|
||||
return token
|
||||
|
||||
async def get_by_token_hash(
|
||||
self, db: AsyncSession, *, token_hash: str
|
||||
) -> OAuthProviderRefreshToken | None:
|
||||
"""Get refresh token record by SHA-256 token hash."""
|
||||
result = await db.execute(
|
||||
select(OAuthProviderRefreshToken).where(
|
||||
OAuthProviderRefreshToken.token_hash == token_hash
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_by_jti(
|
||||
self, db: AsyncSession, *, jti: str
|
||||
) -> OAuthProviderRefreshToken | None:
|
||||
"""Get refresh token record by JWT ID (JTI)."""
|
||||
result = await db.execute(
|
||||
select(OAuthProviderRefreshToken).where(
|
||||
OAuthProviderRefreshToken.jti == jti
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def revoke(
|
||||
self, db: AsyncSession, *, token: OAuthProviderRefreshToken
|
||||
) -> None:
|
||||
"""Mark a specific token record as revoked."""
|
||||
token.revoked = True # type: ignore[assignment]
|
||||
token.last_used_at = datetime.now(UTC) # type: ignore[assignment]
|
||||
await db.commit()
|
||||
|
||||
async def revoke_all_for_user_client(
|
||||
self, db: AsyncSession, *, user_id: UUID, client_id: str
|
||||
) -> int:
|
||||
"""
|
||||
Revoke all active tokens for a specific user-client pair.
|
||||
|
||||
Used when security incidents are detected (e.g., authorization code reuse).
|
||||
Returns the number of tokens revoked.
|
||||
"""
|
||||
result = await db.execute(
|
||||
update(OAuthProviderRefreshToken)
|
||||
.where(
|
||||
and_(
|
||||
OAuthProviderRefreshToken.user_id == user_id,
|
||||
OAuthProviderRefreshToken.client_id == client_id,
|
||||
OAuthProviderRefreshToken.revoked == False, # noqa: E712
|
||||
)
|
||||
)
|
||||
.values(revoked=True)
|
||||
)
|
||||
count = result.rowcount # type: ignore[attr-defined]
|
||||
if count > 0:
|
||||
await db.commit()
|
||||
return count
|
||||
|
||||
async def revoke_all_for_user(self, db: AsyncSession, *, user_id: UUID) -> int:
|
||||
"""
|
||||
Revoke all active tokens for a user across all clients.
|
||||
|
||||
Used when user changes password or logs out everywhere.
|
||||
Returns the number of tokens revoked.
|
||||
"""
|
||||
result = await db.execute(
|
||||
update(OAuthProviderRefreshToken)
|
||||
.where(
|
||||
and_(
|
||||
OAuthProviderRefreshToken.user_id == user_id,
|
||||
OAuthProviderRefreshToken.revoked == False, # noqa: E712
|
||||
)
|
||||
)
|
||||
.values(revoked=True)
|
||||
)
|
||||
count = result.rowcount # type: ignore[attr-defined]
|
||||
if count > 0:
|
||||
await db.commit()
|
||||
return count
|
||||
|
||||
async def cleanup_expired(self, db: AsyncSession, *, cutoff_days: int = 7) -> int:
|
||||
"""
|
||||
Delete expired refresh tokens older than cutoff_days.
|
||||
|
||||
Should be called periodically (e.g., daily).
|
||||
Returns the number of tokens deleted.
|
||||
"""
|
||||
cutoff = datetime.now(UTC) - timedelta(days=cutoff_days)
|
||||
result = await db.execute(
|
||||
delete(OAuthProviderRefreshToken).where(
|
||||
OAuthProviderRefreshToken.expires_at < cutoff
|
||||
)
|
||||
)
|
||||
await db.commit()
|
||||
return result.rowcount # type: ignore[attr-defined]
|
||||
|
||||
|
||||
# Singleton instance
|
||||
oauth_provider_token_repo = OAuthProviderTokenRepository()
|
||||
113
backend/app/repositories/oauth_state.py
Normal file
113
backend/app/repositories/oauth_state.py
Normal file
@@ -0,0 +1,113 @@
|
||||
# app/repositories/oauth_state.py
|
||||
"""Repository for OAuthState model async database operations."""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.repository_exceptions import DuplicateEntryError
|
||||
from app.models.oauth_state import OAuthState
|
||||
from app.repositories.base import BaseRepository
|
||||
from app.schemas.oauth import OAuthStateCreate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmptySchema(BaseModel):
|
||||
"""Placeholder schema for repository operations that don't need update schemas."""
|
||||
|
||||
|
||||
class OAuthStateRepository(BaseRepository[OAuthState, OAuthStateCreate, EmptySchema]):
|
||||
"""Repository for OAuth state (CSRF protection)."""
|
||||
|
||||
async def create_state(
|
||||
self, db: AsyncSession, *, obj_in: OAuthStateCreate
|
||||
) -> OAuthState:
|
||||
"""Create a new OAuth state for CSRF protection."""
|
||||
try:
|
||||
db_obj = OAuthState(
|
||||
state=obj_in.state,
|
||||
code_verifier=obj_in.code_verifier,
|
||||
nonce=obj_in.nonce,
|
||||
provider=obj_in.provider,
|
||||
redirect_uri=obj_in.redirect_uri,
|
||||
user_id=obj_in.user_id,
|
||||
expires_at=obj_in.expires_at,
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
|
||||
logger.debug("OAuth state created for %s", obj_in.provider)
|
||||
return db_obj
|
||||
except IntegrityError as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
logger.error("OAuth state collision: %s", error_msg)
|
||||
raise DuplicateEntryError("Failed to create OAuth state, please retry")
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.exception("Error creating OAuth state: %s", e)
|
||||
raise
|
||||
|
||||
async def get_and_consume_state(
|
||||
self, db: AsyncSession, *, state: str
|
||||
) -> OAuthState | None:
|
||||
"""Get and delete OAuth state (consume it)."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(OAuthState).where(OAuthState.state == state)
|
||||
)
|
||||
db_obj = result.scalar_one_or_none()
|
||||
|
||||
if db_obj is None:
|
||||
logger.warning("OAuth state not found: %s...", state[:8])
|
||||
return None
|
||||
|
||||
now = datetime.now(UTC)
|
||||
expires_at = db_obj.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=UTC)
|
||||
|
||||
if expires_at < now:
|
||||
logger.warning("OAuth state expired: %s...", state[:8])
|
||||
await db.delete(db_obj)
|
||||
await db.commit()
|
||||
return None
|
||||
|
||||
await db.delete(db_obj)
|
||||
await db.commit()
|
||||
|
||||
logger.debug("OAuth state consumed: %s...", state[:8])
|
||||
return db_obj
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error("Error consuming OAuth state: %s", e)
|
||||
raise
|
||||
|
||||
async def cleanup_expired(self, db: AsyncSession) -> int:
|
||||
"""Clean up expired OAuth states."""
|
||||
try:
|
||||
now = datetime.now(UTC)
|
||||
|
||||
stmt = delete(OAuthState).where(OAuthState.expires_at < now)
|
||||
result = await db.execute(stmt)
|
||||
await db.commit()
|
||||
|
||||
count = result.rowcount
|
||||
if count > 0:
|
||||
logger.info("Cleaned up %s expired OAuth states", count)
|
||||
|
||||
return count
|
||||
except Exception as e: # pragma: no cover
|
||||
await db.rollback()
|
||||
logger.error("Error cleaning up expired OAuth states: %s", e)
|
||||
raise
|
||||
|
||||
|
||||
# Singleton instance
|
||||
oauth_state_repo = OAuthStateRepository(OAuthState)
|
||||
128
backend/app/crud/organization.py → backend/app/repositories/organization.py
Executable file → Normal file
128
backend/app/crud/organization.py → backend/app/repositories/organization.py
Executable file → Normal file
@@ -1,5 +1,5 @@
|
||||
# app/crud/organization_async.py
|
||||
"""Async CRUD operations for Organization model using SQLAlchemy 2.0 patterns."""
|
||||
# app/repositories/organization.py
|
||||
"""Repository for Organization model async database operations using SQLAlchemy 2.0 patterns."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
@@ -9,10 +9,11 @@ from sqlalchemy import and_, case, func, or_, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.crud.base import CRUDBase
|
||||
from app.core.repository_exceptions import DuplicateEntryError, IntegrityConstraintError
|
||||
from app.models.organization import Organization
|
||||
from app.models.user import User
|
||||
from app.models.user_organization import OrganizationRole, UserOrganization
|
||||
from app.repositories.base import BaseRepository
|
||||
from app.schemas.organizations import (
|
||||
OrganizationCreate,
|
||||
OrganizationUpdate,
|
||||
@@ -21,8 +22,10 @@ from app.schemas.organizations import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUpdate]):
|
||||
"""Async CRUD operations for Organization model."""
|
||||
class OrganizationRepository(
|
||||
BaseRepository[Organization, OrganizationCreate, OrganizationUpdate]
|
||||
):
|
||||
"""Repository for Organization model."""
|
||||
|
||||
async def get_by_slug(self, db: AsyncSession, *, slug: str) -> Organization | None:
|
||||
"""Get organization by slug."""
|
||||
@@ -32,7 +35,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting organization by slug {slug}: {e!s}")
|
||||
logger.error("Error getting organization by slug %s: %s", slug, e)
|
||||
raise
|
||||
|
||||
async def create(
|
||||
@@ -54,18 +57,20 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
if "slug" in error_msg.lower():
|
||||
logger.warning(f"Duplicate slug attempted: {obj_in.slug}")
|
||||
raise ValueError(
|
||||
if (
|
||||
"slug" in error_msg.lower()
|
||||
or "unique" in error_msg.lower()
|
||||
or "duplicate" in error_msg.lower()
|
||||
):
|
||||
logger.warning("Duplicate slug attempted: %s", obj_in.slug)
|
||||
raise DuplicateEntryError(
|
||||
f"Organization with slug '{obj_in.slug}' already exists"
|
||||
)
|
||||
logger.error(f"Integrity error creating organization: {error_msg}")
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
logger.error("Integrity error creating organization: %s", error_msg)
|
||||
raise IntegrityConstraintError(f"Database integrity error: {error_msg}")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Unexpected error creating organization: {e!s}", exc_info=True
|
||||
)
|
||||
logger.exception("Unexpected error creating organization: %s", e)
|
||||
raise
|
||||
|
||||
async def get_multi_with_filters(
|
||||
@@ -79,16 +84,10 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
sort_by: str = "created_at",
|
||||
sort_order: str = "desc",
|
||||
) -> tuple[list[Organization], int]:
|
||||
"""
|
||||
Get multiple organizations with filtering, searching, and sorting.
|
||||
|
||||
Returns:
|
||||
Tuple of (organizations list, total count)
|
||||
"""
|
||||
"""Get multiple organizations with filtering, searching, and sorting."""
|
||||
try:
|
||||
query = select(Organization)
|
||||
|
||||
# Apply filters
|
||||
if is_active is not None:
|
||||
query = query.where(Organization.is_active == is_active)
|
||||
|
||||
@@ -100,26 +99,23 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
)
|
||||
query = query.where(search_filter)
|
||||
|
||||
# Get total count before pagination
|
||||
count_query = select(func.count()).select_from(query.alias())
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# Apply sorting
|
||||
sort_column = getattr(Organization, sort_by, Organization.created_at)
|
||||
if sort_order == "desc":
|
||||
query = query.order_by(sort_column.desc())
|
||||
else:
|
||||
query = query.order_by(sort_column.asc())
|
||||
|
||||
# Apply pagination
|
||||
query = query.offset(skip).limit(limit)
|
||||
result = await db.execute(query)
|
||||
organizations = list(result.scalars().all())
|
||||
|
||||
return organizations, total
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting organizations with filters: {e!s}")
|
||||
logger.error("Error getting organizations with filters: %s", e)
|
||||
raise
|
||||
|
||||
async def get_member_count(self, db: AsyncSession, *, organization_id: UUID) -> int:
|
||||
@@ -136,7 +132,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
return result.scalar_one() or 0
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting member count for organization {organization_id}: {e!s}"
|
||||
"Error getting member count for organization %s: %s", organization_id, e
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -149,16 +145,8 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
is_active: bool | None = None,
|
||||
search: str | None = None,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""
|
||||
Get organizations with member counts in a SINGLE QUERY using JOIN and GROUP BY.
|
||||
This eliminates the N+1 query problem.
|
||||
|
||||
Returns:
|
||||
Tuple of (list of dicts with org and member_count, total count)
|
||||
"""
|
||||
"""Get organizations with member counts in a SINGLE QUERY using JOIN and GROUP BY."""
|
||||
try:
|
||||
# Build base query with LEFT JOIN and GROUP BY
|
||||
# Use CASE statement to count only active members
|
||||
query = (
|
||||
select(
|
||||
Organization,
|
||||
@@ -181,10 +169,10 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
.group_by(Organization.id)
|
||||
)
|
||||
|
||||
# Apply filters
|
||||
if is_active is not None:
|
||||
query = query.where(Organization.is_active == is_active)
|
||||
|
||||
search_filter = None
|
||||
if search:
|
||||
search_filter = or_(
|
||||
Organization.name.ilike(f"%{search}%"),
|
||||
@@ -193,17 +181,15 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
)
|
||||
query = query.where(search_filter)
|
||||
|
||||
# Get total count
|
||||
count_query = select(func.count(Organization.id))
|
||||
if is_active is not None:
|
||||
count_query = count_query.where(Organization.is_active == is_active)
|
||||
if search:
|
||||
if search_filter is not None:
|
||||
count_query = count_query.where(search_filter)
|
||||
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# Apply pagination and ordering
|
||||
query = (
|
||||
query.order_by(Organization.created_at.desc()).offset(skip).limit(limit)
|
||||
)
|
||||
@@ -211,7 +197,6 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
result = await db.execute(query)
|
||||
rows = result.all()
|
||||
|
||||
# Convert to list of dicts
|
||||
orgs_with_counts = [
|
||||
{"organization": org, "member_count": member_count}
|
||||
for org, member_count in rows
|
||||
@@ -220,9 +205,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
return orgs_with_counts, total
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting organizations with member counts: {e!s}", exc_info=True
|
||||
)
|
||||
logger.exception("Error getting organizations with member counts: %s", e)
|
||||
raise
|
||||
|
||||
async def add_user(
|
||||
@@ -236,7 +219,6 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
) -> UserOrganization:
|
||||
"""Add a user to an organization with a specific role."""
|
||||
try:
|
||||
# Check if relationship already exists
|
||||
result = await db.execute(
|
||||
select(UserOrganization).where(
|
||||
and_(
|
||||
@@ -248,7 +230,6 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
existing = result.scalar_one_or_none()
|
||||
|
||||
if existing:
|
||||
# Reactivate if inactive, or raise error if already active
|
||||
if not existing.is_active:
|
||||
existing.is_active = True
|
||||
existing.role = role
|
||||
@@ -257,9 +238,10 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
await db.refresh(existing)
|
||||
return existing
|
||||
else:
|
||||
raise ValueError("User is already a member of this organization")
|
||||
raise DuplicateEntryError(
|
||||
"User is already a member of this organization"
|
||||
)
|
||||
|
||||
# Create new relationship
|
||||
user_org = UserOrganization(
|
||||
user_id=user_id,
|
||||
organization_id=organization_id,
|
||||
@@ -273,11 +255,11 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
return user_org
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Integrity error adding user to organization: {e!s}")
|
||||
raise ValueError("Failed to add user to organization")
|
||||
logger.error("Integrity error adding user to organization: %s", e)
|
||||
raise IntegrityConstraintError("Failed to add user to organization")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error adding user to organization: {e!s}", exc_info=True)
|
||||
logger.exception("Error adding user to organization: %s", e)
|
||||
raise
|
||||
|
||||
async def remove_user(
|
||||
@@ -303,7 +285,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
return True
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error removing user from organization: {e!s}", exc_info=True)
|
||||
logger.exception("Error removing user from organization: %s", e)
|
||||
raise
|
||||
|
||||
async def update_user_role(
|
||||
@@ -338,7 +320,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
return user_org
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error updating user role: {e!s}", exc_info=True)
|
||||
logger.exception("Error updating user role: %s", e)
|
||||
raise
|
||||
|
||||
async def get_organization_members(
|
||||
@@ -348,16 +330,10 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
organization_id: UUID,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: bool = True,
|
||||
is_active: bool | None = True,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""
|
||||
Get members of an organization with user details.
|
||||
|
||||
Returns:
|
||||
Tuple of (members list with user details, total count)
|
||||
"""
|
||||
"""Get members of an organization with user details."""
|
||||
try:
|
||||
# Build query with join
|
||||
query = (
|
||||
select(UserOrganization, User)
|
||||
.join(User, UserOrganization.user_id == User.id)
|
||||
@@ -367,7 +343,6 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
if is_active is not None:
|
||||
query = query.where(UserOrganization.is_active == is_active)
|
||||
|
||||
# Get total count
|
||||
count_query = select(func.count()).select_from(
|
||||
select(UserOrganization)
|
||||
.where(UserOrganization.organization_id == organization_id)
|
||||
@@ -381,7 +356,6 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# Apply ordering and pagination
|
||||
query = (
|
||||
query.order_by(UserOrganization.created_at.desc())
|
||||
.offset(skip)
|
||||
@@ -406,11 +380,11 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
|
||||
return members, total
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting organization members: {e!s}")
|
||||
logger.error("Error getting organization members: %s", e)
|
||||
raise
|
||||
|
||||
async def get_user_organizations(
|
||||
self, db: AsyncSession, *, user_id: UUID, is_active: bool = True
|
||||
self, db: AsyncSession, *, user_id: UUID, is_active: bool | None = True
|
||||
) -> list[Organization]:
|
||||
"""Get all organizations a user belongs to."""
|
||||
try:
|
||||
@@ -429,21 +403,14 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user organizations: {e!s}")
|
||||
logger.error("Error getting user organizations: %s", e)
|
||||
raise
|
||||
|
||||
async def get_user_organizations_with_details(
|
||||
self, db: AsyncSession, *, user_id: UUID, is_active: bool = True
|
||||
self, db: AsyncSession, *, user_id: UUID, is_active: bool | None = True
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Get user's organizations with role and member count in SINGLE QUERY.
|
||||
Eliminates N+1 problem by using subquery for member counts.
|
||||
|
||||
Returns:
|
||||
List of dicts with organization, role, and member_count
|
||||
"""
|
||||
"""Get user's organizations with role and member count in SINGLE QUERY."""
|
||||
try:
|
||||
# Subquery to get member counts for each organization
|
||||
member_count_subq = (
|
||||
select(
|
||||
UserOrganization.organization_id,
|
||||
@@ -454,7 +421,6 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
.subquery()
|
||||
)
|
||||
|
||||
# Main query with JOIN to get org, role, and member count
|
||||
query = (
|
||||
select(
|
||||
Organization,
|
||||
@@ -486,9 +452,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting user organizations with details: {e!s}", exc_info=True
|
||||
)
|
||||
logger.exception("Error getting user organizations with details: %s", e)
|
||||
raise
|
||||
|
||||
async def get_user_role_in_org(
|
||||
@@ -507,9 +471,9 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
)
|
||||
user_org = result.scalar_one_or_none()
|
||||
|
||||
return user_org.role if user_org else None
|
||||
return user_org.role if user_org else None # pyright: ignore[reportReturnType]
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user role in org: {e!s}")
|
||||
logger.error("Error getting user role in org: %s", e)
|
||||
raise
|
||||
|
||||
async def is_user_org_owner(
|
||||
@@ -531,5 +495,5 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
|
||||
return role in [OrganizationRole.OWNER, OrganizationRole.ADMIN]
|
||||
|
||||
|
||||
# Create a singleton instance for use across the application
|
||||
organization = CRUDOrganization(Organization)
|
||||
# Singleton instance
|
||||
organization_repo = OrganizationRepository(Organization)
|
||||
231
backend/app/crud/session.py → backend/app/repositories/session.py
Executable file → Normal file
231
backend/app/crud/session.py → backend/app/repositories/session.py
Executable file → Normal file
@@ -1,6 +1,5 @@
|
||||
"""
|
||||
Async CRUD operations for user sessions using SQLAlchemy 2.0 patterns.
|
||||
"""
|
||||
# app/repositories/session.py
|
||||
"""Repository for UserSession model async database operations using SQLAlchemy 2.0 patterns."""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
@@ -11,49 +10,32 @@ from sqlalchemy import and_, delete, func, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from app.crud.base import CRUDBase
|
||||
from app.core.repository_exceptions import IntegrityConstraintError, InvalidInputError
|
||||
from app.models.user_session import UserSession
|
||||
from app.repositories.base import BaseRepository
|
||||
from app.schemas.sessions import SessionCreate, SessionUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
"""Async CRUD operations for user sessions."""
|
||||
class SessionRepository(BaseRepository[UserSession, SessionCreate, SessionUpdate]):
|
||||
"""Repository for UserSession model."""
|
||||
|
||||
async def get_by_jti(self, db: AsyncSession, *, jti: str) -> UserSession | None:
|
||||
"""
|
||||
Get session by refresh token JTI.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
jti: Refresh token JWT ID
|
||||
|
||||
Returns:
|
||||
UserSession if found, None otherwise
|
||||
"""
|
||||
"""Get session by refresh token JTI."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(UserSession).where(UserSession.refresh_token_jti == jti)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting session by JTI {jti}: {e!s}")
|
||||
logger.error("Error getting session by JTI %s: %s", jti, e)
|
||||
raise
|
||||
|
||||
async def get_active_by_jti(
|
||||
self, db: AsyncSession, *, jti: str
|
||||
) -> UserSession | None:
|
||||
"""
|
||||
Get active session by refresh token JTI.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
jti: Refresh token JWT ID
|
||||
|
||||
Returns:
|
||||
Active UserSession if found, None otherwise
|
||||
"""
|
||||
"""Get active session by refresh token JTI."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(UserSession).where(
|
||||
@@ -65,7 +47,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting active session by JTI {jti}: {e!s}")
|
||||
logger.error("Error getting active session by JTI %s: %s", jti, e)
|
||||
raise
|
||||
|
||||
async def get_user_sessions(
|
||||
@@ -76,25 +58,12 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
active_only: bool = True,
|
||||
with_user: bool = False,
|
||||
) -> list[UserSession]:
|
||||
"""
|
||||
Get all sessions for a user with optional eager loading.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_id: User ID
|
||||
active_only: If True, return only active sessions
|
||||
with_user: If True, eager load user relationship to prevent N+1
|
||||
|
||||
Returns:
|
||||
List of UserSession objects
|
||||
"""
|
||||
"""Get all sessions for a user with optional eager loading."""
|
||||
try:
|
||||
# Convert user_id string to UUID if needed
|
||||
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||
|
||||
query = select(UserSession).where(UserSession.user_id == user_uuid)
|
||||
|
||||
# Add eager loading if requested to prevent N+1 queries
|
||||
if with_user:
|
||||
query = query.options(joinedload(UserSession.user))
|
||||
|
||||
@@ -105,25 +74,13 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting sessions for user {user_id}: {e!s}")
|
||||
logger.error("Error getting sessions for user %s: %s", user_id, e)
|
||||
raise
|
||||
|
||||
async def create_session(
|
||||
self, db: AsyncSession, *, obj_in: SessionCreate
|
||||
) -> UserSession:
|
||||
"""
|
||||
Create a new user session.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
obj_in: SessionCreate schema with session data
|
||||
|
||||
Returns:
|
||||
Created UserSession
|
||||
|
||||
Raises:
|
||||
ValueError: If session creation fails
|
||||
"""
|
||||
"""Create a new user session."""
|
||||
try:
|
||||
db_obj = UserSession(
|
||||
user_id=obj_in.user_id,
|
||||
@@ -143,33 +100,26 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
await db.refresh(db_obj)
|
||||
|
||||
logger.info(
|
||||
f"Session created for user {obj_in.user_id} from {obj_in.device_name} "
|
||||
f"(IP: {obj_in.ip_address})"
|
||||
"Session created for user %s from %s (IP: %s)",
|
||||
obj_in.user_id,
|
||||
obj_in.device_name,
|
||||
obj_in.ip_address,
|
||||
)
|
||||
|
||||
return db_obj
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error creating session: {e!s}", exc_info=True)
|
||||
raise ValueError(f"Failed to create session: {e!s}")
|
||||
logger.exception("Error creating session: %s", e)
|
||||
raise IntegrityConstraintError(f"Failed to create session: {e!s}")
|
||||
|
||||
async def deactivate(
|
||||
self, db: AsyncSession, *, session_id: str
|
||||
) -> UserSession | None:
|
||||
"""
|
||||
Deactivate a session (logout from device).
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
session_id: Session UUID
|
||||
|
||||
Returns:
|
||||
Deactivated UserSession if found, None otherwise
|
||||
"""
|
||||
"""Deactivate a session (logout from device)."""
|
||||
try:
|
||||
session = await self.get(db, id=session_id)
|
||||
if not session:
|
||||
logger.warning(f"Session {session_id} not found for deactivation")
|
||||
logger.warning("Session %s not found for deactivation", session_id)
|
||||
return None
|
||||
|
||||
session.is_active = False
|
||||
@@ -178,31 +128,23 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
await db.refresh(session)
|
||||
|
||||
logger.info(
|
||||
f"Session {session_id} deactivated for user {session.user_id} "
|
||||
f"({session.device_name})"
|
||||
"Session %s deactivated for user %s (%s)",
|
||||
session_id,
|
||||
session.user_id,
|
||||
session.device_name,
|
||||
)
|
||||
|
||||
return session
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error deactivating session {session_id}: {e!s}")
|
||||
logger.error("Error deactivating session %s: %s", session_id, e)
|
||||
raise
|
||||
|
||||
async def deactivate_all_user_sessions(
|
||||
self, db: AsyncSession, *, user_id: str
|
||||
) -> int:
|
||||
"""
|
||||
Deactivate all active sessions for a user (logout from all devices).
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
Number of sessions deactivated
|
||||
"""
|
||||
"""Deactivate all active sessions for a user (logout from all devices)."""
|
||||
try:
|
||||
# Convert user_id string to UUID if needed
|
||||
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||
|
||||
stmt = (
|
||||
@@ -216,27 +158,18 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
|
||||
count = result.rowcount
|
||||
|
||||
logger.info(f"Deactivated {count} sessions for user {user_id}")
|
||||
logger.info("Deactivated %s sessions for user %s", count, user_id)
|
||||
|
||||
return count
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error deactivating all sessions for user {user_id}: {e!s}")
|
||||
logger.error("Error deactivating all sessions for user %s: %s", user_id, e)
|
||||
raise
|
||||
|
||||
async def update_last_used(
|
||||
self, db: AsyncSession, *, session: UserSession
|
||||
) -> UserSession:
|
||||
"""
|
||||
Update the last_used_at timestamp for a session.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
session: UserSession object
|
||||
|
||||
Returns:
|
||||
Updated UserSession
|
||||
"""
|
||||
"""Update the last_used_at timestamp for a session."""
|
||||
try:
|
||||
session.last_used_at = datetime.now(UTC)
|
||||
db.add(session)
|
||||
@@ -245,7 +178,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
return session
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error updating last_used for session {session.id}: {e!s}")
|
||||
logger.error("Error updating last_used for session %s: %s", session.id, e)
|
||||
raise
|
||||
|
||||
async def update_refresh_token(
|
||||
@@ -256,20 +189,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
new_jti: str,
|
||||
new_expires_at: datetime,
|
||||
) -> UserSession:
|
||||
"""
|
||||
Update session with new refresh token JTI and expiration.
|
||||
|
||||
Called during token refresh.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
session: UserSession object
|
||||
new_jti: New refresh token JTI
|
||||
new_expires_at: New expiration datetime
|
||||
|
||||
Returns:
|
||||
Updated UserSession
|
||||
"""
|
||||
"""Update session with new refresh token JTI and expiration."""
|
||||
try:
|
||||
session.refresh_token_jti = new_jti
|
||||
session.expires_at = new_expires_at
|
||||
@@ -281,32 +201,16 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Error updating refresh token for session {session.id}: {e!s}"
|
||||
"Error updating refresh token for session %s: %s", session.id, e
|
||||
)
|
||||
raise
|
||||
|
||||
async def cleanup_expired(self, db: AsyncSession, *, keep_days: int = 30) -> int:
|
||||
"""
|
||||
Clean up expired sessions using optimized bulk DELETE.
|
||||
|
||||
Deletes sessions that are:
|
||||
- Expired AND inactive
|
||||
- Older than keep_days
|
||||
|
||||
Uses single DELETE query instead of N individual deletes for efficiency.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
keep_days: Keep inactive sessions for this many days (for audit)
|
||||
|
||||
Returns:
|
||||
Number of sessions deleted
|
||||
"""
|
||||
"""Clean up expired sessions using optimized bulk DELETE."""
|
||||
try:
|
||||
cutoff_date = datetime.now(UTC) - timedelta(days=keep_days)
|
||||
now = datetime.now(UTC)
|
||||
|
||||
# Use bulk DELETE with WHERE clause - single query
|
||||
stmt = delete(UserSession).where(
|
||||
and_(
|
||||
UserSession.is_active == False, # noqa: E712
|
||||
@@ -321,38 +225,25 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
count = result.rowcount
|
||||
|
||||
if count > 0:
|
||||
logger.info(f"Cleaned up {count} expired sessions using bulk DELETE")
|
||||
logger.info("Cleaned up %s expired sessions using bulk DELETE", count)
|
||||
|
||||
return count
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error cleaning up expired sessions: {e!s}")
|
||||
logger.error("Error cleaning up expired sessions: %s", e)
|
||||
raise
|
||||
|
||||
async def cleanup_expired_for_user(self, db: AsyncSession, *, user_id: str) -> int:
|
||||
"""
|
||||
Clean up expired and inactive sessions for a specific user.
|
||||
|
||||
Uses single bulk DELETE query for efficiency instead of N individual deletes.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_id: User ID to cleanup sessions for
|
||||
|
||||
Returns:
|
||||
Number of sessions deleted
|
||||
"""
|
||||
"""Clean up expired and inactive sessions for a specific user."""
|
||||
try:
|
||||
# Validate UUID
|
||||
try:
|
||||
uuid_obj = uuid.UUID(user_id)
|
||||
except (ValueError, AttributeError):
|
||||
logger.error(f"Invalid UUID format: {user_id}")
|
||||
raise ValueError(f"Invalid user ID format: {user_id}")
|
||||
logger.error("Invalid UUID format: %s", user_id)
|
||||
raise InvalidInputError(f"Invalid user ID format: {user_id}")
|
||||
|
||||
now = datetime.now(UTC)
|
||||
|
||||
# Use bulk DELETE with WHERE clause - single query
|
||||
stmt = delete(UserSession).where(
|
||||
and_(
|
||||
UserSession.user_id == uuid_obj,
|
||||
@@ -368,30 +259,22 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
|
||||
if count > 0:
|
||||
logger.info(
|
||||
f"Cleaned up {count} expired sessions for user {user_id} using bulk DELETE"
|
||||
"Cleaned up %s expired sessions for user %s using bulk DELETE",
|
||||
count,
|
||||
user_id,
|
||||
)
|
||||
|
||||
return count
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Error cleaning up expired sessions for user {user_id}: {e!s}"
|
||||
"Error cleaning up expired sessions for user %s: %s", user_id, e
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_user_session_count(self, db: AsyncSession, *, user_id: str) -> int:
|
||||
"""
|
||||
Get count of active sessions for a user.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
Number of active sessions
|
||||
"""
|
||||
"""Get count of active sessions for a user."""
|
||||
try:
|
||||
# Convert user_id string to UUID if needed
|
||||
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||
|
||||
result = await db.execute(
|
||||
@@ -401,7 +284,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
)
|
||||
return result.scalar_one()
|
||||
except Exception as e:
|
||||
logger.error(f"Error counting sessions for user {user_id}: {e!s}")
|
||||
logger.error("Error counting sessions for user %s: %s", user_id, e)
|
||||
raise
|
||||
|
||||
async def get_all_sessions(
|
||||
@@ -413,31 +296,16 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
active_only: bool = True,
|
||||
with_user: bool = True,
|
||||
) -> tuple[list[UserSession], int]:
|
||||
"""
|
||||
Get all sessions across all users with pagination (admin only).
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
skip: Number of records to skip
|
||||
limit: Maximum number of records to return
|
||||
active_only: If True, return only active sessions
|
||||
with_user: If True, eager load user relationship to prevent N+1
|
||||
|
||||
Returns:
|
||||
Tuple of (list of UserSession objects, total count)
|
||||
"""
|
||||
"""Get all sessions across all users with pagination (admin only)."""
|
||||
try:
|
||||
# Build query
|
||||
query = select(UserSession)
|
||||
|
||||
# Add eager loading if requested to prevent N+1 queries
|
||||
if with_user:
|
||||
query = query.options(joinedload(UserSession.user))
|
||||
|
||||
if active_only:
|
||||
query = query.where(UserSession.is_active)
|
||||
|
||||
# Get total count
|
||||
count_query = select(func.count(UserSession.id))
|
||||
if active_only:
|
||||
count_query = count_query.where(UserSession.is_active)
|
||||
@@ -445,7 +313,6 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# Apply pagination and ordering
|
||||
query = (
|
||||
query.order_by(UserSession.last_used_at.desc())
|
||||
.offset(skip)
|
||||
@@ -458,9 +325,9 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
return sessions, total
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting all sessions: {e!s}", exc_info=True)
|
||||
logger.exception("Error getting all sessions: %s", e)
|
||||
raise
|
||||
|
||||
|
||||
# Create singleton instance
|
||||
session = CRUDSession(UserSession)
|
||||
# Singleton instance
|
||||
session_repo = SessionRepository(UserSession)
|
||||
155
backend/app/crud/user.py → backend/app/repositories/user.py
Executable file → Normal file
155
backend/app/crud/user.py → backend/app/repositories/user.py
Executable file → Normal file
@@ -1,5 +1,5 @@
|
||||
# app/crud/user_async.py
|
||||
"""Async CRUD operations for User model using SQLAlchemy 2.0 patterns."""
|
||||
# app/repositories/user.py
|
||||
"""Repository for User model async database operations using SQLAlchemy 2.0 patterns."""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
@@ -11,15 +11,16 @@ from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.auth import get_password_hash_async
|
||||
from app.crud.base import CRUDBase
|
||||
from app.core.repository_exceptions import DuplicateEntryError, InvalidInputError
|
||||
from app.models.user import User
|
||||
from app.repositories.base import BaseRepository
|
||||
from app.schemas.users import UserCreate, UserUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
"""Async CRUD operations for User model."""
|
||||
class UserRepository(BaseRepository[User, UserCreate, UserUpdate]):
|
||||
"""Repository for User model."""
|
||||
|
||||
async def get_by_email(self, db: AsyncSession, *, email: str) -> User | None:
|
||||
"""Get user by email address."""
|
||||
@@ -27,13 +28,12 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
result = await db.execute(select(User).where(User.email == email))
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user by email {email}: {e!s}")
|
||||
logger.error("Error getting user by email %s: %s", email, e)
|
||||
raise
|
||||
|
||||
async def create(self, db: AsyncSession, *, obj_in: UserCreate) -> User:
|
||||
"""Create a new user with async password hashing and error handling."""
|
||||
try:
|
||||
# Hash password asynchronously to avoid blocking event loop
|
||||
password_hash = await get_password_hash_async(obj_in.password)
|
||||
|
||||
db_obj = User(
|
||||
@@ -57,13 +57,49 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
if "email" in error_msg.lower():
|
||||
logger.warning(f"Duplicate email attempted: {obj_in.email}")
|
||||
raise ValueError(f"User with email {obj_in.email} already exists")
|
||||
logger.error(f"Integrity error creating user: {error_msg}")
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
logger.warning("Duplicate email attempted: %s", obj_in.email)
|
||||
raise DuplicateEntryError(
|
||||
f"User with email {obj_in.email} already exists"
|
||||
)
|
||||
logger.error("Integrity error creating user: %s", error_msg)
|
||||
raise DuplicateEntryError(f"Database integrity error: {error_msg}")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Unexpected error creating user: {e!s}", exc_info=True)
|
||||
logger.exception("Unexpected error creating user: %s", e)
|
||||
raise
|
||||
|
||||
async def create_oauth_user(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
email: str,
|
||||
first_name: str = "User",
|
||||
last_name: str | None = None,
|
||||
) -> User:
|
||||
"""Create a new passwordless user for OAuth sign-in."""
|
||||
try:
|
||||
db_obj = User(
|
||||
email=email,
|
||||
password_hash=None, # OAuth-only user
|
||||
first_name=first_name,
|
||||
last_name=last_name,
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.flush() # Get user.id without committing
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
if "email" in error_msg.lower():
|
||||
logger.warning("Duplicate email attempted: %s", email)
|
||||
raise DuplicateEntryError(f"User with email {email} already exists")
|
||||
logger.error("Integrity error creating OAuth user: %s", error_msg)
|
||||
raise DuplicateEntryError(f"Database integrity error: {error_msg}")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.exception("Unexpected error creating OAuth user: %s", e)
|
||||
raise
|
||||
|
||||
async def update(
|
||||
@@ -75,8 +111,6 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
else:
|
||||
update_data = obj_in.model_dump(exclude_unset=True)
|
||||
|
||||
# Handle password separately if it exists in update data
|
||||
# Hash password asynchronously to avoid blocking event loop
|
||||
if "password" in update_data:
|
||||
update_data["password_hash"] = await get_password_hash_async(
|
||||
update_data["password"]
|
||||
@@ -85,6 +119,15 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
|
||||
return await super().update(db, db_obj=db_obj, obj_in=update_data)
|
||||
|
||||
async def update_password(
|
||||
self, db: AsyncSession, *, user: User, password_hash: str
|
||||
) -> User:
|
||||
"""Set a new password hash on a user and commit."""
|
||||
user.password_hash = password_hash
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return user
|
||||
|
||||
async def get_multi_with_total(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
@@ -96,43 +139,23 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
filters: dict[str, Any] | None = None,
|
||||
search: str | None = None,
|
||||
) -> tuple[list[User], int]:
|
||||
"""
|
||||
Get multiple users with total count, filtering, sorting, and search.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
skip: Number of records to skip
|
||||
limit: Maximum number of records to return
|
||||
sort_by: Field name to sort by
|
||||
sort_order: Sort order ("asc" or "desc")
|
||||
filters: Dictionary of filters (field_name: value)
|
||||
search: Search term to match against email, first_name, last_name
|
||||
|
||||
Returns:
|
||||
Tuple of (users list, total count)
|
||||
"""
|
||||
# Validate pagination
|
||||
"""Get multiple users with total count, filtering, sorting, and search."""
|
||||
if skip < 0:
|
||||
raise ValueError("skip must be non-negative")
|
||||
raise InvalidInputError("skip must be non-negative")
|
||||
if limit < 0:
|
||||
raise ValueError("limit must be non-negative")
|
||||
raise InvalidInputError("limit must be non-negative")
|
||||
if limit > 1000:
|
||||
raise ValueError("Maximum limit is 1000")
|
||||
raise InvalidInputError("Maximum limit is 1000")
|
||||
|
||||
try:
|
||||
# Build base query
|
||||
query = select(User)
|
||||
|
||||
# Exclude soft-deleted users
|
||||
query = query.where(User.deleted_at.is_(None))
|
||||
|
||||
# Apply filters
|
||||
if filters:
|
||||
for field, value in filters.items():
|
||||
if hasattr(User, field) and value is not None:
|
||||
query = query.where(getattr(User, field) == value)
|
||||
|
||||
# Apply search
|
||||
if search:
|
||||
search_filter = or_(
|
||||
User.email.ilike(f"%{search}%"),
|
||||
@@ -141,14 +164,12 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
)
|
||||
query = query.where(search_filter)
|
||||
|
||||
# Get total count
|
||||
from sqlalchemy import func
|
||||
|
||||
count_query = select(func.count()).select_from(query.alias())
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# Apply sorting
|
||||
if sort_by and hasattr(User, sort_by):
|
||||
sort_column = getattr(User, sort_by)
|
||||
if sort_order.lower() == "desc":
|
||||
@@ -156,7 +177,6 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
else:
|
||||
query = query.order_by(sort_column.asc())
|
||||
|
||||
# Apply pagination
|
||||
query = query.offset(skip).limit(limit)
|
||||
result = await db.execute(query)
|
||||
users = list(result.scalars().all())
|
||||
@@ -164,32 +184,21 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
return users, total
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving paginated users: {e!s}")
|
||||
logger.error("Error retrieving paginated users: %s", e)
|
||||
raise
|
||||
|
||||
async def bulk_update_status(
|
||||
self, db: AsyncSession, *, user_ids: list[UUID], is_active: bool
|
||||
) -> int:
|
||||
"""
|
||||
Bulk update is_active status for multiple users.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_ids: List of user IDs to update
|
||||
is_active: New active status
|
||||
|
||||
Returns:
|
||||
Number of users updated
|
||||
"""
|
||||
"""Bulk update is_active status for multiple users."""
|
||||
try:
|
||||
if not user_ids:
|
||||
return 0
|
||||
|
||||
# Use UPDATE with WHERE IN for efficiency
|
||||
stmt = (
|
||||
update(User)
|
||||
.where(User.id.in_(user_ids))
|
||||
.where(User.deleted_at.is_(None)) # Don't update deleted users
|
||||
.where(User.deleted_at.is_(None))
|
||||
.values(is_active=is_active, updated_at=datetime.now(UTC))
|
||||
)
|
||||
|
||||
@@ -197,12 +206,14 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
await db.commit()
|
||||
|
||||
updated_count = result.rowcount
|
||||
logger.info(f"Bulk updated {updated_count} users to is_active={is_active}")
|
||||
logger.info(
|
||||
"Bulk updated %s users to is_active=%s", updated_count, is_active
|
||||
)
|
||||
return updated_count
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error bulk updating user status: {e!s}", exc_info=True)
|
||||
logger.exception("Error bulk updating user status: %s", e)
|
||||
raise
|
||||
|
||||
async def bulk_soft_delete(
|
||||
@@ -212,34 +223,20 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
user_ids: list[UUID],
|
||||
exclude_user_id: UUID | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Bulk soft delete multiple users.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
user_ids: List of user IDs to delete
|
||||
exclude_user_id: Optional user ID to exclude (e.g., the admin performing the action)
|
||||
|
||||
Returns:
|
||||
Number of users deleted
|
||||
"""
|
||||
"""Bulk soft delete multiple users."""
|
||||
try:
|
||||
if not user_ids:
|
||||
return 0
|
||||
|
||||
# Remove excluded user from list
|
||||
filtered_ids = [uid for uid in user_ids if uid != exclude_user_id]
|
||||
|
||||
if not filtered_ids:
|
||||
return 0
|
||||
|
||||
# Use UPDATE with WHERE IN for efficiency
|
||||
stmt = (
|
||||
update(User)
|
||||
.where(User.id.in_(filtered_ids))
|
||||
.where(
|
||||
User.deleted_at.is_(None)
|
||||
) # Don't re-delete already deleted users
|
||||
.where(User.deleted_at.is_(None))
|
||||
.values(
|
||||
deleted_at=datetime.now(UTC),
|
||||
is_active=False,
|
||||
@@ -251,22 +248,22 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
await db.commit()
|
||||
|
||||
deleted_count = result.rowcount
|
||||
logger.info(f"Bulk soft deleted {deleted_count} users")
|
||||
logger.info("Bulk soft deleted %s users", deleted_count)
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error bulk deleting users: {e!s}", exc_info=True)
|
||||
logger.exception("Error bulk deleting users: %s", e)
|
||||
raise
|
||||
|
||||
def is_active(self, user: User) -> bool:
|
||||
"""Check if user is active."""
|
||||
return user.is_active
|
||||
return bool(user.is_active)
|
||||
|
||||
def is_superuser(self, user: User) -> bool:
|
||||
"""Check if user is a superuser."""
|
||||
return user.is_superuser
|
||||
return bool(user.is_superuser)
|
||||
|
||||
|
||||
# Create a singleton instance for use across the application
|
||||
user = CRUDUser(User)
|
||||
# Singleton instance
|
||||
user_repo = UserRepository(User)
|
||||
@@ -60,8 +60,8 @@ class OAuthAccountCreate(OAuthAccountBase):
|
||||
|
||||
user_id: UUID
|
||||
provider_user_id: str = Field(..., max_length=255)
|
||||
access_token_encrypted: str | None = None
|
||||
refresh_token_encrypted: str | None = None
|
||||
access_token: str | None = None
|
||||
refresh_token: str | None = None
|
||||
token_expires_at: datetime | None = None
|
||||
|
||||
|
||||
|
||||
@@ -48,7 +48,7 @@ class OrganizationCreate(OrganizationBase):
|
||||
"""Schema for creating a new organization."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
slug: str = Field(..., min_length=1, max_length=255)
|
||||
slug: str = Field(..., min_length=1, max_length=255) # pyright: ignore[reportIncompatibleVariableOverride]
|
||||
|
||||
|
||||
class OrganizationUpdate(BaseModel):
|
||||
|
||||
@@ -1,5 +1,19 @@
|
||||
# app/services/__init__.py
|
||||
from . import oauth_provider_service
|
||||
from .auth_service import AuthService
|
||||
from .oauth_service import OAuthService
|
||||
from .organization_service import OrganizationService, organization_service
|
||||
from .session_service import SessionService, session_service
|
||||
from .user_service import UserService, user_service
|
||||
|
||||
__all__ = ["AuthService", "OAuthService"]
|
||||
__all__ = [
|
||||
"AuthService",
|
||||
"OAuthService",
|
||||
"OrganizationService",
|
||||
"SessionService",
|
||||
"UserService",
|
||||
"oauth_provider_service",
|
||||
"organization_service",
|
||||
"session_service",
|
||||
"user_service",
|
||||
]
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
import logging
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.auth import (
|
||||
@@ -14,12 +13,18 @@ from app.core.auth import (
|
||||
verify_password_async,
|
||||
)
|
||||
from app.core.config import settings
|
||||
from app.core.exceptions import AuthenticationError
|
||||
from app.core.exceptions import AuthenticationError, DuplicateError
|
||||
from app.core.repository_exceptions import DuplicateEntryError
|
||||
from app.models.user import User
|
||||
from app.repositories.user import user_repo
|
||||
from app.schemas.users import Token, UserCreate, UserResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Pre-computed bcrypt hash used for constant-time comparison when user is not found,
|
||||
# preventing timing attacks that could enumerate valid email addresses.
|
||||
_DUMMY_HASH = "$2b$12$EixZaYVK1fsbw1ZfbX3OXePaWxn96p36zLFbnJHfxPSEFBzXKiHia"
|
||||
|
||||
|
||||
class AuthService:
|
||||
"""Service for handling authentication operations"""
|
||||
@@ -39,10 +44,12 @@ class AuthService:
|
||||
Returns:
|
||||
User if authenticated, None otherwise
|
||||
"""
|
||||
result = await db.execute(select(User).where(User.email == email))
|
||||
user = result.scalar_one_or_none()
|
||||
user = await user_repo.get_by_email(db, email=email)
|
||||
|
||||
if not user:
|
||||
# Perform a dummy verification to match timing of a real bcrypt check,
|
||||
# preventing email enumeration via response-time differences.
|
||||
await verify_password_async(password, _DUMMY_HASH)
|
||||
return None
|
||||
|
||||
# Verify password asynchronously to avoid blocking event loop
|
||||
@@ -71,40 +78,23 @@ class AuthService:
|
||||
"""
|
||||
try:
|
||||
# Check if user already exists
|
||||
result = await db.execute(select(User).where(User.email == user_data.email))
|
||||
existing_user = result.scalar_one_or_none()
|
||||
existing_user = await user_repo.get_by_email(db, email=user_data.email)
|
||||
if existing_user:
|
||||
raise AuthenticationError("User with this email already exists")
|
||||
raise DuplicateError("User with this email already exists")
|
||||
|
||||
# Create new user with async password hashing
|
||||
# Hash password asynchronously to avoid blocking event loop
|
||||
hashed_password = await get_password_hash_async(user_data.password)
|
||||
# Delegate creation (hashing + commit) to the repository
|
||||
user = await user_repo.create(db, obj_in=user_data)
|
||||
|
||||
# Create user object from model
|
||||
user = User(
|
||||
email=user_data.email,
|
||||
password_hash=hashed_password,
|
||||
first_name=user_data.first_name,
|
||||
last_name=user_data.last_name,
|
||||
phone_number=user_data.phone_number,
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
)
|
||||
|
||||
db.add(user)
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
|
||||
logger.info(f"User created successfully: {user.email}")
|
||||
logger.info("User created successfully: %s", user.email)
|
||||
return user
|
||||
|
||||
except AuthenticationError:
|
||||
# Re-raise authentication errors without rollback
|
||||
except (AuthenticationError, DuplicateError):
|
||||
# Re-raise API exceptions without rollback
|
||||
raise
|
||||
except DuplicateEntryError as e:
|
||||
raise DuplicateError(str(e))
|
||||
except Exception as e:
|
||||
# Rollback on any database errors
|
||||
await db.rollback()
|
||||
logger.error(f"Error creating user: {e!s}", exc_info=True)
|
||||
logger.exception("Error creating user: %s", e)
|
||||
raise AuthenticationError(f"Failed to create user: {e!s}")
|
||||
|
||||
@staticmethod
|
||||
@@ -168,8 +158,7 @@ class AuthService:
|
||||
user_id = token_data.user_id
|
||||
|
||||
# Get user from database
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
user = await user_repo.get(db, id=str(user_id))
|
||||
if not user or not user.is_active:
|
||||
raise TokenInvalidError("Invalid user or inactive account")
|
||||
|
||||
@@ -177,7 +166,7 @@ class AuthService:
|
||||
return AuthService.create_tokens(user)
|
||||
|
||||
except (TokenExpiredError, TokenInvalidError) as e:
|
||||
logger.warning(f"Token refresh failed: {e!s}")
|
||||
logger.warning("Token refresh failed: %s", e)
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
@@ -200,8 +189,7 @@ class AuthService:
|
||||
AuthenticationError: If current password is incorrect or update fails
|
||||
"""
|
||||
try:
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
user = await user_repo.get(db, id=str(user_id))
|
||||
if not user:
|
||||
raise AuthenticationError("User not found")
|
||||
|
||||
@@ -210,10 +198,10 @@ class AuthService:
|
||||
raise AuthenticationError("Current password is incorrect")
|
||||
|
||||
# Hash new password asynchronously to avoid blocking event loop
|
||||
user.password_hash = await get_password_hash_async(new_password)
|
||||
await db.commit()
|
||||
new_hash = await get_password_hash_async(new_password)
|
||||
await user_repo.update_password(db, user=user, password_hash=new_hash)
|
||||
|
||||
logger.info(f"Password changed successfully for user {user_id}")
|
||||
logger.info("Password changed successfully for user %s", user_id)
|
||||
return True
|
||||
|
||||
except AuthenticationError:
|
||||
@@ -222,7 +210,34 @@ class AuthService:
|
||||
except Exception as e:
|
||||
# Rollback on any database errors
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Error changing password for user {user_id}: {e!s}", exc_info=True
|
||||
)
|
||||
logger.exception("Error changing password for user %s: %s", user_id, e)
|
||||
raise AuthenticationError(f"Failed to change password: {e!s}")
|
||||
|
||||
@staticmethod
|
||||
async def reset_password(
|
||||
db: AsyncSession, *, email: str, new_password: str
|
||||
) -> User:
|
||||
"""
|
||||
Reset a user's password without requiring the current password.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
email: User email address
|
||||
new_password: New password to set
|
||||
|
||||
Returns:
|
||||
Updated user
|
||||
|
||||
Raises:
|
||||
AuthenticationError: If user not found or inactive
|
||||
"""
|
||||
user = await user_repo.get_by_email(db, email=email)
|
||||
if not user:
|
||||
raise AuthenticationError("User not found")
|
||||
if not user.is_active:
|
||||
raise AuthenticationError("User account is inactive")
|
||||
|
||||
new_hash = await get_password_hash_async(new_password)
|
||||
user = await user_repo.update_password(db, user=user, password_hash=new_hash)
|
||||
logger.info("Password reset successfully for %s", email)
|
||||
return user
|
||||
|
||||
@@ -58,8 +58,8 @@ class ConsoleEmailBackend(EmailBackend):
|
||||
logger.info("=" * 80)
|
||||
logger.info("EMAIL SENT (Console Backend)")
|
||||
logger.info("=" * 80)
|
||||
logger.info(f"To: {', '.join(to)}")
|
||||
logger.info(f"Subject: {subject}")
|
||||
logger.info("To: %s", ", ".join(to))
|
||||
logger.info("Subject: %s", subject)
|
||||
logger.info("-" * 80)
|
||||
if text_content:
|
||||
logger.info("Plain Text Content:")
|
||||
@@ -199,7 +199,7 @@ The {settings.PROJECT_NAME} Team
|
||||
text_content=text_content,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send password reset email to {to_email}: {e!s}")
|
||||
logger.error("Failed to send password reset email to %s: %s", to_email, e)
|
||||
return False
|
||||
|
||||
async def send_email_verification(
|
||||
@@ -287,7 +287,7 @@ The {settings.PROJECT_NAME} Team
|
||||
text_content=text_content,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send verification email to {to_email}: {e!s}")
|
||||
logger.error("Failed to send verification email to %s: %s", to_email, e)
|
||||
return False
|
||||
|
||||
|
||||
|
||||
@@ -25,15 +25,19 @@ from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from jose import jwt
|
||||
from sqlalchemy import and_, delete, select
|
||||
import jwt
|
||||
from jwt.exceptions import ExpiredSignatureError, InvalidTokenError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.config import settings
|
||||
from app.models.oauth_authorization_code import OAuthAuthorizationCode
|
||||
from app.models.oauth_client import OAuthClient
|
||||
from app.models.oauth_provider_token import OAuthConsent, OAuthProviderRefreshToken
|
||||
from app.models.user import User
|
||||
from app.repositories.oauth_authorization_code import oauth_authorization_code_repo
|
||||
from app.repositories.oauth_client import oauth_client_repo
|
||||
from app.repositories.oauth_consent import oauth_consent_repo
|
||||
from app.repositories.oauth_provider_token import oauth_provider_token_repo
|
||||
from app.repositories.user import user_repo
|
||||
from app.schemas.oauth import OAuthClientCreate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -135,7 +139,7 @@ def verify_pkce(code_verifier: str, code_challenge: str, method: str) -> bool:
|
||||
if method != "S256":
|
||||
# SECURITY: Reject any method other than S256
|
||||
# 'plain' method provides no security against code interception attacks
|
||||
logger.warning(f"PKCE verification rejected for unsupported method: {method}")
|
||||
logger.warning("PKCE verification rejected for unsupported method: %s", method)
|
||||
return False
|
||||
|
||||
# SHA-256 hash, then base64url encode (RFC 7636 Section 4.2)
|
||||
@@ -161,15 +165,7 @@ def join_scope(scopes: list[str]) -> str:
|
||||
|
||||
async def get_client(db: AsyncSession, client_id: str) -> OAuthClient | None:
|
||||
"""Get OAuth client by client_id."""
|
||||
result = await db.execute(
|
||||
select(OAuthClient).where(
|
||||
and_(
|
||||
OAuthClient.client_id == client_id,
|
||||
OAuthClient.is_active == True, # noqa: E712
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
return await oauth_client_repo.get_by_client_id(db, client_id=client_id)
|
||||
|
||||
|
||||
async def validate_client(
|
||||
@@ -204,21 +200,19 @@ async def validate_client(
|
||||
if not client.client_secret_hash:
|
||||
raise InvalidClientError("Client not configured with secret")
|
||||
|
||||
# SECURITY: Verify secret using bcrypt (not SHA-256)
|
||||
# Supports both bcrypt and legacy SHA-256 hashes for migration
|
||||
# SECURITY: Verify secret using bcrypt
|
||||
from app.core.auth import verify_password
|
||||
|
||||
stored_hash = str(client.client_secret_hash)
|
||||
|
||||
if stored_hash.startswith("$2"):
|
||||
# New bcrypt format
|
||||
if not verify_password(client_secret, stored_hash):
|
||||
raise InvalidClientError("Invalid client secret")
|
||||
else:
|
||||
# Legacy SHA-256 format
|
||||
computed_hash = hashlib.sha256(client_secret.encode()).hexdigest()
|
||||
if not secrets.compare_digest(computed_hash, stored_hash):
|
||||
raise InvalidClientError("Invalid client secret")
|
||||
if not stored_hash.startswith("$2"):
|
||||
raise InvalidClientError(
|
||||
"Client secret uses deprecated hash format. "
|
||||
"Please regenerate your client credentials."
|
||||
)
|
||||
|
||||
if not verify_password(client_secret, stored_hash):
|
||||
raise InvalidClientError("Invalid client secret")
|
||||
|
||||
return client
|
||||
|
||||
@@ -263,7 +257,9 @@ def validate_scopes(client: OAuthClient, requested_scopes: list[str]) -> list[st
|
||||
# Warn if some scopes were filtered out
|
||||
invalid = requested - allowed
|
||||
if invalid:
|
||||
logger.warning(f"Client {client.client_id} requested invalid scopes: {invalid}")
|
||||
logger.warning(
|
||||
"Client %s requested invalid scopes: %s", client.client_id, invalid
|
||||
)
|
||||
|
||||
return list(valid)
|
||||
|
||||
@@ -311,25 +307,24 @@ async def create_authorization_code(
|
||||
minutes=AUTHORIZATION_CODE_EXPIRY_MINUTES
|
||||
)
|
||||
|
||||
auth_code = OAuthAuthorizationCode(
|
||||
await oauth_authorization_code_repo.create_code(
|
||||
db,
|
||||
code=code,
|
||||
client_id=client.client_id,
|
||||
user_id=user.id,
|
||||
redirect_uri=redirect_uri,
|
||||
scope=scope,
|
||||
expires_at=expires_at,
|
||||
code_challenge=code_challenge,
|
||||
code_challenge_method=code_challenge_method,
|
||||
state=state,
|
||||
nonce=nonce,
|
||||
expires_at=expires_at,
|
||||
used=False,
|
||||
)
|
||||
|
||||
db.add(auth_code)
|
||||
await db.commit()
|
||||
|
||||
logger.info(
|
||||
f"Created authorization code for user {user.id} and client {client.client_id}"
|
||||
"Created authorization code for user %s and client %s",
|
||||
user.id,
|
||||
client.client_id,
|
||||
)
|
||||
return code
|
||||
|
||||
@@ -366,35 +361,20 @@ async def exchange_authorization_code(
|
||||
"""
|
||||
# Atomically mark code as used and fetch it (prevents race condition)
|
||||
# RFC 6749 Section 4.1.2: Authorization codes MUST be single-use
|
||||
from sqlalchemy import update
|
||||
|
||||
# First, atomically mark the code as used and get affected count
|
||||
update_stmt = (
|
||||
update(OAuthAuthorizationCode)
|
||||
.where(
|
||||
and_(
|
||||
OAuthAuthorizationCode.code == code,
|
||||
OAuthAuthorizationCode.used == False, # noqa: E712
|
||||
)
|
||||
)
|
||||
.values(used=True)
|
||||
.returning(OAuthAuthorizationCode.id)
|
||||
updated_id = await oauth_authorization_code_repo.consume_code_atomically(
|
||||
db, code=code
|
||||
)
|
||||
result = await db.execute(update_stmt)
|
||||
updated_id = result.scalar_one_or_none()
|
||||
|
||||
if not updated_id:
|
||||
# Either code doesn't exist or was already used
|
||||
# Check if it exists to provide appropriate error
|
||||
check_result = await db.execute(
|
||||
select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.code == code)
|
||||
)
|
||||
existing_code = check_result.scalar_one_or_none()
|
||||
existing_code = await oauth_authorization_code_repo.get_by_code(db, code=code)
|
||||
|
||||
if existing_code and existing_code.used:
|
||||
# Code reuse is a security incident - revoke all tokens for this grant
|
||||
logger.warning(
|
||||
f"Authorization code reuse detected for client {existing_code.client_id}"
|
||||
"Authorization code reuse detected for client %s",
|
||||
existing_code.client_id,
|
||||
)
|
||||
await revoke_tokens_for_user_client(
|
||||
db, UUID(str(existing_code.user_id)), str(existing_code.client_id)
|
||||
@@ -404,11 +384,9 @@ async def exchange_authorization_code(
|
||||
raise InvalidGrantError("Invalid authorization code")
|
||||
|
||||
# Now fetch the full auth code record
|
||||
auth_code_result = await db.execute(
|
||||
select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.id == updated_id)
|
||||
)
|
||||
auth_code = auth_code_result.scalar_one()
|
||||
await db.commit()
|
||||
auth_code = await oauth_authorization_code_repo.get_by_id(db, code_id=updated_id)
|
||||
if auth_code is None:
|
||||
raise InvalidGrantError("Authorization code not found after consumption")
|
||||
|
||||
if auth_code.is_expired:
|
||||
raise InvalidGrantError("Authorization code has expired")
|
||||
@@ -452,8 +430,7 @@ async def exchange_authorization_code(
|
||||
raise InvalidGrantError("PKCE required for public clients")
|
||||
|
||||
# Get user
|
||||
user_result = await db.execute(select(User).where(User.id == auth_code.user_id))
|
||||
user = user_result.scalar_one_or_none()
|
||||
user = await user_repo.get(db, id=str(auth_code.user_id))
|
||||
if not user or not user.is_active:
|
||||
raise InvalidGrantError("User not found or inactive")
|
||||
|
||||
@@ -543,7 +520,8 @@ async def create_tokens(
|
||||
refresh_token_hash = hash_token(refresh_token)
|
||||
|
||||
# Store refresh token in database
|
||||
refresh_token_record = OAuthProviderRefreshToken(
|
||||
await oauth_provider_token_repo.create_token(
|
||||
db,
|
||||
token_hash=refresh_token_hash,
|
||||
jti=jti,
|
||||
client_id=client.client_id,
|
||||
@@ -553,10 +531,8 @@ async def create_tokens(
|
||||
device_info=device_info,
|
||||
ip_address=ip_address,
|
||||
)
|
||||
db.add(refresh_token_record)
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"Issued tokens for user {user.id} to client {client.client_id}")
|
||||
logger.info("Issued tokens for user %s to client %s", user.id, client.client_id)
|
||||
|
||||
return {
|
||||
"access_token": access_token,
|
||||
@@ -599,12 +575,9 @@ async def refresh_tokens(
|
||||
"""
|
||||
# Find refresh token
|
||||
token_hash = hash_token(refresh_token)
|
||||
result = await db.execute(
|
||||
select(OAuthProviderRefreshToken).where(
|
||||
OAuthProviderRefreshToken.token_hash == token_hash
|
||||
)
|
||||
token_record = await oauth_provider_token_repo.get_by_token_hash(
|
||||
db, token_hash=token_hash
|
||||
)
|
||||
token_record: OAuthProviderRefreshToken | None = result.scalar_one_or_none()
|
||||
|
||||
if not token_record:
|
||||
raise InvalidGrantError("Invalid refresh token")
|
||||
@@ -612,7 +585,7 @@ async def refresh_tokens(
|
||||
if token_record.revoked:
|
||||
# Token reuse after revocation - security incident
|
||||
logger.warning(
|
||||
f"Revoked refresh token reuse detected for client {token_record.client_id}"
|
||||
"Revoked refresh token reuse detected for client %s", token_record.client_id
|
||||
)
|
||||
raise InvalidGrantError("Refresh token has been revoked")
|
||||
|
||||
@@ -631,8 +604,7 @@ async def refresh_tokens(
|
||||
)
|
||||
|
||||
# Get user
|
||||
user_result = await db.execute(select(User).where(User.id == token_record.user_id))
|
||||
user = user_result.scalar_one_or_none()
|
||||
user = await user_repo.get(db, id=str(token_record.user_id))
|
||||
if not user or not user.is_active:
|
||||
raise InvalidGrantError("User not found or inactive")
|
||||
|
||||
@@ -648,9 +620,7 @@ async def refresh_tokens(
|
||||
final_scope = token_scope
|
||||
|
||||
# Revoke old refresh token (token rotation)
|
||||
token_record.revoked = True # type: ignore[assignment]
|
||||
token_record.last_used_at = datetime.now(UTC) # type: ignore[assignment]
|
||||
await db.commit()
|
||||
await oauth_provider_token_repo.revoke(db, token=token_record)
|
||||
|
||||
# Issue new tokens
|
||||
device = str(token_record.device_info) if token_record.device_info else None
|
||||
@@ -697,28 +667,22 @@ async def revoke_token(
|
||||
# Try as refresh token first (more likely)
|
||||
if token_type_hint != "access_token":
|
||||
token_hash = hash_token(token)
|
||||
result = await db.execute(
|
||||
select(OAuthProviderRefreshToken).where(
|
||||
OAuthProviderRefreshToken.token_hash == token_hash
|
||||
)
|
||||
refresh_record = await oauth_provider_token_repo.get_by_token_hash(
|
||||
db, token_hash=token_hash
|
||||
)
|
||||
refresh_record = result.scalar_one_or_none()
|
||||
|
||||
if refresh_record:
|
||||
# Validate client if provided
|
||||
if client_id and refresh_record.client_id != client_id:
|
||||
raise InvalidClientError("Token was not issued to this client")
|
||||
|
||||
refresh_record.revoked = True # type: ignore[assignment]
|
||||
await db.commit()
|
||||
logger.info(f"Revoked refresh token {refresh_record.jti[:8]}...")
|
||||
await oauth_provider_token_repo.revoke(db, token=refresh_record)
|
||||
logger.info("Revoked refresh token %s...", refresh_record.jti[:8])
|
||||
return True
|
||||
|
||||
# Try as access token (JWT)
|
||||
if token_type_hint != "refresh_token":
|
||||
try:
|
||||
from jose.exceptions import JWTError
|
||||
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.SECRET_KEY,
|
||||
@@ -731,22 +695,18 @@ async def revoke_token(
|
||||
jti = payload.get("jti")
|
||||
if jti:
|
||||
# Find and revoke the associated refresh token
|
||||
result = await db.execute(
|
||||
select(OAuthProviderRefreshToken).where(
|
||||
OAuthProviderRefreshToken.jti == jti
|
||||
)
|
||||
)
|
||||
refresh_record = result.scalar_one_or_none()
|
||||
refresh_record = await oauth_provider_token_repo.get_by_jti(db, jti=jti)
|
||||
if refresh_record:
|
||||
if client_id and refresh_record.client_id != client_id:
|
||||
raise InvalidClientError("Token was not issued to this client")
|
||||
refresh_record.revoked = True # type: ignore[assignment]
|
||||
await db.commit()
|
||||
await oauth_provider_token_repo.revoke(db, token=refresh_record)
|
||||
logger.info(
|
||||
f"Revoked refresh token via access token JTI {jti[:8]}..."
|
||||
"Revoked refresh token via access token JTI %s...", jti[:8]
|
||||
)
|
||||
return True
|
||||
except (JWTError, Exception): # noqa: S110 - Intentional: invalid JWT not an error
|
||||
except InvalidTokenError:
|
||||
pass
|
||||
except Exception: # noqa: S110 - Intentional: invalid JWT not an error
|
||||
pass
|
||||
|
||||
return False
|
||||
@@ -770,26 +730,13 @@ async def revoke_tokens_for_user_client(
|
||||
Returns:
|
||||
Number of tokens revoked
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(OAuthProviderRefreshToken).where(
|
||||
and_(
|
||||
OAuthProviderRefreshToken.user_id == user_id,
|
||||
OAuthProviderRefreshToken.client_id == client_id,
|
||||
OAuthProviderRefreshToken.revoked == False, # noqa: E712
|
||||
)
|
||||
)
|
||||
count = await oauth_provider_token_repo.revoke_all_for_user_client(
|
||||
db, user_id=user_id, client_id=client_id
|
||||
)
|
||||
tokens = result.scalars().all()
|
||||
|
||||
count = 0
|
||||
for token in tokens:
|
||||
token.revoked = True # type: ignore[assignment]
|
||||
count += 1
|
||||
|
||||
if count > 0:
|
||||
await db.commit()
|
||||
logger.warning(
|
||||
f"Revoked {count} tokens for user {user_id} and client {client_id}"
|
||||
"Revoked %s tokens for user %s and client %s", count, user_id, client_id
|
||||
)
|
||||
|
||||
return count
|
||||
@@ -808,24 +755,10 @@ async def revoke_all_user_tokens(db: AsyncSession, user_id: UUID) -> int:
|
||||
Returns:
|
||||
Number of tokens revoked
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(OAuthProviderRefreshToken).where(
|
||||
and_(
|
||||
OAuthProviderRefreshToken.user_id == user_id,
|
||||
OAuthProviderRefreshToken.revoked == False, # noqa: E712
|
||||
)
|
||||
)
|
||||
)
|
||||
tokens = result.scalars().all()
|
||||
|
||||
count = 0
|
||||
for token in tokens:
|
||||
token.revoked = True # type: ignore[assignment]
|
||||
count += 1
|
||||
count = await oauth_provider_token_repo.revoke_all_for_user(db, user_id=user_id)
|
||||
|
||||
if count > 0:
|
||||
await db.commit()
|
||||
logger.info(f"Revoked {count} OAuth provider tokens for user {user_id}")
|
||||
logger.info("Revoked %s OAuth provider tokens for user %s", count, user_id)
|
||||
|
||||
return count
|
||||
|
||||
@@ -864,8 +797,6 @@ async def introspect_token(
|
||||
# Try as access token (JWT) first
|
||||
if token_type_hint != "refresh_token":
|
||||
try:
|
||||
from jose.exceptions import ExpiredSignatureError, JWTError
|
||||
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.SECRET_KEY,
|
||||
@@ -878,12 +809,7 @@ async def introspect_token(
|
||||
# Check if associated refresh token is revoked
|
||||
jti = payload.get("jti")
|
||||
if jti:
|
||||
result = await db.execute(
|
||||
select(OAuthProviderRefreshToken).where(
|
||||
OAuthProviderRefreshToken.jti == jti
|
||||
)
|
||||
)
|
||||
refresh_record = result.scalar_one_or_none()
|
||||
refresh_record = await oauth_provider_token_repo.get_by_jti(db, jti=jti)
|
||||
if refresh_record and refresh_record.revoked:
|
||||
return {"active": False}
|
||||
|
||||
@@ -901,18 +827,17 @@ async def introspect_token(
|
||||
}
|
||||
except ExpiredSignatureError:
|
||||
return {"active": False}
|
||||
except (JWTError, Exception): # noqa: S110 - Intentional: invalid JWT falls through to refresh token check
|
||||
except InvalidTokenError:
|
||||
pass
|
||||
except Exception: # noqa: S110 - Intentional: invalid JWT falls through to refresh token check
|
||||
pass
|
||||
|
||||
# Try as refresh token
|
||||
if token_type_hint != "access_token":
|
||||
token_hash = hash_token(token)
|
||||
result = await db.execute(
|
||||
select(OAuthProviderRefreshToken).where(
|
||||
OAuthProviderRefreshToken.token_hash == token_hash
|
||||
)
|
||||
refresh_record = await oauth_provider_token_repo.get_by_token_hash(
|
||||
db, token_hash=token_hash
|
||||
)
|
||||
refresh_record = result.scalar_one_or_none()
|
||||
|
||||
if refresh_record and refresh_record.is_valid:
|
||||
return {
|
||||
@@ -937,17 +862,11 @@ async def get_consent(
|
||||
db: AsyncSession,
|
||||
user_id: UUID,
|
||||
client_id: str,
|
||||
) -> OAuthConsent | None:
|
||||
):
|
||||
"""Get existing consent record for user-client pair."""
|
||||
result = await db.execute(
|
||||
select(OAuthConsent).where(
|
||||
and_(
|
||||
OAuthConsent.user_id == user_id,
|
||||
OAuthConsent.client_id == client_id,
|
||||
)
|
||||
)
|
||||
return await oauth_consent_repo.get_consent(
|
||||
db, user_id=user_id, client_id=client_id
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def check_consent(
|
||||
@@ -972,31 +891,15 @@ async def grant_consent(
|
||||
user_id: UUID,
|
||||
client_id: str,
|
||||
scopes: list[str],
|
||||
) -> OAuthConsent:
|
||||
):
|
||||
"""
|
||||
Grant or update consent for a user-client pair.
|
||||
|
||||
If consent already exists, updates the granted scopes.
|
||||
"""
|
||||
consent = await get_consent(db, user_id, client_id)
|
||||
|
||||
if consent:
|
||||
# Merge scopes
|
||||
granted = str(consent.granted_scopes) if consent.granted_scopes else ""
|
||||
existing = set(parse_scope(granted))
|
||||
new_scopes = existing | set(scopes)
|
||||
consent.granted_scopes = join_scope(list(new_scopes)) # type: ignore[assignment]
|
||||
else:
|
||||
consent = OAuthConsent(
|
||||
user_id=user_id,
|
||||
client_id=client_id,
|
||||
granted_scopes=join_scope(scopes),
|
||||
)
|
||||
db.add(consent)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(consent)
|
||||
return consent
|
||||
return await oauth_consent_repo.grant_consent(
|
||||
db, user_id=user_id, client_id=client_id, scopes=scopes
|
||||
)
|
||||
|
||||
|
||||
async def revoke_consent(
|
||||
@@ -1009,21 +912,13 @@ async def revoke_consent(
|
||||
|
||||
Returns True if consent was found and revoked.
|
||||
"""
|
||||
# Delete consent record
|
||||
result = await db.execute(
|
||||
delete(OAuthConsent).where(
|
||||
and_(
|
||||
OAuthConsent.user_id == user_id,
|
||||
OAuthConsent.client_id == client_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Revoke all tokens
|
||||
# Revoke all tokens first
|
||||
await revoke_tokens_for_user_client(db, user_id, client_id)
|
||||
|
||||
await db.commit()
|
||||
return result.rowcount > 0 # type: ignore[attr-defined]
|
||||
# Delete consent record
|
||||
return await oauth_consent_repo.revoke_consent(
|
||||
db, user_id=user_id, client_id=client_id
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
@@ -1031,6 +926,26 @@ async def revoke_consent(
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def register_client(db: AsyncSession, client_data: OAuthClientCreate) -> tuple:
|
||||
"""Create a new OAuth client. Returns (client, secret)."""
|
||||
return await oauth_client_repo.create_client(db, obj_in=client_data)
|
||||
|
||||
|
||||
async def list_clients(db: AsyncSession) -> list:
|
||||
"""List all registered OAuth clients."""
|
||||
return await oauth_client_repo.get_all_clients(db)
|
||||
|
||||
|
||||
async def delete_client_by_id(db: AsyncSession, client_id: str) -> None:
|
||||
"""Delete an OAuth client by client_id."""
|
||||
await oauth_client_repo.delete_client(db, client_id=client_id)
|
||||
|
||||
|
||||
async def list_user_consents(db: AsyncSession, user_id: UUID) -> list[dict]:
|
||||
"""Get all OAuth consents for a user with client details."""
|
||||
return await oauth_consent_repo.get_user_consents_with_clients(db, user_id=user_id)
|
||||
|
||||
|
||||
async def cleanup_expired_codes(db: AsyncSession) -> int:
|
||||
"""
|
||||
Delete expired authorization codes.
|
||||
@@ -1040,13 +955,7 @@ async def cleanup_expired_codes(db: AsyncSession) -> int:
|
||||
Returns:
|
||||
Number of codes deleted
|
||||
"""
|
||||
result = await db.execute(
|
||||
delete(OAuthAuthorizationCode).where(
|
||||
OAuthAuthorizationCode.expires_at < datetime.now(UTC)
|
||||
)
|
||||
)
|
||||
await db.commit()
|
||||
return result.rowcount # type: ignore[attr-defined]
|
||||
return await oauth_authorization_code_repo.cleanup_expired(db)
|
||||
|
||||
|
||||
async def cleanup_expired_tokens(db: AsyncSession) -> int:
|
||||
@@ -1058,12 +967,4 @@ async def cleanup_expired_tokens(db: AsyncSession) -> int:
|
||||
Returns:
|
||||
Number of tokens deleted
|
||||
"""
|
||||
# Delete tokens that are both expired AND revoked (or just very old)
|
||||
cutoff = datetime.now(UTC) - timedelta(days=7)
|
||||
result = await db.execute(
|
||||
delete(OAuthProviderRefreshToken).where(
|
||||
OAuthProviderRefreshToken.expires_at < cutoff
|
||||
)
|
||||
)
|
||||
await db.commit()
|
||||
return result.rowcount # type: ignore[attr-defined]
|
||||
return await oauth_provider_token_repo.cleanup_expired(db, cutoff_days=7)
|
||||
|
||||
@@ -19,14 +19,15 @@ from typing import TypedDict, cast
|
||||
from uuid import UUID
|
||||
|
||||
from authlib.integrations.httpx_client import AsyncOAuth2Client
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.auth import create_access_token, create_refresh_token
|
||||
from app.core.config import settings
|
||||
from app.core.exceptions import AuthenticationError
|
||||
from app.crud import oauth_account, oauth_state
|
||||
from app.models.user import User
|
||||
from app.repositories.oauth_account import oauth_account_repo as oauth_account
|
||||
from app.repositories.oauth_state import oauth_state_repo as oauth_state
|
||||
from app.repositories.user import user_repo
|
||||
from app.schemas.oauth import (
|
||||
OAuthAccountCreate,
|
||||
OAuthCallbackResponse,
|
||||
@@ -38,19 +39,22 @@ from app.schemas.oauth import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuthProviderConfig(TypedDict, total=False):
|
||||
"""Type definition for OAuth provider configuration."""
|
||||
|
||||
class _OAuthProviderConfigRequired(TypedDict):
|
||||
name: str
|
||||
icon: str
|
||||
authorize_url: str
|
||||
token_url: str
|
||||
userinfo_url: str
|
||||
email_url: str # Optional, GitHub-only
|
||||
scopes: list[str]
|
||||
supports_pkce: bool
|
||||
|
||||
|
||||
class OAuthProviderConfig(_OAuthProviderConfigRequired, total=False):
|
||||
"""Type definition for OAuth provider configuration."""
|
||||
|
||||
email_url: str # Optional, GitHub-only
|
||||
|
||||
|
||||
# Provider configurations
|
||||
OAUTH_PROVIDERS: dict[str, OAuthProviderConfig] = {
|
||||
"google": {
|
||||
@@ -215,7 +219,7 @@ class OAuthService:
|
||||
**auth_params,
|
||||
)
|
||||
|
||||
logger.info(f"OAuth authorization URL created for {provider}")
|
||||
logger.info("OAuth authorization URL created for %s", provider)
|
||||
return url, state
|
||||
|
||||
@staticmethod
|
||||
@@ -250,8 +254,9 @@ class OAuthService:
|
||||
# This prevents authorization code injection attacks (RFC 6749 Section 10.6)
|
||||
if state_record.redirect_uri != redirect_uri:
|
||||
logger.warning(
|
||||
f"OAuth redirect_uri mismatch: expected {state_record.redirect_uri}, "
|
||||
f"got {redirect_uri}"
|
||||
"OAuth redirect_uri mismatch: expected %s, got %s",
|
||||
state_record.redirect_uri,
|
||||
redirect_uri,
|
||||
)
|
||||
raise AuthenticationError("Redirect URI mismatch")
|
||||
|
||||
@@ -295,7 +300,7 @@ class OAuthService:
|
||||
except AuthenticationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"OAuth token exchange failed: {e!s}")
|
||||
logger.error("OAuth token exchange failed: %s", e)
|
||||
raise AuthenticationError("Failed to exchange authorization code")
|
||||
|
||||
# Get user info from provider
|
||||
@@ -308,7 +313,7 @@ class OAuthService:
|
||||
client, provider, config, access_token
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get user info: {e!s}")
|
||||
logger.error("Failed to get user info: %s", e)
|
||||
raise AuthenticationError(
|
||||
"Failed to get user information from provider"
|
||||
)
|
||||
@@ -343,18 +348,17 @@ class OAuthService:
|
||||
await oauth_account.update_tokens(
|
||||
db,
|
||||
account=existing_oauth,
|
||||
access_token_encrypted=token.get("access_token"), refresh_token_encrypted=token.get("refresh_token"), token_expires_at=datetime.now(UTC)
|
||||
access_token=token.get("access_token"),
|
||||
refresh_token=token.get("refresh_token"),
|
||||
token_expires_at=datetime.now(UTC)
|
||||
+ timedelta(seconds=token.get("expires_in", 3600)),
|
||||
)
|
||||
|
||||
logger.info(f"OAuth login successful for {user.email} via {provider}")
|
||||
logger.info("OAuth login successful for %s via %s", user.email, provider)
|
||||
|
||||
elif state_record.user_id:
|
||||
# Account linking flow (user is already logged in)
|
||||
result = await db.execute(
|
||||
select(User).where(User.id == state_record.user_id)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
user = await user_repo.get(db, id=str(state_record.user_id))
|
||||
|
||||
if not user:
|
||||
raise AuthenticationError("User not found for account linking")
|
||||
@@ -375,24 +379,23 @@ class OAuthService:
|
||||
provider=provider,
|
||||
provider_user_id=provider_user_id,
|
||||
provider_email=provider_email,
|
||||
access_token_encrypted=token.get("access_token"), refresh_token_encrypted=token.get("refresh_token"), token_expires_at=datetime.now(UTC)
|
||||
access_token=token.get("access_token"),
|
||||
refresh_token=token.get("refresh_token"),
|
||||
token_expires_at=datetime.now(UTC)
|
||||
+ timedelta(seconds=token.get("expires_in", 3600))
|
||||
if token.get("expires_in")
|
||||
else None,
|
||||
)
|
||||
await oauth_account.create_account(db, obj_in=oauth_create)
|
||||
|
||||
logger.info(f"OAuth account linked: {provider} -> {user.email}")
|
||||
logger.info("OAuth account linked: %s -> %s", provider, user.email)
|
||||
|
||||
else:
|
||||
# New OAuth login - check for existing user by email
|
||||
user = None
|
||||
|
||||
if provider_email and settings.OAUTH_AUTO_LINK_BY_EMAIL:
|
||||
result = await db.execute(
|
||||
select(User).where(User.email == provider_email)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
user = await user_repo.get_by_email(db, email=provider_email)
|
||||
|
||||
if user:
|
||||
# Auto-link to existing user
|
||||
@@ -407,7 +410,9 @@ class OAuthService:
|
||||
if existing_provider:
|
||||
# This shouldn't happen if we got here, but safety check
|
||||
logger.warning(
|
||||
f"OAuth account already linked (race condition?): {provider} -> {user.email}"
|
||||
"OAuth account already linked (race condition?): %s -> %s",
|
||||
provider,
|
||||
user.email,
|
||||
)
|
||||
else:
|
||||
# Create OAuth account link
|
||||
@@ -416,8 +421,8 @@ class OAuthService:
|
||||
provider=provider,
|
||||
provider_user_id=provider_user_id,
|
||||
provider_email=provider_email,
|
||||
access_token_encrypted=token.get("access_token"),
|
||||
refresh_token_encrypted=token.get("refresh_token"),
|
||||
access_token=token.get("access_token"),
|
||||
refresh_token=token.get("refresh_token"),
|
||||
token_expires_at=datetime.now(UTC)
|
||||
+ timedelta(seconds=token.get("expires_in", 3600))
|
||||
if token.get("expires_in")
|
||||
@@ -425,7 +430,9 @@ class OAuthService:
|
||||
)
|
||||
await oauth_account.create_account(db, obj_in=oauth_create)
|
||||
|
||||
logger.info(f"OAuth auto-linked by email: {provider} -> {user.email}")
|
||||
logger.info(
|
||||
"OAuth auto-linked by email: %s -> %s", provider, user.email
|
||||
)
|
||||
|
||||
else:
|
||||
# Create new user
|
||||
@@ -445,7 +452,7 @@ class OAuthService:
|
||||
)
|
||||
is_new_user = True
|
||||
|
||||
logger.info(f"New user created via OAuth: {user.email} ({provider})")
|
||||
logger.info("New user created via OAuth: %s (%s)", user.email, provider)
|
||||
|
||||
# Generate JWT tokens
|
||||
claims = {
|
||||
@@ -486,7 +493,7 @@ class OAuthService:
|
||||
# GitHub requires separate request for email
|
||||
if provider == "github" and not user_info.get("email"):
|
||||
email_resp = await client.get(
|
||||
config["email_url"],
|
||||
config["email_url"], # pyright: ignore[reportTypedDictNotRequiredAccess]
|
||||
headers=headers,
|
||||
)
|
||||
email_resp.raise_for_status()
|
||||
@@ -530,8 +537,9 @@ class OAuthService:
|
||||
AuthenticationError: If verification fails
|
||||
"""
|
||||
import httpx
|
||||
from jose import jwt as jose_jwt
|
||||
from jose.exceptions import JWTError
|
||||
import jwt as pyjwt
|
||||
from jwt.algorithms import RSAAlgorithm
|
||||
from jwt.exceptions import InvalidTokenError
|
||||
|
||||
try:
|
||||
# Fetch Google's public keys (JWKS)
|
||||
@@ -545,24 +553,27 @@ class OAuthService:
|
||||
jwks = jwks_response.json()
|
||||
|
||||
# Get the key ID from the token header
|
||||
unverified_header = jose_jwt.get_unverified_header(id_token)
|
||||
unverified_header = pyjwt.get_unverified_header(id_token)
|
||||
kid = unverified_header.get("kid")
|
||||
if not kid:
|
||||
raise AuthenticationError("ID token missing key ID (kid)")
|
||||
|
||||
# Find the matching public key
|
||||
public_key = None
|
||||
jwk_data = None
|
||||
for key in jwks.get("keys", []):
|
||||
if key.get("kid") == kid:
|
||||
public_key = key
|
||||
jwk_data = key
|
||||
break
|
||||
|
||||
if not public_key:
|
||||
if not jwk_data:
|
||||
raise AuthenticationError("ID token signed with unknown key")
|
||||
|
||||
# Convert JWK to a public key object for PyJWT
|
||||
public_key = RSAAlgorithm.from_jwk(jwk_data)
|
||||
|
||||
# Verify the token signature and decode claims
|
||||
# jose library will verify signature against the JWK
|
||||
payload = jose_jwt.decode(
|
||||
# PyJWT will verify signature against the RSA public key
|
||||
payload = pyjwt.decode(
|
||||
id_token,
|
||||
public_key,
|
||||
algorithms=["RS256"], # Google uses RS256
|
||||
@@ -581,23 +592,24 @@ class OAuthService:
|
||||
token_nonce = payload.get("nonce")
|
||||
if token_nonce != expected_nonce:
|
||||
logger.warning(
|
||||
f"OAuth ID token nonce mismatch: expected {expected_nonce}, "
|
||||
f"got {token_nonce}"
|
||||
"OAuth ID token nonce mismatch: expected %s, got %s",
|
||||
expected_nonce,
|
||||
token_nonce,
|
||||
)
|
||||
raise AuthenticationError("Invalid ID token nonce")
|
||||
|
||||
logger.debug("Google ID token verified successfully")
|
||||
return payload
|
||||
|
||||
except JWTError as e:
|
||||
logger.warning(f"Google ID token verification failed: {e}")
|
||||
except InvalidTokenError as e:
|
||||
logger.warning("Google ID token verification failed: %s", e)
|
||||
raise AuthenticationError("Invalid ID token signature")
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f"Failed to fetch Google JWKS: {e}")
|
||||
logger.error("Failed to fetch Google JWKS: %s", e)
|
||||
# If we can't verify the ID token, fail closed for security
|
||||
raise AuthenticationError("Failed to verify ID token")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error verifying Google ID token: {e}")
|
||||
logger.error("Unexpected error verifying Google ID token: %s", e)
|
||||
raise AuthenticationError("ID token verification error")
|
||||
|
||||
@staticmethod
|
||||
@@ -644,14 +656,15 @@ class OAuthService:
|
||||
provider=provider,
|
||||
provider_user_id=provider_user_id,
|
||||
provider_email=email,
|
||||
access_token_encrypted=token.get("access_token"), refresh_token_encrypted=token.get("refresh_token"), token_expires_at=datetime.now(UTC)
|
||||
access_token=token.get("access_token"),
|
||||
refresh_token=token.get("refresh_token"),
|
||||
token_expires_at=datetime.now(UTC)
|
||||
+ timedelta(seconds=token.get("expires_in", 3600))
|
||||
if token.get("expires_in")
|
||||
else None,
|
||||
)
|
||||
await oauth_account.create_account(db, obj_in=oauth_create)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
|
||||
return user
|
||||
@@ -698,9 +711,23 @@ class OAuthService:
|
||||
if not deleted:
|
||||
raise AuthenticationError(f"No {provider} account found to unlink")
|
||||
|
||||
logger.info(f"OAuth provider unlinked: {provider} from {user.email}")
|
||||
logger.info("OAuth provider unlinked: %s from %s", provider, user.email)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def get_user_accounts(db: AsyncSession, *, user_id: UUID) -> list:
|
||||
"""Get all OAuth accounts linked to a user."""
|
||||
return await oauth_account.get_user_accounts(db, user_id=user_id)
|
||||
|
||||
@staticmethod
|
||||
async def get_user_account_by_provider(
|
||||
db: AsyncSession, *, user_id: UUID, provider: str
|
||||
):
|
||||
"""Get a specific OAuth account for a user and provider."""
|
||||
return await oauth_account.get_user_account_by_provider(
|
||||
db, user_id=user_id, provider=provider
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def cleanup_expired_states(db: AsyncSession) -> int:
|
||||
"""
|
||||
|
||||
155
backend/app/services/organization_service.py
Normal file
155
backend/app/services/organization_service.py
Normal file
@@ -0,0 +1,155 @@
|
||||
# app/services/organization_service.py
|
||||
"""Service layer for organization operations — delegates to OrganizationRepository."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.exceptions import NotFoundError
|
||||
from app.models.organization import Organization
|
||||
from app.models.user_organization import OrganizationRole, UserOrganization
|
||||
from app.repositories.organization import OrganizationRepository, organization_repo
|
||||
from app.schemas.organizations import OrganizationCreate, OrganizationUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OrganizationService:
|
||||
"""Service for organization management operations."""
|
||||
|
||||
def __init__(
|
||||
self, organization_repository: OrganizationRepository | None = None
|
||||
) -> None:
|
||||
self._repo = organization_repository or organization_repo
|
||||
|
||||
async def get_organization(self, db: AsyncSession, org_id: str) -> Organization:
|
||||
"""Get organization by ID, raising NotFoundError if not found."""
|
||||
org = await self._repo.get(db, id=org_id)
|
||||
if not org:
|
||||
raise NotFoundError(f"Organization {org_id} not found")
|
||||
return org
|
||||
|
||||
async def create_organization(
|
||||
self, db: AsyncSession, *, obj_in: OrganizationCreate
|
||||
) -> Organization:
|
||||
"""Create a new organization."""
|
||||
return await self._repo.create(db, obj_in=obj_in)
|
||||
|
||||
async def update_organization(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
org: Organization,
|
||||
obj_in: OrganizationUpdate | dict[str, Any],
|
||||
) -> Organization:
|
||||
"""Update an existing organization."""
|
||||
return await self._repo.update(db, db_obj=org, obj_in=obj_in)
|
||||
|
||||
async def remove_organization(self, db: AsyncSession, org_id: str) -> None:
|
||||
"""Permanently delete an organization by ID."""
|
||||
await self._repo.remove(db, id=org_id)
|
||||
|
||||
async def get_member_count(self, db: AsyncSession, *, organization_id: UUID) -> int:
|
||||
"""Get number of active members in an organization."""
|
||||
return await self._repo.get_member_count(db, organization_id=organization_id)
|
||||
|
||||
async def get_multi_with_member_counts(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: bool | None = None,
|
||||
search: str | None = None,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""List organizations with member counts and pagination."""
|
||||
return await self._repo.get_multi_with_member_counts(
|
||||
db, skip=skip, limit=limit, is_active=is_active, search=search
|
||||
)
|
||||
|
||||
async def get_user_organizations_with_details(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: UUID,
|
||||
is_active: bool | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get all organizations a user belongs to, with membership details."""
|
||||
return await self._repo.get_user_organizations_with_details(
|
||||
db, user_id=user_id, is_active=is_active
|
||||
)
|
||||
|
||||
async def get_organization_members(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
organization_id: UUID,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: bool | None = True,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""Get members of an organization with pagination."""
|
||||
return await self._repo.get_organization_members(
|
||||
db,
|
||||
organization_id=organization_id,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
is_active=is_active,
|
||||
)
|
||||
|
||||
async def add_member(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
organization_id: UUID,
|
||||
user_id: UUID,
|
||||
role: OrganizationRole = OrganizationRole.MEMBER,
|
||||
) -> UserOrganization:
|
||||
"""Add a user to an organization."""
|
||||
return await self._repo.add_user(
|
||||
db, organization_id=organization_id, user_id=user_id, role=role
|
||||
)
|
||||
|
||||
async def remove_member(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
organization_id: UUID,
|
||||
user_id: UUID,
|
||||
) -> bool:
|
||||
"""Remove a user from an organization. Returns True if found and removed."""
|
||||
return await self._repo.remove_user(
|
||||
db, organization_id=organization_id, user_id=user_id
|
||||
)
|
||||
|
||||
async def get_user_role_in_org(
|
||||
self, db: AsyncSession, *, user_id: UUID, organization_id: UUID
|
||||
) -> OrganizationRole | None:
|
||||
"""Get the role of a user in an organization."""
|
||||
return await self._repo.get_user_role_in_org(
|
||||
db, user_id=user_id, organization_id=organization_id
|
||||
)
|
||||
|
||||
async def get_org_distribution(
|
||||
self, db: AsyncSession, *, limit: int = 6
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Return top organizations by member count for admin dashboard."""
|
||||
from sqlalchemy import func, select
|
||||
|
||||
result = await db.execute(
|
||||
select(
|
||||
Organization.name,
|
||||
func.count(UserOrganization.user_id).label("count"),
|
||||
)
|
||||
.join(UserOrganization, Organization.id == UserOrganization.organization_id)
|
||||
.group_by(Organization.name)
|
||||
.order_by(func.count(UserOrganization.user_id).desc())
|
||||
.limit(limit)
|
||||
)
|
||||
return [{"name": row.name, "value": row.count} for row in result.all()]
|
||||
|
||||
|
||||
# Default singleton
|
||||
organization_service = OrganizationService()
|
||||
@@ -8,7 +8,7 @@ import logging
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from app.core.database import SessionLocal
|
||||
from app.crud.session import session as session_crud
|
||||
from app.repositories.session import session_repo as session_repo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -32,15 +32,15 @@ async def cleanup_expired_sessions(keep_days: int = 30) -> int:
|
||||
|
||||
async with SessionLocal() as db:
|
||||
try:
|
||||
# Use CRUD method to cleanup
|
||||
count = await session_crud.cleanup_expired(db, keep_days=keep_days)
|
||||
# Use repository method to cleanup
|
||||
count = await session_repo.cleanup_expired(db, keep_days=keep_days)
|
||||
|
||||
logger.info(f"Session cleanup complete: {count} sessions deleted")
|
||||
logger.info("Session cleanup complete: %s sessions deleted", count)
|
||||
|
||||
return count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during session cleanup: {e!s}", exc_info=True)
|
||||
logger.exception("Error during session cleanup: %s", e)
|
||||
return 0
|
||||
|
||||
|
||||
@@ -79,10 +79,10 @@ async def get_session_statistics() -> dict:
|
||||
"expired": expired_sessions,
|
||||
}
|
||||
|
||||
logger.info(f"Session statistics: {stats}")
|
||||
logger.info("Session statistics: %s", stats)
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting session statistics: {e!s}", exc_info=True)
|
||||
logger.exception("Error getting session statistics: %s", e)
|
||||
return {}
|
||||
|
||||
97
backend/app/services/session_service.py
Normal file
97
backend/app/services/session_service.py
Normal file
@@ -0,0 +1,97 @@
|
||||
# app/services/session_service.py
|
||||
"""Service layer for session operations — delegates to SessionRepository."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.user_session import UserSession
|
||||
from app.repositories.session import SessionRepository, session_repo
|
||||
from app.schemas.sessions import SessionCreate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SessionService:
|
||||
"""Service for user session management operations."""
|
||||
|
||||
def __init__(self, session_repository: SessionRepository | None = None) -> None:
|
||||
self._repo = session_repository or session_repo
|
||||
|
||||
async def create_session(
|
||||
self, db: AsyncSession, *, obj_in: SessionCreate
|
||||
) -> UserSession:
|
||||
"""Create a new session record."""
|
||||
return await self._repo.create_session(db, obj_in=obj_in)
|
||||
|
||||
async def get_session(
|
||||
self, db: AsyncSession, session_id: str
|
||||
) -> UserSession | None:
|
||||
"""Get session by ID."""
|
||||
return await self._repo.get(db, id=session_id)
|
||||
|
||||
async def get_user_sessions(
|
||||
self, db: AsyncSession, *, user_id: str, active_only: bool = True
|
||||
) -> list[UserSession]:
|
||||
"""Get all sessions for a user."""
|
||||
return await self._repo.get_user_sessions(
|
||||
db, user_id=user_id, active_only=active_only
|
||||
)
|
||||
|
||||
async def get_active_by_jti(
|
||||
self, db: AsyncSession, *, jti: str
|
||||
) -> UserSession | None:
|
||||
"""Get active session by refresh token JTI."""
|
||||
return await self._repo.get_active_by_jti(db, jti=jti)
|
||||
|
||||
async def get_by_jti(self, db: AsyncSession, *, jti: str) -> UserSession | None:
|
||||
"""Get session by refresh token JTI (active or inactive)."""
|
||||
return await self._repo.get_by_jti(db, jti=jti)
|
||||
|
||||
async def deactivate(
|
||||
self, db: AsyncSession, *, session_id: str
|
||||
) -> UserSession | None:
|
||||
"""Deactivate a session (logout from device)."""
|
||||
return await self._repo.deactivate(db, session_id=session_id)
|
||||
|
||||
async def deactivate_all_user_sessions(
|
||||
self, db: AsyncSession, *, user_id: str
|
||||
) -> int:
|
||||
"""Deactivate all sessions for a user. Returns count deactivated."""
|
||||
return await self._repo.deactivate_all_user_sessions(db, user_id=user_id)
|
||||
|
||||
async def update_refresh_token(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
session: UserSession,
|
||||
new_jti: str,
|
||||
new_expires_at: datetime,
|
||||
) -> UserSession:
|
||||
"""Update session with a rotated refresh token."""
|
||||
return await self._repo.update_refresh_token(
|
||||
db, session=session, new_jti=new_jti, new_expires_at=new_expires_at
|
||||
)
|
||||
|
||||
async def cleanup_expired_for_user(self, db: AsyncSession, *, user_id: str) -> int:
|
||||
"""Remove expired sessions for a user. Returns count removed."""
|
||||
return await self._repo.cleanup_expired_for_user(db, user_id=user_id)
|
||||
|
||||
async def get_all_sessions(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
active_only: bool = True,
|
||||
with_user: bool = True,
|
||||
) -> tuple[list[UserSession], int]:
|
||||
"""Get all sessions with pagination (admin only)."""
|
||||
return await self._repo.get_all_sessions(
|
||||
db, skip=skip, limit=limit, active_only=active_only, with_user=with_user
|
||||
)
|
||||
|
||||
|
||||
# Default singleton
|
||||
session_service = SessionService()
|
||||
120
backend/app/services/user_service.py
Normal file
120
backend/app/services/user_service.py
Normal file
@@ -0,0 +1,120 @@
|
||||
# app/services/user_service.py
|
||||
"""Service layer for user operations — delegates to UserRepository."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.exceptions import NotFoundError
|
||||
from app.models.user import User
|
||||
from app.repositories.user import UserRepository, user_repo
|
||||
from app.schemas.users import UserCreate, UserUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UserService:
|
||||
"""Service for user management operations."""
|
||||
|
||||
def __init__(self, user_repository: UserRepository | None = None) -> None:
|
||||
self._repo = user_repository or user_repo
|
||||
|
||||
async def get_user(self, db: AsyncSession, user_id: str) -> User:
|
||||
"""Get user by ID, raising NotFoundError if not found."""
|
||||
user = await self._repo.get(db, id=user_id)
|
||||
if not user:
|
||||
raise NotFoundError(f"User {user_id} not found")
|
||||
return user
|
||||
|
||||
async def get_by_email(self, db: AsyncSession, email: str) -> User | None:
|
||||
"""Get user by email address."""
|
||||
return await self._repo.get_by_email(db, email=email)
|
||||
|
||||
async def create_user(self, db: AsyncSession, user_data: UserCreate) -> User:
|
||||
"""Create a new user."""
|
||||
return await self._repo.create(db, obj_in=user_data)
|
||||
|
||||
async def update_user(
|
||||
self, db: AsyncSession, *, user: User, obj_in: UserUpdate | dict[str, Any]
|
||||
) -> User:
|
||||
"""Update an existing user."""
|
||||
return await self._repo.update(db, db_obj=user, obj_in=obj_in)
|
||||
|
||||
async def soft_delete_user(self, db: AsyncSession, user_id: str) -> None:
|
||||
"""Soft-delete a user by ID."""
|
||||
await self._repo.soft_delete(db, id=user_id)
|
||||
|
||||
async def list_users(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
sort_by: str | None = None,
|
||||
sort_order: str = "asc",
|
||||
filters: dict[str, Any] | None = None,
|
||||
search: str | None = None,
|
||||
) -> tuple[list[User], int]:
|
||||
"""List users with pagination, sorting, filtering, and search."""
|
||||
return await self._repo.get_multi_with_total(
|
||||
db,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
sort_by=sort_by,
|
||||
sort_order=sort_order,
|
||||
filters=filters,
|
||||
search=search,
|
||||
)
|
||||
|
||||
async def bulk_update_status(
|
||||
self, db: AsyncSession, *, user_ids: list[UUID], is_active: bool
|
||||
) -> int:
|
||||
"""Bulk update active status for multiple users. Returns count updated."""
|
||||
return await self._repo.bulk_update_status(
|
||||
db, user_ids=user_ids, is_active=is_active
|
||||
)
|
||||
|
||||
async def bulk_soft_delete(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_ids: list[UUID],
|
||||
exclude_user_id: UUID | None = None,
|
||||
) -> int:
|
||||
"""Bulk soft-delete multiple users. Returns count deleted."""
|
||||
return await self._repo.bulk_soft_delete(
|
||||
db, user_ids=user_ids, exclude_user_id=exclude_user_id
|
||||
)
|
||||
|
||||
async def get_stats(self, db: AsyncSession) -> dict[str, Any]:
|
||||
"""Return user stats needed for the admin dashboard."""
|
||||
from sqlalchemy import func, select
|
||||
|
||||
total_users = (
|
||||
await db.execute(select(func.count()).select_from(User))
|
||||
).scalar() or 0
|
||||
active_count = (
|
||||
await db.execute(
|
||||
select(func.count()).select_from(User).where(User.is_active)
|
||||
)
|
||||
).scalar() or 0
|
||||
inactive_count = (
|
||||
await db.execute(
|
||||
select(func.count()).select_from(User).where(User.is_active.is_(False))
|
||||
)
|
||||
).scalar() or 0
|
||||
all_users = list(
|
||||
(await db.execute(select(User).order_by(User.created_at))).scalars().all()
|
||||
)
|
||||
return {
|
||||
"total_users": total_users,
|
||||
"active_count": active_count,
|
||||
"inactive_count": inactive_count,
|
||||
"all_users": all_users,
|
||||
}
|
||||
|
||||
|
||||
# Default singleton
|
||||
user_service = UserService()
|
||||
@@ -65,10 +65,10 @@ async def setup_async_test_db():
|
||||
async with test_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
AsyncTestingSessionLocal = sessionmaker(
|
||||
AsyncTestingSessionLocal = sessionmaker( # pyright: ignore[reportCallIssue]
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
bind=test_engine,
|
||||
bind=test_engine, # pyright: ignore[reportArgumentType]
|
||||
expire_on_commit=False,
|
||||
class_=AsyncSession,
|
||||
)
|
||||
|
||||
@@ -79,12 +79,13 @@ This FastAPI backend application follows a **clean layered architecture** patter
|
||||
|
||||
### Authentication & Security
|
||||
|
||||
- **python-jose**: JWT token generation and validation
|
||||
- Cryptographic signing
|
||||
- **PyJWT**: JWT token generation and validation
|
||||
- Cryptographic signing (HS256, RS256)
|
||||
- Token expiration handling
|
||||
- Claims validation
|
||||
- JWK support for Google ID token verification
|
||||
|
||||
- **passlib + bcrypt**: Password hashing
|
||||
- **bcrypt**: Password hashing
|
||||
- Industry-standard bcrypt algorithm
|
||||
- Configurable cost factor
|
||||
- Salt generation
|
||||
@@ -117,7 +118,8 @@ backend/
|
||||
│ ├── api/ # API layer
|
||||
│ │ ├── dependencies/ # Dependency injection
|
||||
│ │ │ ├── auth.py # Authentication dependencies
|
||||
│ │ │ └── permissions.py # Authorization dependencies
|
||||
│ │ │ ├── permissions.py # Authorization dependencies
|
||||
│ │ │ └── services.py # Service singleton injection
|
||||
│ │ ├── routes/ # API endpoints
|
||||
│ │ │ ├── auth.py # Authentication routes
|
||||
│ │ │ ├── users.py # User management routes
|
||||
@@ -131,13 +133,14 @@ backend/
|
||||
│ │ ├── config.py # Application configuration
|
||||
│ │ ├── database.py # Database connection
|
||||
│ │ ├── exceptions.py # Custom exception classes
|
||||
│ │ ├── repository_exceptions.py # Repository-level exception hierarchy
|
||||
│ │ └── middleware.py # Custom middleware
|
||||
│ │
|
||||
│ ├── crud/ # Database operations
|
||||
│ │ ├── base.py # Generic CRUD base class
|
||||
│ │ ├── user.py # User CRUD operations
|
||||
│ │ ├── session.py # Session CRUD operations
|
||||
│ │ └── organization.py # Organization CRUD
|
||||
│ ├── repositories/ # Data access layer
|
||||
│ │ ├── base.py # Generic repository base class
|
||||
│ │ ├── user.py # User repository
|
||||
│ │ ├── session.py # Session repository
|
||||
│ │ └── organization.py # Organization repository
|
||||
│ │
|
||||
│ ├── models/ # SQLAlchemy models
|
||||
│ │ ├── base.py # Base model with mixins
|
||||
@@ -153,8 +156,11 @@ backend/
|
||||
│ │ ├── sessions.py # Session schemas
|
||||
│ │ └── organizations.py # Organization schemas
|
||||
│ │
|
||||
│ ├── services/ # Business logic
|
||||
│ ├── services/ # Business logic layer
|
||||
│ │ ├── auth_service.py # Authentication service
|
||||
│ │ ├── user_service.py # User management service
|
||||
│ │ ├── session_service.py # Session management service
|
||||
│ │ ├── organization_service.py # Organization service
|
||||
│ │ ├── email_service.py # Email service
|
||||
│ │ └── session_cleanup.py # Background cleanup
|
||||
│ │
|
||||
@@ -168,20 +174,25 @@ backend/
|
||||
│
|
||||
├── tests/ # Test suite
|
||||
│ ├── api/ # Integration tests
|
||||
│ ├── crud/ # CRUD tests
|
||||
│ ├── repositories/ # Repository unit tests
|
||||
│ ├── services/ # Service unit tests
|
||||
│ ├── models/ # Model tests
|
||||
│ ├── services/ # Service tests
|
||||
│ └── conftest.py # Test configuration
|
||||
│
|
||||
├── docs/ # Documentation
|
||||
│ ├── ARCHITECTURE.md # This file
|
||||
│ ├── CODING_STANDARDS.md # Coding standards
|
||||
│ ├── COMMON_PITFALLS.md # Common mistakes to avoid
|
||||
│ ├── E2E_TESTING.md # E2E testing guide
|
||||
│ └── FEATURE_EXAMPLE.md # Feature implementation guide
|
||||
│
|
||||
├── requirements.txt # Python dependencies
|
||||
├── pytest.ini # Pytest configuration
|
||||
├── .coveragerc # Coverage configuration
|
||||
└── alembic.ini # Alembic configuration
|
||||
├── pyproject.toml # Dependencies, tool configs (Ruff, pytest, coverage, Pyright)
|
||||
├── uv.lock # Locked dependency versions (commit to git)
|
||||
├── Makefile # Development commands (quality, security, testing)
|
||||
├── .pre-commit-config.yaml # Pre-commit hook configuration
|
||||
├── .secrets.baseline # detect-secrets baseline (known false positives)
|
||||
├── alembic.ini # Alembic configuration
|
||||
└── migrate.py # Migration helper script
|
||||
```
|
||||
|
||||
## Layered Architecture
|
||||
@@ -214,11 +225,11 @@ The application follows a strict 5-layer architecture:
|
||||
└──────────────────────────┬──────────────────────────────────┘
|
||||
│ calls
|
||||
┌──────────────────────────▼──────────────────────────────────┐
|
||||
│ CRUD Layer (crud/) │
|
||||
│ Repository Layer (repositories/) │
|
||||
│ - Database operations │
|
||||
│ - Query building │
|
||||
│ - Transaction management │
|
||||
│ - Error handling │
|
||||
│ - Custom repository exceptions │
|
||||
│ - No business logic │
|
||||
└──────────────────────────┬──────────────────────────────────┘
|
||||
│ uses
|
||||
┌──────────────────────────▼──────────────────────────────────┐
|
||||
@@ -262,7 +273,7 @@ async def get_current_user_info(
|
||||
|
||||
**Rules**:
|
||||
- Should NOT contain business logic
|
||||
- Should NOT directly perform database operations (use CRUD or services)
|
||||
- Should NOT directly call repositories (use services injected via `dependencies/services.py`)
|
||||
- Must validate all input via Pydantic schemas
|
||||
- Must specify response models
|
||||
- Should apply appropriate rate limits
|
||||
@@ -279,9 +290,9 @@ async def get_current_user_info(
|
||||
|
||||
**Example**:
|
||||
```python
|
||||
def get_current_user(
|
||||
async def get_current_user(
|
||||
token: str = Depends(oauth2_scheme),
|
||||
db: Session = Depends(get_db)
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> User:
|
||||
"""
|
||||
Extract and validate user from JWT token.
|
||||
@@ -295,7 +306,7 @@ def get_current_user(
|
||||
except Exception:
|
||||
raise AuthenticationError("Invalid authentication credentials")
|
||||
|
||||
user = user_crud.get(db, id=user_id)
|
||||
user = await user_repo.get(db, id=user_id)
|
||||
if not user:
|
||||
raise AuthenticationError("User not found")
|
||||
|
||||
@@ -313,7 +324,7 @@ def get_current_user(
|
||||
**Responsibility**: Implement complex business logic
|
||||
|
||||
**Key Functions**:
|
||||
- Orchestrate multiple CRUD operations
|
||||
- Orchestrate multiple repository operations
|
||||
- Implement business rules
|
||||
- Handle external service integration
|
||||
- Coordinate transactions
|
||||
@@ -323,9 +334,9 @@ def get_current_user(
|
||||
class AuthService:
|
||||
"""Authentication service with business logic."""
|
||||
|
||||
def login(
|
||||
async def login(
|
||||
self,
|
||||
db: Session,
|
||||
db: AsyncSession,
|
||||
email: str,
|
||||
password: str,
|
||||
request: Request
|
||||
@@ -339,8 +350,8 @@ class AuthService:
|
||||
3. Generate tokens
|
||||
4. Return tokens and user info
|
||||
"""
|
||||
# Validate credentials
|
||||
user = user_crud.get_by_email(db, email=email)
|
||||
# Validate credentials via repository
|
||||
user = await user_repo.get_by_email(db, email=email)
|
||||
if not user or not verify_password(password, user.hashed_password):
|
||||
raise AuthenticationError("Invalid credentials")
|
||||
|
||||
@@ -350,11 +361,10 @@ class AuthService:
|
||||
# Extract device info
|
||||
device_info = extract_device_info(request)
|
||||
|
||||
# Create session
|
||||
session = session_crud.create_session(
|
||||
# Create session via repository
|
||||
session = await session_repo.create(
|
||||
db,
|
||||
user_id=user.id,
|
||||
device_info=device_info
|
||||
obj_in=SessionCreate(user_id=user.id, **device_info)
|
||||
)
|
||||
|
||||
# Generate tokens
|
||||
@@ -373,75 +383,60 @@ class AuthService:
|
||||
|
||||
**Rules**:
|
||||
- Contains business logic, not just data operations
|
||||
- Can call multiple CRUD operations
|
||||
- Can call multiple repository operations
|
||||
- Should handle complex workflows
|
||||
- Must maintain data consistency
|
||||
- Should use transactions when needed
|
||||
|
||||
#### 4. CRUD Layer (`app/crud/`)
|
||||
#### 4. Repository Layer (`app/repositories/`)
|
||||
|
||||
**Responsibility**: Database operations and queries
|
||||
**Responsibility**: Database operations and queries — no business logic
|
||||
|
||||
**Key Functions**:
|
||||
- Create, read, update, delete operations
|
||||
- Build database queries
|
||||
- Handle database errors
|
||||
- Raise custom repository exceptions (`DuplicateEntryError`, `IntegrityConstraintError`)
|
||||
- Manage soft deletes
|
||||
- Implement pagination and filtering
|
||||
|
||||
**Example**:
|
||||
```python
|
||||
class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
"""CRUD operations for user sessions."""
|
||||
class SessionRepository(RepositoryBase[UserSession, SessionCreate, SessionUpdate]):
|
||||
"""Repository for user sessions — database operations only."""
|
||||
|
||||
def get_by_jti(self, db: Session, jti: UUID) -> Optional[UserSession]:
|
||||
async def get_by_jti(self, db: AsyncSession, *, jti: str) -> UserSession | None:
|
||||
"""Get session by refresh token JTI."""
|
||||
try:
|
||||
return (
|
||||
db.query(UserSession)
|
||||
.filter(UserSession.refresh_token_jti == jti)
|
||||
.first()
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting session by JTI: {str(e)}")
|
||||
return None
|
||||
result = await db.execute(
|
||||
select(UserSession).where(UserSession.refresh_token_jti == jti)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
def get_active_by_jti(
|
||||
self,
|
||||
db: Session,
|
||||
jti: UUID
|
||||
) -> Optional[UserSession]:
|
||||
"""Get active session by refresh token JTI."""
|
||||
session = self.get_by_jti(db, jti=jti)
|
||||
if session and session.is_active and not session.is_expired:
|
||||
return session
|
||||
return None
|
||||
|
||||
def deactivate(self, db: Session, session_id: UUID) -> bool:
|
||||
async def deactivate(self, db: AsyncSession, *, session_id: UUID) -> bool:
|
||||
"""Deactivate a session (logout)."""
|
||||
try:
|
||||
session = self.get(db, id=session_id)
|
||||
session = await self.get(db, id=session_id)
|
||||
if not session:
|
||||
return False
|
||||
|
||||
session.is_active = False
|
||||
db.commit()
|
||||
await db.commit()
|
||||
logger.info(f"Session {session_id} deactivated")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
await db.rollback()
|
||||
logger.error(f"Error deactivating session: {str(e)}")
|
||||
return False
|
||||
```
|
||||
|
||||
**Rules**:
|
||||
- Should NOT contain business logic
|
||||
- Must handle database exceptions
|
||||
- Must use parameterized queries (SQLAlchemy does this)
|
||||
- Must raise custom repository exceptions (not raw `ValueError`/`IntegrityError`)
|
||||
- Must use async SQLAlchemy 2.0 `select()` API (never `db.query()`)
|
||||
- Should log all database errors
|
||||
- Must rollback on errors
|
||||
- Should use soft deletes when possible
|
||||
- **Never imported directly by routes** — always called through services
|
||||
|
||||
#### 5. Data Layer (`app/models/` + `app/schemas/`)
|
||||
|
||||
@@ -546,51 +541,23 @@ SessionLocal = sessionmaker(
|
||||
#### Dependency Injection Pattern
|
||||
|
||||
```python
|
||||
def get_db() -> Generator[Session, None, None]:
|
||||
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""
|
||||
Database session dependency for FastAPI routes.
|
||||
Async database session dependency for FastAPI routes.
|
||||
|
||||
Automatically commits on success, rolls back on error.
|
||||
The session is passed to service methods; commit/rollback is
|
||||
managed inside service or repository methods.
|
||||
"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
async with AsyncSessionLocal() as db:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# Usage in routes
|
||||
# Usage in routes — always through a service, never direct repository
|
||||
@router.get("/users")
|
||||
def list_users(db: Session = Depends(get_db)):
|
||||
return user_crud.get_multi(db)
|
||||
```
|
||||
|
||||
#### Context Manager Pattern
|
||||
|
||||
```python
|
||||
@contextmanager
|
||||
def transaction_scope() -> Generator[Session, None, None]:
|
||||
"""
|
||||
Context manager for database transactions.
|
||||
|
||||
Use for complex operations requiring multiple steps.
|
||||
Automatically commits on success, rolls back on error.
|
||||
"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
db.commit()
|
||||
except Exception:
|
||||
db.rollback()
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# Usage in services
|
||||
def complex_operation():
|
||||
with transaction_scope() as db:
|
||||
user = user_crud.create(db, obj_in=user_data)
|
||||
session = session_crud.create(db, session_data)
|
||||
return user, session
|
||||
async def list_users(
|
||||
user_service: UserService = Depends(get_user_service),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
return await user_service.get_users(db)
|
||||
```
|
||||
|
||||
### Model Mixins
|
||||
@@ -782,22 +749,15 @@ def get_profile(
|
||||
|
||||
```python
|
||||
@router.delete("/sessions/{session_id}")
|
||||
def revoke_session(
|
||||
async def revoke_session(
|
||||
session_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
session_service: SessionService = Depends(get_session_service),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Users can only revoke their own sessions."""
|
||||
session = session_crud.get(db, id=session_id)
|
||||
|
||||
if not session:
|
||||
raise NotFoundError("Session not found")
|
||||
|
||||
# Check ownership
|
||||
if session.user_id != current_user.id:
|
||||
raise AuthorizationError("You can only revoke your own sessions")
|
||||
|
||||
session_crud.deactivate(db, session_id=session_id)
|
||||
# SessionService verifies ownership and raises NotFoundError / AuthorizationError
|
||||
await session_service.revoke_session(db, session_id=session_id, user_id=current_user.id)
|
||||
return MessageResponse(success=True, message="Session revoked")
|
||||
```
|
||||
|
||||
@@ -1061,23 +1021,27 @@ from app.services.session_cleanup import cleanup_expired_sessions
|
||||
|
||||
scheduler = AsyncIOScheduler()
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""Start background jobs on application startup."""
|
||||
if not settings.IS_TEST: # Don't run in tests
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Application lifespan context manager."""
|
||||
# Startup
|
||||
if os.getenv("IS_TEST", "False") != "True":
|
||||
scheduler.add_job(
|
||||
cleanup_expired_sessions,
|
||||
"cron",
|
||||
hour=2, # Run at 2 AM daily
|
||||
id="cleanup_expired_sessions"
|
||||
id="cleanup_expired_sessions",
|
||||
replace_existing=True,
|
||||
)
|
||||
scheduler.start()
|
||||
logger.info("Background jobs started")
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
"""Stop background jobs on application shutdown."""
|
||||
scheduler.shutdown()
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
if os.getenv("IS_TEST", "False") != "True":
|
||||
scheduler.shutdown()
|
||||
await close_async_db() # Dispose database engine connections
|
||||
```
|
||||
|
||||
### Job Implementation
|
||||
@@ -1092,8 +1056,8 @@ async def cleanup_expired_sessions():
|
||||
Runs daily at 2 AM. Removes sessions expired for more than 30 days.
|
||||
"""
|
||||
try:
|
||||
with transaction_scope() as db:
|
||||
count = session_crud.cleanup_expired(db, keep_days=30)
|
||||
async with AsyncSessionLocal() as db:
|
||||
count = await session_repo.cleanup_expired(db, keep_days=30)
|
||||
logger.info(f"Cleaned up {count} expired sessions")
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up sessions: {str(e)}", exc_info=True)
|
||||
@@ -1110,7 +1074,7 @@ async def cleanup_expired_sessions():
|
||||
│Integration │ ← API endpoint tests
|
||||
│ Tests │
|
||||
├─────────────┤
|
||||
│ Unit │ ← CRUD, services, utilities
|
||||
│ Unit │ ← repositories, services, utilities
|
||||
│ Tests │
|
||||
└─────────────┘
|
||||
```
|
||||
@@ -1205,6 +1169,8 @@ app.add_middleware(
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
> 📖 For the full benchmarking guide (how to run, read results, write new benchmarks, and manage baselines), see **[BENCHMARKS.md](BENCHMARKS.md)**.
|
||||
|
||||
### Database Connection Pooling
|
||||
|
||||
- Pool size: 20 connections
|
||||
|
||||
311
backend/docs/BENCHMARKS.md
Normal file
311
backend/docs/BENCHMARKS.md
Normal file
@@ -0,0 +1,311 @@
|
||||
# Performance Benchmarks Guide
|
||||
|
||||
Automated performance benchmarking infrastructure using **pytest-benchmark** to detect latency regressions in critical API endpoints.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Why Benchmark?](#why-benchmark)
|
||||
- [Quick Start](#quick-start)
|
||||
- [How It Works](#how-it-works)
|
||||
- [Understanding Results](#understanding-results)
|
||||
- [Test Organization](#test-organization)
|
||||
- [Writing Benchmark Tests](#writing-benchmark-tests)
|
||||
- [Baseline Management](#baseline-management)
|
||||
- [CI/CD Integration](#cicd-integration)
|
||||
- [Troubleshooting](#troubleshooting)
|
||||
|
||||
---
|
||||
|
||||
## Why Benchmark?
|
||||
|
||||
Performance regressions are silent bugs — they don't break tests or cause errors, but they degrade the user experience over time. Common causes include:
|
||||
|
||||
- **Unintended N+1 queries** after adding a relationship
|
||||
- **Heavier serialization** after adding new fields to a response model
|
||||
- **Middleware overhead** from new security headers or logging
|
||||
- **Dependency upgrades** that introduce slower code paths
|
||||
|
||||
Without automated benchmarks, these regressions go unnoticed until users complain. Performance benchmarks serve as an **early warning system** — they measure endpoint latency on every run and flag significant deviations from an established baseline.
|
||||
|
||||
### What benchmarks give you
|
||||
|
||||
| Benefit | Description |
|
||||
|---------|-------------|
|
||||
| **Regression detection** | Automatically flags when an endpoint becomes significantly slower |
|
||||
| **Baseline tracking** | Stores known-good performance numbers for comparison |
|
||||
| **Confidence in refactors** | Verify that code changes don't degrade response times |
|
||||
| **Visibility** | Makes performance a first-class, measurable quality attribute |
|
||||
|
||||
---
|
||||
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
# Run benchmarks (no comparison, just see current numbers)
|
||||
make benchmark
|
||||
|
||||
# Save current results as the baseline
|
||||
make benchmark-save
|
||||
|
||||
# Run benchmarks and compare against the saved baseline
|
||||
make benchmark-check
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## How It Works
|
||||
|
||||
The benchmarking system has three layers:
|
||||
|
||||
### 1. pytest-benchmark integration
|
||||
|
||||
[pytest-benchmark](https://pytest-benchmark.readthedocs.io/) is a pytest plugin that provides a `benchmark` fixture. It handles:
|
||||
|
||||
- **Calibration**: Automatically determines how many iterations to run for statistical significance
|
||||
- **Timing**: Uses `time.perf_counter` for high-resolution measurements
|
||||
- **Statistics**: Computes min, max, mean, median, standard deviation, IQR, and outlier detection
|
||||
- **Comparison**: Compares current results against saved baselines and flags regressions
|
||||
|
||||
### 2. Benchmark types
|
||||
|
||||
The test suite includes two categories of performance tests:
|
||||
|
||||
| Type | How it works | Examples |
|
||||
|------|-------------|----------|
|
||||
| **pytest-benchmark tests** | Uses the `benchmark` fixture for precise, multi-round timing | `test_health_endpoint_performance`, `test_openapi_schema_performance`, `test_password_hashing_performance`, `test_password_verification_performance`, `test_access_token_creation_performance`, `test_refresh_token_creation_performance`, `test_token_decode_performance` |
|
||||
| **Manual latency tests** | Uses `time.perf_counter` with explicit thresholds (for async endpoints that pytest-benchmark doesn't support natively) | `test_login_latency`, `test_get_current_user_latency`, `test_register_latency`, `test_token_refresh_latency`, `test_sessions_list_latency`, `test_user_profile_update_latency` |
|
||||
|
||||
### 3. Regression detection
|
||||
|
||||
When running `make benchmark-check`, the system:
|
||||
|
||||
1. Runs all benchmark tests
|
||||
2. Compares results against the saved baseline (`.benchmarks/` directory)
|
||||
3. **Fails the build** if any test's mean time exceeds **200%** of the baseline (i.e., 3× slower)
|
||||
|
||||
The `200%` threshold in `--benchmark-compare-fail=mean:200%` means "fail if the mean increased by more than 200% relative to the baseline." This is deliberately generous to avoid false positives from normal run-to-run variance while still catching real regressions.
|
||||
|
||||
---
|
||||
|
||||
## Understanding Results
|
||||
|
||||
A typical benchmark output looks like this:
|
||||
|
||||
```
|
||||
--------------------------------------------------------------------------------------- benchmark: 2 tests --------------------------------------------------------------------------------------
|
||||
Name (time in ms) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations
|
||||
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
test_health_endpoint_performance 0.9841 (1.0) 1.5513 (1.0) 1.1390 (1.0) 0.1098 (1.0) 1.1151 (1.0) 0.1672 (1.0) 39;2 877.9666 (1.0) 133 1
|
||||
test_openapi_schema_performance 1.6523 (1.68) 2.0892 (1.35) 1.7843 (1.57) 0.1553 (1.41) 1.7200 (1.54) 0.1727 (1.03) 2;0 560.4471 (0.64) 10 1
|
||||
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
```
|
||||
|
||||
### Column reference
|
||||
|
||||
| Column | Meaning |
|
||||
|--------|---------|
|
||||
| **Min** | Fastest single execution |
|
||||
| **Max** | Slowest single execution |
|
||||
| **Mean** | Average across all rounds — the primary metric for regression detection |
|
||||
| **StdDev** | How much results vary between rounds (lower = more stable) |
|
||||
| **Median** | Middle value, less sensitive to outliers than mean |
|
||||
| **IQR** | Interquartile range — spread of the middle 50% of results |
|
||||
| **Outliers** | Format `A;B` — A = within 1 StdDev, B = within 1.5 IQR from quartiles |
|
||||
| **OPS** | Operations per second (`1 / Mean`) |
|
||||
| **Rounds** | How many times the test was executed (auto-calibrated) |
|
||||
| **Iterations** | Iterations per round (usually 1 for ms-scale tests) |
|
||||
|
||||
### The ratio numbers `(1.0)`, `(1.68)`, etc.
|
||||
|
||||
These show how each test compares **to the best result in that column**. The fastest test is always `(1.0)`, and others show their relative factor. For example, `(1.68)` means "1.68× slower than the fastest."
|
||||
|
||||
### Color coding
|
||||
|
||||
- **Green**: The fastest (best) value in each column
|
||||
- **Red**: The slowest (worst) value in each column
|
||||
|
||||
This is a **relative ranking within the current run** — red does NOT mean the test failed or that performance is bad. It simply highlights which endpoint is the slower one in the group.
|
||||
|
||||
### What's "normal"?
|
||||
|
||||
For this project's current endpoints:
|
||||
|
||||
| Test | Expected range | Why |
|
||||
|------|---------------|-----|
|
||||
| `GET /health` | ~1–1.5ms | Minimal logic, mocked DB check |
|
||||
| `GET /api/v1/openapi.json` | ~1.5–2.5ms | Serializes entire API schema |
|
||||
| `get_password_hash` | ~200ms | CPU-bound bcrypt hashing |
|
||||
| `verify_password` | ~200ms | CPU-bound bcrypt verification |
|
||||
| `create_access_token` | ~17–20µs | JWT encoding with HMAC-SHA256 |
|
||||
| `create_refresh_token` | ~17–20µs | JWT encoding with HMAC-SHA256 |
|
||||
| `decode_token` | ~20–25µs | JWT decoding and claim validation |
|
||||
| `POST /api/v1/auth/login` | < 500ms threshold | Includes bcrypt password verification |
|
||||
| `POST /api/v1/auth/register` | < 500ms threshold | Includes bcrypt password hashing |
|
||||
| `POST /api/v1/auth/refresh` | < 200ms threshold | Token rotation + DB session update |
|
||||
| `GET /api/v1/users/me` | < 200ms threshold | DB lookup + token validation |
|
||||
| `GET /api/v1/sessions/me` | < 200ms threshold | Session list query + token validation |
|
||||
| `PATCH /api/v1/users/me` | < 200ms threshold | DB update + token validation |
|
||||
|
||||
---
|
||||
|
||||
## Test Organization
|
||||
|
||||
```
|
||||
backend/tests/
|
||||
├── benchmarks/
|
||||
│ └── test_endpoint_performance.py # All performance benchmark tests
|
||||
│
|
||||
backend/.benchmarks/ # Saved baselines (auto-generated)
|
||||
└── Linux-CPython-3.12-64bit/
|
||||
└── 0001_baseline.json # Platform-specific baseline file
|
||||
```
|
||||
|
||||
### Test markers
|
||||
|
||||
All benchmark tests use the `@pytest.mark.benchmark` marker. The `--benchmark-only` flag ensures that only tests using the `benchmark` fixture are executed during benchmark runs, while manual latency tests (async) are skipped.
|
||||
|
||||
---
|
||||
|
||||
## Writing Benchmark Tests
|
||||
|
||||
### Stateless endpoint (using pytest-benchmark fixture)
|
||||
|
||||
```python
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
def test_my_endpoint_performance(sync_client, benchmark):
|
||||
"""Benchmark: GET /my-endpoint should respond within acceptable latency."""
|
||||
result = benchmark(sync_client.get, "/my-endpoint")
|
||||
assert result.status_code == 200
|
||||
```
|
||||
|
||||
The `benchmark` fixture handles all timing, calibration, and statistics automatically. Just pass it the callable and arguments.
|
||||
|
||||
### Async / DB-dependent endpoint (manual timing)
|
||||
|
||||
For async endpoints that require database access, use manual timing with an explicit threshold:
|
||||
|
||||
```python
|
||||
import time
|
||||
import pytest
|
||||
|
||||
MAX_RESPONSE_MS = 300
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_my_async_endpoint_latency(client, setup_fixture):
|
||||
"""Performance: endpoint must respond under threshold."""
|
||||
iterations = 5
|
||||
total_ms = 0.0
|
||||
|
||||
for _ in range(iterations):
|
||||
start = time.perf_counter()
|
||||
response = await client.get("/api/v1/my-endpoint")
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
total_ms += elapsed_ms
|
||||
assert response.status_code == 200
|
||||
|
||||
mean_ms = total_ms / iterations
|
||||
assert mean_ms < MAX_RESPONSE_MS, (
|
||||
f"Latency regression: {mean_ms:.1f}ms exceeds {MAX_RESPONSE_MS}ms threshold"
|
||||
)
|
||||
```
|
||||
|
||||
### Guidelines for new benchmarks
|
||||
|
||||
1. **Benchmark critical paths** — endpoints users hit frequently or where latency matters most
|
||||
2. **Mock external dependencies** for stateless tests to isolate endpoint overhead
|
||||
3. **Set generous thresholds** for manual tests — account for CI variability
|
||||
4. **Keep benchmarks fast** — they run on every check, so avoid heavy setup
|
||||
|
||||
---
|
||||
|
||||
## Baseline Management
|
||||
|
||||
### Saving a baseline
|
||||
|
||||
```bash
|
||||
make benchmark-save
|
||||
```
|
||||
|
||||
This runs all benchmarks and saves results to `.benchmarks/<platform>/0001_baseline.json`. The baseline captures:
|
||||
- Mean, min, max, median, stddev for each test
|
||||
- Machine info (CPU, OS, Python version)
|
||||
- Timestamp
|
||||
|
||||
### Comparing against baseline
|
||||
|
||||
```bash
|
||||
make benchmark-check
|
||||
```
|
||||
|
||||
If no baseline exists, this command automatically creates one and prints a warning. On subsequent runs, it compares current results against the saved baseline.
|
||||
|
||||
### When to update the baseline
|
||||
|
||||
- **After intentional performance changes** (e.g., you optimized an endpoint — save the new, faster baseline)
|
||||
- **After infrastructure changes** (e.g., new CI runner, different hardware)
|
||||
- **After adding new benchmark tests** (the new tests need a baseline entry)
|
||||
|
||||
```bash
|
||||
# Update the baseline after intentional changes
|
||||
make benchmark-save
|
||||
```
|
||||
|
||||
### Version control
|
||||
|
||||
The `.benchmarks/` directory can be committed to version control so that CI pipelines can compare against a known-good baseline. However, since benchmark results are machine-specific, you may prefer to generate baselines in CI rather than committing local results.
|
||||
|
||||
---
|
||||
|
||||
## CI/CD Integration
|
||||
|
||||
Add benchmark checking to your CI pipeline to catch regressions on every PR:
|
||||
|
||||
```yaml
|
||||
# Example GitHub Actions step
|
||||
- name: Performance regression check
|
||||
run: |
|
||||
cd backend
|
||||
make benchmark-save # Create baseline from main branch
|
||||
# ... apply PR changes ...
|
||||
make benchmark-check # Compare PR against baseline
|
||||
```
|
||||
|
||||
A more robust approach:
|
||||
1. Save the baseline on the `main` branch after each merge
|
||||
2. On PR branches, run `make benchmark-check` against the `main` baseline
|
||||
3. The pipeline fails if any endpoint regresses beyond the 200% threshold
|
||||
|
||||
---
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "No benchmark baseline found" warning
|
||||
|
||||
```
|
||||
⚠️ No benchmark baseline found. Run 'make benchmark-save' first to create one.
|
||||
```
|
||||
|
||||
This means no baseline file exists yet. The command will auto-create one. Future runs of `make benchmark-check` will compare against it.
|
||||
|
||||
### Machine info mismatch warning
|
||||
|
||||
```
|
||||
WARNING: benchmark machine_info is different
|
||||
```
|
||||
|
||||
This is expected when comparing baselines generated on a different machine or OS. The comparison still works, but absolute numbers may differ. Re-save the baseline on the current machine if needed.
|
||||
|
||||
### High variance (large StdDev)
|
||||
|
||||
If StdDev is high relative to the Mean, results may be unreliable. Common causes:
|
||||
- System under load during benchmark run
|
||||
- Garbage collection interference
|
||||
- Thermal throttling
|
||||
|
||||
Try running benchmarks on an idle system or increasing `min_rounds` in `pyproject.toml`.
|
||||
|
||||
### Only 7 of 13 tests run
|
||||
|
||||
The async tests (`test_login_latency`, `test_get_current_user_latency`, `test_register_latency`, `test_token_refresh_latency`, `test_sessions_list_latency`, `test_user_profile_update_latency`) are skipped during `--benchmark-only` runs because they don't use the `benchmark` fixture. They run as part of the normal test suite (`make test`) with manual threshold assertions.
|
||||
@@ -75,15 +75,14 @@ def create_user(db: Session, user_in: UserCreate) -> User:
|
||||
### 4. Code Formatting
|
||||
|
||||
Use automated formatters:
|
||||
- **Black**: Code formatting
|
||||
- **isort**: Import sorting
|
||||
- **flake8**: Linting
|
||||
- **Ruff**: Code formatting and linting (replaces Black, isort, flake8)
|
||||
- **pyright**: Static type checking
|
||||
|
||||
Run before committing:
|
||||
Run before committing (or use `make validate`):
|
||||
```bash
|
||||
black app tests
|
||||
isort app tests
|
||||
flake8 app tests
|
||||
uv run ruff format app tests
|
||||
uv run ruff check app tests
|
||||
uv run pyright app
|
||||
```
|
||||
|
||||
## Code Organization
|
||||
@@ -94,19 +93,17 @@ Follow the 5-layer architecture strictly:
|
||||
|
||||
```
|
||||
API Layer (routes/)
|
||||
↓ calls
|
||||
Dependencies (dependencies/)
|
||||
↓ injects
|
||||
↓ calls (via service injected from dependencies/services.py)
|
||||
Service Layer (services/)
|
||||
↓ calls
|
||||
CRUD Layer (crud/)
|
||||
Repository Layer (repositories/)
|
||||
↓ uses
|
||||
Models & Schemas (models/, schemas/)
|
||||
```
|
||||
|
||||
**Rules:**
|
||||
- Routes should NOT directly call CRUD operations (use services when business logic is needed)
|
||||
- CRUD operations should NOT contain business logic
|
||||
- Routes must NEVER import repositories directly — always use a service
|
||||
- Services call repositories; repositories contain only database operations
|
||||
- Models should NOT import from higher layers
|
||||
- Each layer should only depend on the layer directly below it
|
||||
|
||||
@@ -125,7 +122,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
# 3. Local application imports
|
||||
from app.api.dependencies.auth import get_current_user
|
||||
from app.crud import user_crud
|
||||
from app.api.dependencies.services import get_user_service
|
||||
from app.models.user import User
|
||||
from app.schemas.users import UserResponse, UserCreate
|
||||
```
|
||||
@@ -217,7 +214,7 @@ if not user:
|
||||
|
||||
### Error Handling Pattern
|
||||
|
||||
Always follow this pattern in CRUD operations (Async version):
|
||||
Always follow this pattern in repository operations (Async version):
|
||||
|
||||
```python
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError, DataError
|
||||
@@ -430,7 +427,7 @@ backend/app/alembic/versions/
|
||||
|
||||
## Database Operations
|
||||
|
||||
### Async CRUD Pattern
|
||||
### Async Repository Pattern
|
||||
|
||||
**IMPORTANT**: This application uses **async SQLAlchemy** with modern patterns for better performance and testability.
|
||||
|
||||
@@ -442,19 +439,19 @@ backend/app/alembic/versions/
|
||||
4. **Testability**: Easy to mock and test
|
||||
5. **Consistent Ordering**: Always order queries for pagination
|
||||
|
||||
### Use the Async CRUD Base Class
|
||||
### Use the Async Repository Base Class
|
||||
|
||||
Always inherit from `CRUDBase` for database operations:
|
||||
Always inherit from `RepositoryBase` for database operations:
|
||||
|
||||
```python
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from app.crud.base import CRUDBase
|
||||
from app.repositories.base import RepositoryBase
|
||||
from app.models.user import User
|
||||
from app.schemas.users import UserCreate, UserUpdate
|
||||
|
||||
class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
"""CRUD operations for User model."""
|
||||
class UserRepository(RepositoryBase[User, UserCreate, UserUpdate]):
|
||||
"""Repository for User model — database operations only."""
|
||||
|
||||
async def get_by_email(
|
||||
self,
|
||||
@@ -467,7 +464,7 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
user_crud = CRUDUser(User)
|
||||
user_repo = UserRepository(User)
|
||||
```
|
||||
|
||||
**Key Points:**
|
||||
@@ -476,6 +473,7 @@ user_crud = CRUDUser(User)
|
||||
- Use `await db.execute()` for queries
|
||||
- Use `.scalar_one_or_none()` instead of `.first()`
|
||||
- Use `T | None` instead of `Optional[T]`
|
||||
- Repository instances are used internally by services — never import them in routes
|
||||
|
||||
### Modern SQLAlchemy Patterns
|
||||
|
||||
@@ -563,13 +561,13 @@ async def create_user(
|
||||
The database session is automatically managed by FastAPI.
|
||||
Commit on success, rollback on error.
|
||||
"""
|
||||
return await user_crud.create(db, obj_in=user_in)
|
||||
return await user_service.create_user(db, obj_in=user_in)
|
||||
```
|
||||
|
||||
**Key Points:**
|
||||
- Route functions must be `async def`
|
||||
- Database parameter is `AsyncSession`
|
||||
- Always `await` CRUD operations
|
||||
- Always `await` repository operations
|
||||
|
||||
#### In Services (Multiple Operations)
|
||||
|
||||
@@ -582,12 +580,11 @@ async def complex_operation(
|
||||
"""
|
||||
Perform multiple database operations atomically.
|
||||
|
||||
The session automatically commits on success or rolls back on error.
|
||||
Services call repositories; commit/rollback is handled inside
|
||||
each repository method.
|
||||
"""
|
||||
user = await user_crud.create(db, obj_in=user_data)
|
||||
session = await session_crud.create(db, obj_in=session_data)
|
||||
|
||||
# Commit is handled by the route's dependency
|
||||
user = await user_repo.create(db, obj_in=user_data)
|
||||
session = await session_repo.create(db, obj_in=session_data)
|
||||
return user, session
|
||||
```
|
||||
|
||||
@@ -597,10 +594,10 @@ Prefer soft deletes over hard deletes for audit trails:
|
||||
|
||||
```python
|
||||
# Good - Soft delete (sets deleted_at)
|
||||
await user_crud.soft_delete(db, id=user_id)
|
||||
await user_repo.soft_delete(db, id=user_id)
|
||||
|
||||
# Acceptable only when required - Hard delete
|
||||
user_crud.remove(db, id=user_id)
|
||||
await user_repo.remove(db, id=user_id)
|
||||
```
|
||||
|
||||
### Query Patterns
|
||||
@@ -740,9 +737,10 @@ Always implement pagination for list endpoints:
|
||||
from app.schemas.common import PaginationParams, PaginatedResponse
|
||||
|
||||
@router.get("/users", response_model=PaginatedResponse[UserResponse])
|
||||
def list_users(
|
||||
async def list_users(
|
||||
pagination: PaginationParams = Depends(),
|
||||
db: Session = Depends(get_db)
|
||||
user_service: UserService = Depends(get_user_service),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
List all users with pagination.
|
||||
@@ -750,10 +748,8 @@ def list_users(
|
||||
Default page size: 20
|
||||
Maximum page size: 100
|
||||
"""
|
||||
users, total = user_crud.get_multi_with_total(
|
||||
db,
|
||||
skip=pagination.offset,
|
||||
limit=pagination.limit
|
||||
users, total = await user_service.get_users(
|
||||
db, skip=pagination.offset, limit=pagination.limit
|
||||
)
|
||||
return PaginatedResponse(data=users, pagination=pagination.create_meta(total))
|
||||
```
|
||||
@@ -816,19 +812,17 @@ def admin_route(
|
||||
pass
|
||||
|
||||
# Check ownership
|
||||
def delete_resource(
|
||||
async def delete_resource(
|
||||
resource_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
resource_service: ResourceService = Depends(get_resource_service),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
resource = resource_crud.get(db, id=resource_id)
|
||||
if not resource:
|
||||
raise NotFoundError("Resource not found")
|
||||
|
||||
if resource.user_id != current_user.id and not current_user.is_superuser:
|
||||
raise AuthorizationError("You can only delete your own resources")
|
||||
|
||||
resource_crud.remove(db, id=resource_id)
|
||||
# Service handles ownership check and raises appropriate errors
|
||||
await resource_service.delete_resource(
|
||||
db, resource_id=resource_id, user_id=current_user.id,
|
||||
is_superuser=current_user.is_superuser,
|
||||
)
|
||||
```
|
||||
|
||||
### Input Validation
|
||||
@@ -862,9 +856,9 @@ tests/
|
||||
├── api/ # Integration tests
|
||||
│ ├── test_users.py
|
||||
│ └── test_auth.py
|
||||
├── crud/ # Unit tests for CRUD
|
||||
├── models/ # Model tests
|
||||
└── services/ # Service tests
|
||||
├── repositories/ # Unit tests for repositories
|
||||
├── services/ # Unit tests for services
|
||||
└── models/ # Model tests
|
||||
```
|
||||
|
||||
### Async Testing with pytest-asyncio
|
||||
@@ -927,7 +921,7 @@ async def test_user(db_session: AsyncSession) -> User:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user(db_session: AsyncSession, test_user: User):
|
||||
"""Test retrieving a user by ID."""
|
||||
user = await user_crud.get(db_session, id=test_user.id)
|
||||
user = await user_repo.get(db_session, id=test_user.id)
|
||||
assert user is not None
|
||||
assert user.email == test_user.email
|
||||
```
|
||||
|
||||
@@ -334,14 +334,14 @@ def login(request: Request, credentials: OAuth2PasswordRequestForm):
|
||||
# ❌ WRONG - Returns password hash!
|
||||
@router.get("/users/{user_id}")
|
||||
def get_user(user_id: UUID, db: Session = Depends(get_db)) -> User:
|
||||
return user_crud.get(db, id=user_id) # Returns ORM model with ALL fields!
|
||||
return user_repo.get(db, id=user_id) # Returns ORM model with ALL fields!
|
||||
```
|
||||
|
||||
```python
|
||||
# ✅ CORRECT - Use response schema
|
||||
@router.get("/users/{user_id}", response_model=UserResponse)
|
||||
def get_user(user_id: UUID, db: Session = Depends(get_db)):
|
||||
user = user_crud.get(db, id=user_id)
|
||||
user = user_repo.get(db, id=user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
return user # Pydantic filters to only UserResponse fields
|
||||
@@ -506,8 +506,8 @@ def revoke_session(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
session = session_crud.get(db, id=session_id)
|
||||
session_crud.deactivate(db, session_id=session_id)
|
||||
session = session_repo.get(db, id=session_id)
|
||||
session_repo.deactivate(db, session_id=session_id)
|
||||
# BUG: User can revoke ANYONE'S session!
|
||||
return {"message": "Session revoked"}
|
||||
```
|
||||
@@ -520,7 +520,7 @@ def revoke_session(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
session = session_crud.get(db, id=session_id)
|
||||
session = session_repo.get(db, id=session_id)
|
||||
|
||||
if not session:
|
||||
raise NotFoundError("Session not found")
|
||||
@@ -529,7 +529,7 @@ def revoke_session(
|
||||
if session.user_id != current_user.id:
|
||||
raise AuthorizationError("You can only revoke your own sessions")
|
||||
|
||||
session_crud.deactivate(db, session_id=session_id)
|
||||
session_repo.deactivate(db, session_id=session_id)
|
||||
return {"message": "Session revoked"}
|
||||
```
|
||||
|
||||
@@ -616,7 +616,43 @@ def create_user(
|
||||
return user
|
||||
```
|
||||
|
||||
**Rule**: Add type hints to ALL functions. Use `mypy` to enforce type checking.
|
||||
**Rule**: Add type hints to ALL functions. Use `pyright` to enforce type checking (`make type-check`).
|
||||
|
||||
---
|
||||
|
||||
---
|
||||
|
||||
### ❌ PITFALL #19: Importing Repositories Directly in Routes
|
||||
|
||||
**Issue**: Routes should never call repositories directly. The layered architecture requires all business operations to go through the service layer.
|
||||
|
||||
```python
|
||||
# ❌ WRONG - Route bypasses service layer
|
||||
from app.repositories.session import session_repo
|
||||
|
||||
@router.get("/sessions/me")
|
||||
async def list_sessions(
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
return await session_repo.get_user_sessions(db, user_id=current_user.id)
|
||||
```
|
||||
|
||||
```python
|
||||
# ✅ CORRECT - Route calls service injected via dependency
|
||||
from app.api.dependencies.services import get_session_service
|
||||
from app.services.session_service import SessionService
|
||||
|
||||
@router.get("/sessions/me")
|
||||
async def list_sessions(
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
session_service: SessionService = Depends(get_session_service),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
return await session_service.get_user_sessions(db, user_id=current_user.id)
|
||||
```
|
||||
|
||||
**Rule**: Routes import from `app.api.dependencies.services`, never from `app.repositories.*`. Services are the only callers of repositories.
|
||||
|
||||
---
|
||||
|
||||
@@ -649,6 +685,11 @@ Use this checklist to catch issues before code review:
|
||||
- [ ] Resource ownership verification
|
||||
- [ ] CORS configured (no wildcards in production)
|
||||
|
||||
### Architecture
|
||||
- [ ] Routes never import repositories directly (only services)
|
||||
- [ ] Services call repositories; repositories call database only
|
||||
- [ ] New service registered in `app/api/dependencies/services.py`
|
||||
|
||||
### Python
|
||||
- [ ] Use `==` not `is` for value comparison
|
||||
- [ ] No mutable default arguments
|
||||
@@ -661,21 +702,18 @@ Use this checklist to catch issues before code review:
|
||||
|
||||
### Pre-commit Checks
|
||||
|
||||
Add these to your development workflow:
|
||||
Add these to your development workflow (or use `make validate`):
|
||||
|
||||
```bash
|
||||
# Format code
|
||||
black app tests
|
||||
isort app tests
|
||||
# Format + lint (Ruff replaces Black, isort, flake8)
|
||||
uv run ruff format app tests
|
||||
uv run ruff check app tests
|
||||
|
||||
# Type checking
|
||||
mypy app --strict
|
||||
|
||||
# Linting
|
||||
flake8 app tests
|
||||
uv run pyright app
|
||||
|
||||
# Run tests
|
||||
pytest --cov=app --cov-report=term-missing
|
||||
IS_TEST=True uv run pytest --cov=app --cov-report=term-missing
|
||||
|
||||
# Check coverage (should be 80%+)
|
||||
coverage report --fail-under=80
|
||||
@@ -693,6 +731,6 @@ Add new entries when:
|
||||
|
||||
---
|
||||
|
||||
**Last Updated**: 2025-10-31
|
||||
**Issues Cataloged**: 18 common pitfalls
|
||||
**Last Updated**: 2026-02-28
|
||||
**Issues Cataloged**: 19 common pitfalls
|
||||
**Remember**: This document exists because these issues HAVE occurred. Don't skip it.
|
||||
|
||||
@@ -99,7 +99,7 @@ backend/tests/
|
||||
│ └── test_database_workflows.py # PostgreSQL workflow tests
|
||||
│
|
||||
├── api/ # Integration tests (SQLite, fast)
|
||||
├── crud/ # Unit tests
|
||||
├── repositories/ # Repository unit tests
|
||||
└── conftest.py # Standard fixtures
|
||||
```
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,4 @@
|
||||
#!/bin/bash
|
||||
#!/bin/sh
|
||||
set -e
|
||||
echo "Starting Backend"
|
||||
|
||||
|
||||
@@ -20,43 +20,36 @@ dependencies = [
|
||||
"uvicorn>=0.34.0",
|
||||
"pydantic>=2.10.6",
|
||||
"pydantic-settings>=2.2.1",
|
||||
"python-multipart>=0.0.19",
|
||||
"python-multipart>=0.0.22",
|
||||
"fastapi-utils==0.8.0",
|
||||
|
||||
# Database
|
||||
"sqlalchemy>=2.0.29",
|
||||
"alembic>=1.14.1",
|
||||
"psycopg2-binary>=2.9.9",
|
||||
"asyncpg>=0.29.0",
|
||||
"aiosqlite==0.21.0",
|
||||
|
||||
# Environment configuration
|
||||
"python-dotenv>=1.0.1",
|
||||
|
||||
# API utilities
|
||||
"email-validator>=2.1.0.post1",
|
||||
"ujson>=5.9.0",
|
||||
|
||||
# CORS and security
|
||||
"starlette>=0.40.0",
|
||||
"starlette-csrf>=1.4.5",
|
||||
"slowapi>=0.1.9",
|
||||
|
||||
# Utilities
|
||||
"httpx>=0.27.0",
|
||||
"tenacity>=8.2.3",
|
||||
"pytz>=2024.1",
|
||||
"pillow>=10.3.0",
|
||||
"pillow>=12.1.1",
|
||||
"apscheduler==3.11.0",
|
||||
|
||||
# Security and authentication (pinned for reproducibility)
|
||||
"python-jose==3.4.0",
|
||||
"passlib==1.7.4",
|
||||
# Security and authentication
|
||||
"PyJWT>=2.9.0",
|
||||
"bcrypt==4.2.1",
|
||||
"cryptography==44.0.1",
|
||||
|
||||
"cryptography>=46.0.5",
|
||||
# OAuth authentication
|
||||
"authlib>=1.3.0",
|
||||
"authlib>=1.6.6",
|
||||
"urllib3>=2.6.3",
|
||||
]
|
||||
|
||||
# Development dependencies
|
||||
@@ -72,7 +65,18 @@ dev = [
|
||||
|
||||
# Development tools
|
||||
"ruff>=0.8.0", # All-in-one: linting, formatting, import sorting
|
||||
"mypy>=1.8.0", # Type checking
|
||||
"pyright>=1.1.390", # Type checking
|
||||
|
||||
# Security auditing
|
||||
"pip-audit>=2.7.0", # Dependency vulnerability scanning (PyPA/OSV)
|
||||
"pip-licenses>=4.0.0", # License compliance checking
|
||||
"detect-secrets>=1.5.0", # Hardcoded secrets detection
|
||||
|
||||
# Performance benchmarking
|
||||
"pytest-benchmark>=4.0.0", # Performance regression detection
|
||||
|
||||
# Pre-commit hooks
|
||||
"pre-commit>=4.0.0", # Git pre-commit hook framework
|
||||
]
|
||||
|
||||
# E2E testing with real PostgreSQL (requires Docker)
|
||||
@@ -131,6 +135,8 @@ select = [
|
||||
"RUF", # Ruff-specific
|
||||
"ASYNC", # flake8-async
|
||||
"S", # flake8-bandit (security)
|
||||
"G", # flake8-logging-format (logging best practices)
|
||||
"T20", # flake8-print (no print statements in production code)
|
||||
]
|
||||
|
||||
# Ignore specific rules
|
||||
@@ -154,11 +160,13 @@ unfixable = []
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"app/alembic/env.py" = ["E402", "F403", "F405"] # Alembic requires specific import order
|
||||
"app/alembic/versions/*.py" = ["E402"] # Migration files have specific structure
|
||||
"tests/**/*.py" = ["S101", "N806", "B017", "N817", "S110", "ASYNC251", "RUF043"] # pytest: asserts, CamelCase fixtures, blind exceptions, try-pass patterns, and async test helpers are intentional
|
||||
"tests/**/*.py" = ["S101", "N806", "B017", "N817", "ASYNC251", "RUF043", "T20"] # pytest: asserts, CamelCase fixtures, blind exceptions, async test helpers, and print for debugging are intentional
|
||||
"app/models/__init__.py" = ["F401"] # __init__ files re-export modules
|
||||
"app/models/base.py" = ["F401"] # Re-exports Base for use by other models
|
||||
"app/utils/test_utils.py" = ["N806"] # SQLAlchemy session factories use CamelCase convention
|
||||
"app/main.py" = ["N806"] # Constants use UPPER_CASE convention
|
||||
"app/init_db.py" = ["T20"] # CLI script uses print for user-facing output
|
||||
"migrate.py" = ["T20"] # CLI script uses print for user-facing output
|
||||
|
||||
# ============================================================================
|
||||
# Ruff Import Sorting (isort replacement)
|
||||
@@ -185,120 +193,6 @@ indent-style = "space"
|
||||
skip-magic-trailing-comma = false
|
||||
line-ending = "lf"
|
||||
|
||||
# ============================================================================
|
||||
# mypy Configuration - Type Checking
|
||||
# ============================================================================
|
||||
[tool.mypy]
|
||||
python_version = "3.12"
|
||||
warn_return_any = false # SQLAlchemy queries return Any - overly strict
|
||||
warn_unused_configs = true
|
||||
disallow_untyped_defs = false # Gradual typing - enable later
|
||||
disallow_incomplete_defs = false
|
||||
check_untyped_defs = true
|
||||
no_implicit_optional = true
|
||||
warn_redundant_casts = true
|
||||
warn_unused_ignores = true
|
||||
warn_no_return = true
|
||||
strict_equality = true
|
||||
ignore_missing_imports = false
|
||||
explicit_package_bases = true
|
||||
namespace_packages = true
|
||||
|
||||
# Pydantic plugin for better validation
|
||||
plugins = ["pydantic.mypy"]
|
||||
|
||||
# Per-module options
|
||||
[[tool.mypy.overrides]]
|
||||
module = "alembic.*"
|
||||
ignore_errors = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "app.alembic.*"
|
||||
ignore_errors = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "sqlalchemy.*"
|
||||
ignore_missing_imports = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "fastapi_utils.*"
|
||||
ignore_missing_imports = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "slowapi.*"
|
||||
ignore_missing_imports = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "jose.*"
|
||||
ignore_missing_imports = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "passlib.*"
|
||||
ignore_missing_imports = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "pydantic_settings.*"
|
||||
ignore_missing_imports = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "fastapi.*"
|
||||
ignore_missing_imports = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "apscheduler.*"
|
||||
ignore_missing_imports = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "starlette.*"
|
||||
ignore_missing_imports = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "authlib.*"
|
||||
ignore_missing_imports = true
|
||||
|
||||
# SQLAlchemy ORM models - Column descriptors cause type confusion
|
||||
[[tool.mypy.overrides]]
|
||||
module = "app.models.*"
|
||||
disable_error_code = ["assignment", "arg-type", "return-value"]
|
||||
|
||||
# CRUD operations - Generic ModelType and SQLAlchemy Result issues
|
||||
[[tool.mypy.overrides]]
|
||||
module = "app.crud.*"
|
||||
disable_error_code = ["attr-defined", "assignment", "arg-type", "return-value"]
|
||||
|
||||
# API routes - SQLAlchemy Column to Pydantic schema conversions
|
||||
[[tool.mypy.overrides]]
|
||||
module = "app.api.routes.*"
|
||||
disable_error_code = ["arg-type", "call-arg", "call-overload", "assignment"]
|
||||
|
||||
# API dependencies - Similar SQLAlchemy Column issues
|
||||
[[tool.mypy.overrides]]
|
||||
module = "app.api.dependencies.*"
|
||||
disable_error_code = ["arg-type"]
|
||||
|
||||
# FastAPI exception handlers have correct signatures despite mypy warnings
|
||||
[[tool.mypy.overrides]]
|
||||
module = "app.main"
|
||||
disable_error_code = ["arg-type"]
|
||||
|
||||
# Auth service - SQLAlchemy Column issues
|
||||
[[tool.mypy.overrides]]
|
||||
module = "app.services.auth_service"
|
||||
disable_error_code = ["assignment", "arg-type"]
|
||||
|
||||
# Test utils - Testing patterns
|
||||
[[tool.mypy.overrides]]
|
||||
module = "app.utils.auth_test_utils"
|
||||
disable_error_code = ["assignment", "arg-type"]
|
||||
|
||||
# ============================================================================
|
||||
# Pydantic mypy plugin configuration
|
||||
# ============================================================================
|
||||
[tool.pydantic-mypy]
|
||||
init_forbid_extra = true
|
||||
init_typed = true
|
||||
warn_required_dynamic_aliases = true
|
||||
|
||||
# ============================================================================
|
||||
# Pytest Configuration
|
||||
# ============================================================================
|
||||
@@ -315,12 +209,15 @@ addopts = [
|
||||
"--cov=app",
|
||||
"--cov-report=term-missing",
|
||||
"--cov-report=html",
|
||||
"--ignore=tests/benchmarks", # benchmarks are incompatible with xdist; run via 'make benchmark'
|
||||
"-p", "no:benchmark", # disable pytest-benchmark plugin during normal runs (conflicts with xdist)
|
||||
]
|
||||
markers = [
|
||||
"sqlite: marks tests that should run on SQLite (mocked).",
|
||||
"postgres: marks tests that require a real PostgreSQL database.",
|
||||
"e2e: marks end-to-end tests requiring Docker containers.",
|
||||
"schemathesis: marks Schemathesis-generated API tests.",
|
||||
"benchmark: marks performance benchmark tests.",
|
||||
]
|
||||
asyncio_default_fixture_loop_scope = "function"
|
||||
|
||||
|
||||
23
backend/pyrightconfig.json
Normal file
23
backend/pyrightconfig.json
Normal file
@@ -0,0 +1,23 @@
|
||||
{
|
||||
"include": ["app"],
|
||||
"exclude": ["app/alembic"],
|
||||
"pythonVersion": "3.12",
|
||||
"venvPath": ".",
|
||||
"venv": ".venv",
|
||||
"typeCheckingMode": "standard",
|
||||
"reportMissingImports": true,
|
||||
"reportMissingTypeStubs": false,
|
||||
"reportUnknownMemberType": false,
|
||||
"reportUnknownVariableType": false,
|
||||
"reportUnknownArgumentType": false,
|
||||
"reportUnknownParameterType": false,
|
||||
"reportUnknownLambdaType": false,
|
||||
"reportReturnType": true,
|
||||
"reportUnusedImport": false,
|
||||
"reportGeneralTypeIssues": false,
|
||||
"reportAttributeAccessIssue": false,
|
||||
"reportArgumentType": false,
|
||||
"strictListInference": false,
|
||||
"strictDictionaryInference": false,
|
||||
"strictSetInference": false
|
||||
}
|
||||
@@ -147,7 +147,7 @@ class TestAdminCreateUser:
|
||||
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert response.status_code == status.HTTP_409_CONFLICT
|
||||
|
||||
|
||||
class TestAdminGetUser:
|
||||
@@ -565,7 +565,7 @@ class TestAdminCreateOrganization:
|
||||
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert response.status_code == status.HTTP_409_CONFLICT
|
||||
|
||||
|
||||
class TestAdminGetOrganization:
|
||||
|
||||
@@ -45,7 +45,7 @@ class TestAdminListUsersFilters:
|
||||
async def test_list_users_database_error_propagates(self, client, superuser_token):
|
||||
"""Test that database errors propagate correctly (covers line 118-120)."""
|
||||
with patch(
|
||||
"app.api.routes.admin.user_crud.get_multi_with_total",
|
||||
"app.api.routes.admin.user_service.list_users",
|
||||
side_effect=Exception("DB error"),
|
||||
):
|
||||
with pytest.raises(Exception):
|
||||
@@ -74,8 +74,8 @@ class TestAdminCreateUserErrors:
|
||||
},
|
||||
)
|
||||
|
||||
# Should get error for duplicate email
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
# Should get conflict for duplicate email
|
||||
assert response.status_code == status.HTTP_409_CONFLICT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_user_unexpected_error_propagates(
|
||||
@@ -83,7 +83,7 @@ class TestAdminCreateUserErrors:
|
||||
):
|
||||
"""Test unexpected errors during user creation (covers line 151-153)."""
|
||||
with patch(
|
||||
"app.api.routes.admin.user_crud.create",
|
||||
"app.api.routes.admin.user_service.create_user",
|
||||
side_effect=RuntimeError("Unexpected error"),
|
||||
):
|
||||
with pytest.raises(RuntimeError):
|
||||
@@ -135,7 +135,7 @@ class TestAdminUpdateUserErrors:
|
||||
):
|
||||
"""Test unexpected errors during user update (covers line 206-208)."""
|
||||
with patch(
|
||||
"app.api.routes.admin.user_crud.update",
|
||||
"app.api.routes.admin.user_service.update_user",
|
||||
side_effect=RuntimeError("Update failed"),
|
||||
):
|
||||
with pytest.raises(RuntimeError):
|
||||
@@ -166,7 +166,7 @@ class TestAdminDeleteUserErrors:
|
||||
):
|
||||
"""Test unexpected errors during user deletion (covers line 238-240)."""
|
||||
with patch(
|
||||
"app.api.routes.admin.user_crud.soft_delete",
|
||||
"app.api.routes.admin.user_service.soft_delete_user",
|
||||
side_effect=Exception("Delete failed"),
|
||||
):
|
||||
with pytest.raises(Exception):
|
||||
@@ -196,7 +196,7 @@ class TestAdminActivateUserErrors:
|
||||
):
|
||||
"""Test unexpected errors during user activation (covers line 282-284)."""
|
||||
with patch(
|
||||
"app.api.routes.admin.user_crud.update",
|
||||
"app.api.routes.admin.user_service.update_user",
|
||||
side_effect=Exception("Activation failed"),
|
||||
):
|
||||
with pytest.raises(Exception):
|
||||
@@ -238,7 +238,7 @@ class TestAdminDeactivateUserErrors:
|
||||
):
|
||||
"""Test unexpected errors during user deactivation (covers line 326-328)."""
|
||||
with patch(
|
||||
"app.api.routes.admin.user_crud.update",
|
||||
"app.api.routes.admin.user_service.update_user",
|
||||
side_effect=Exception("Deactivation failed"),
|
||||
):
|
||||
with pytest.raises(Exception):
|
||||
@@ -258,7 +258,7 @@ class TestAdminListOrganizationsErrors:
|
||||
async def test_list_organizations_database_error(self, client, superuser_token):
|
||||
"""Test list organizations with database error (covers line 427-456)."""
|
||||
with patch(
|
||||
"app.api.routes.admin.organization_crud.get_multi_with_member_counts",
|
||||
"app.api.routes.admin.organization_service.get_multi_with_member_counts",
|
||||
side_effect=Exception("DB error"),
|
||||
):
|
||||
with pytest.raises(Exception):
|
||||
@@ -299,14 +299,14 @@ class TestAdminCreateOrganizationErrors:
|
||||
},
|
||||
)
|
||||
|
||||
# Should get error for duplicate slug
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
# Should get conflict for duplicate slug
|
||||
assert response.status_code == status.HTTP_409_CONFLICT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_organization_unexpected_error(self, client, superuser_token):
|
||||
"""Test unexpected errors during organization creation (covers line 484-485)."""
|
||||
with patch(
|
||||
"app.api.routes.admin.organization_crud.create",
|
||||
"app.api.routes.admin.organization_service.create_organization",
|
||||
side_effect=RuntimeError("Creation failed"),
|
||||
):
|
||||
with pytest.raises(RuntimeError):
|
||||
@@ -367,7 +367,7 @@ class TestAdminUpdateOrganizationErrors:
|
||||
org_id = org.id
|
||||
|
||||
with patch(
|
||||
"app.api.routes.admin.organization_crud.update",
|
||||
"app.api.routes.admin.organization_service.update_organization",
|
||||
side_effect=Exception("Update failed"),
|
||||
):
|
||||
with pytest.raises(Exception):
|
||||
@@ -412,7 +412,7 @@ class TestAdminDeleteOrganizationErrors:
|
||||
org_id = org.id
|
||||
|
||||
with patch(
|
||||
"app.api.routes.admin.organization_crud.remove",
|
||||
"app.api.routes.admin.organization_service.remove_organization",
|
||||
side_effect=Exception("Delete failed"),
|
||||
):
|
||||
with pytest.raises(Exception):
|
||||
@@ -456,7 +456,7 @@ class TestAdminListOrganizationMembersErrors:
|
||||
org_id = org.id
|
||||
|
||||
with patch(
|
||||
"app.api.routes.admin.organization_crud.get_organization_members",
|
||||
"app.api.routes.admin.organization_service.get_organization_members",
|
||||
side_effect=Exception("DB error"),
|
||||
):
|
||||
with pytest.raises(Exception):
|
||||
@@ -531,7 +531,7 @@ class TestAdminAddOrganizationMemberErrors:
|
||||
org_id = org.id
|
||||
|
||||
with patch(
|
||||
"app.api.routes.admin.organization_crud.add_user",
|
||||
"app.api.routes.admin.organization_service.add_member",
|
||||
side_effect=Exception("Add failed"),
|
||||
):
|
||||
with pytest.raises(Exception):
|
||||
@@ -587,7 +587,7 @@ class TestAdminRemoveOrganizationMemberErrors:
|
||||
org_id = org.id
|
||||
|
||||
with patch(
|
||||
"app.api.routes.admin.organization_crud.remove_user",
|
||||
"app.api.routes.admin.organization_service.remove_member",
|
||||
side_effect=Exception("Remove failed"),
|
||||
):
|
||||
with pytest.raises(Exception):
|
||||
|
||||
@@ -19,7 +19,7 @@ class TestLoginSessionCreationFailure:
|
||||
"""Test that login succeeds even if session creation fails."""
|
||||
# Mock session creation to fail
|
||||
with patch(
|
||||
"app.api.routes.auth.session_crud.create_session",
|
||||
"app.api.routes.auth.session_service.create_session",
|
||||
side_effect=Exception("Session creation failed"),
|
||||
):
|
||||
response = await client.post(
|
||||
@@ -43,7 +43,7 @@ class TestOAuthLoginSessionCreationFailure:
|
||||
):
|
||||
"""Test OAuth login succeeds even if session creation fails."""
|
||||
with patch(
|
||||
"app.api.routes.auth.session_crud.create_session",
|
||||
"app.api.routes.auth.session_service.create_session",
|
||||
side_effect=Exception("Session failed"),
|
||||
):
|
||||
response = await client.post(
|
||||
@@ -76,7 +76,7 @@ class TestRefreshTokenSessionUpdateFailure:
|
||||
|
||||
# Mock session update to fail
|
||||
with patch(
|
||||
"app.api.routes.auth.session_crud.update_refresh_token",
|
||||
"app.api.routes.auth.session_service.update_refresh_token",
|
||||
side_effect=Exception("Update failed"),
|
||||
):
|
||||
response = await client.post(
|
||||
@@ -130,7 +130,7 @@ class TestLogoutWithNonExistentSession:
|
||||
tokens = response.json()
|
||||
|
||||
# Mock session lookup to return None
|
||||
with patch("app.api.routes.auth.session_crud.get_by_jti", return_value=None):
|
||||
with patch("app.api.routes.auth.session_service.get_by_jti", return_value=None):
|
||||
response = await client.post(
|
||||
"/api/v1/auth/logout",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
@@ -157,7 +157,7 @@ class TestLogoutUnexpectedError:
|
||||
|
||||
# Mock to raise unexpected error
|
||||
with patch(
|
||||
"app.api.routes.auth.session_crud.get_by_jti",
|
||||
"app.api.routes.auth.session_service.get_by_jti",
|
||||
side_effect=Exception("Unexpected error"),
|
||||
):
|
||||
response = await client.post(
|
||||
@@ -186,7 +186,7 @@ class TestLogoutAllUnexpectedError:
|
||||
|
||||
# Mock to raise database error
|
||||
with patch(
|
||||
"app.api.routes.auth.session_crud.deactivate_all_user_sessions",
|
||||
"app.api.routes.auth.session_service.deactivate_all_user_sessions",
|
||||
side_effect=Exception("DB error"),
|
||||
):
|
||||
response = await client.post(
|
||||
@@ -212,7 +212,7 @@ class TestPasswordResetConfirmSessionInvalidation:
|
||||
|
||||
# Mock session invalidation to fail
|
||||
with patch(
|
||||
"app.api.routes.auth.session_crud.deactivate_all_user_sessions",
|
||||
"app.api.routes.auth.session_service.deactivate_all_user_sessions",
|
||||
side_effect=Exception("Invalidation failed"),
|
||||
):
|
||||
response = await client.post(
|
||||
|
||||
@@ -334,7 +334,7 @@ class TestPasswordResetConfirm:
|
||||
token = create_password_reset_token(async_test_user.email)
|
||||
|
||||
# Mock the database commit to raise an exception
|
||||
with patch("app.api.routes.auth.user_crud.get_by_email") as mock_get:
|
||||
with patch("app.services.auth_service.user_repo.get_by_email") as mock_get:
|
||||
mock_get.side_effect = Exception("Database error")
|
||||
|
||||
response = await client.post(
|
||||
|
||||
@@ -12,8 +12,8 @@ These tests prevent real-world attack scenarios.
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
from app.crud.session import session as session_crud
|
||||
from app.models.user import User
|
||||
from app.repositories.session import session_repo as session_repo
|
||||
|
||||
|
||||
class TestRevokedSessionSecurity:
|
||||
@@ -117,7 +117,7 @@ class TestRevokedSessionSecurity:
|
||||
|
||||
async with SessionLocal() as session:
|
||||
# Find and delete the session
|
||||
db_session = await session_crud.get_by_jti(session, jti=jti)
|
||||
db_session = await session_repo.get_by_jti(session, jti=jti)
|
||||
if db_session:
|
||||
await session.delete(db_session)
|
||||
await session.commit()
|
||||
|
||||
@@ -8,7 +8,7 @@ from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.crud.oauth import oauth_account
|
||||
from app.repositories.oauth_account import oauth_account_repo as oauth_account
|
||||
from app.schemas.oauth import OAuthAccountCreate
|
||||
|
||||
|
||||
@@ -349,7 +349,7 @@ class TestOAuthProviderEndpoints:
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create a test client
|
||||
from app.crud.oauth import oauth_client
|
||||
from app.repositories.oauth_client import oauth_client_repo as oauth_client
|
||||
from app.schemas.oauth import OAuthClientCreate
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -386,7 +386,7 @@ class TestOAuthProviderEndpoints:
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create a test client
|
||||
from app.crud.oauth import oauth_client
|
||||
from app.repositories.oauth_client import oauth_client_repo as oauth_client
|
||||
from app.schemas.oauth import OAuthClientCreate
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
|
||||
@@ -537,7 +537,7 @@ class TestOrganizationExceptionHandlers:
|
||||
):
|
||||
"""Test generic exception handler in get_my_organizations (covers lines 81-83)."""
|
||||
with patch(
|
||||
"app.crud.organization.organization.get_user_organizations_with_details",
|
||||
"app.api.routes.organizations.organization_service.get_user_organizations_with_details",
|
||||
side_effect=Exception("Database connection lost"),
|
||||
):
|
||||
# The exception handler logs and re-raises, so we expect the exception
|
||||
@@ -554,7 +554,7 @@ class TestOrganizationExceptionHandlers:
|
||||
):
|
||||
"""Test generic exception handler in get_organization (covers lines 124-128)."""
|
||||
with patch(
|
||||
"app.crud.organization.organization.get",
|
||||
"app.api.routes.organizations.organization_service.get_organization",
|
||||
side_effect=Exception("Database timeout"),
|
||||
):
|
||||
with pytest.raises(Exception, match="Database timeout"):
|
||||
@@ -569,7 +569,7 @@ class TestOrganizationExceptionHandlers:
|
||||
):
|
||||
"""Test generic exception handler in get_organization_members (covers lines 170-172)."""
|
||||
with patch(
|
||||
"app.crud.organization.organization.get_organization_members",
|
||||
"app.api.routes.organizations.organization_service.get_organization_members",
|
||||
side_effect=Exception("Connection pool exhausted"),
|
||||
):
|
||||
with pytest.raises(Exception, match="Connection pool exhausted"):
|
||||
@@ -591,11 +591,11 @@ class TestOrganizationExceptionHandlers:
|
||||
admin_token = login_response.json()["access_token"]
|
||||
|
||||
with patch(
|
||||
"app.crud.organization.organization.get",
|
||||
"app.api.routes.organizations.organization_service.get_organization",
|
||||
return_value=test_org_with_user_admin,
|
||||
):
|
||||
with patch(
|
||||
"app.crud.organization.organization.update",
|
||||
"app.api.routes.organizations.organization_service.update_organization",
|
||||
side_effect=Exception("Write lock timeout"),
|
||||
):
|
||||
with pytest.raises(Exception, match="Write lock timeout"):
|
||||
|
||||
@@ -11,9 +11,9 @@ These tests prevent unauthorized access and privilege escalation.
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
from app.crud.user import user as user_crud
|
||||
from app.models.organization import Organization
|
||||
from app.models.user import User
|
||||
from app.repositories.user import user_repo as user_repo
|
||||
|
||||
|
||||
class TestInactiveUserBlocking:
|
||||
@@ -50,7 +50,7 @@ class TestInactiveUserBlocking:
|
||||
|
||||
# Step 2: Admin deactivates the user
|
||||
async with SessionLocal() as session:
|
||||
user = await user_crud.get(session, id=async_test_user.id)
|
||||
user = await user_repo.get(session, id=async_test_user.id)
|
||||
user.is_active = False
|
||||
await session.commit()
|
||||
|
||||
@@ -80,7 +80,7 @@ class TestInactiveUserBlocking:
|
||||
|
||||
# Deactivate user
|
||||
async with SessionLocal() as session:
|
||||
user = await user_crud.get(session, id=async_test_user.id)
|
||||
user = await user_repo.get(session, id=async_test_user.id)
|
||||
user.is_active = False
|
||||
await session.commit()
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ async def async_test_user2(async_test_db):
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
from app.crud.user import user as user_crud
|
||||
from app.repositories.user import user_repo as user_repo
|
||||
from app.schemas.users import UserCreate
|
||||
|
||||
user_data = UserCreate(
|
||||
@@ -48,7 +48,7 @@ async def async_test_user2(async_test_db):
|
||||
first_name="Test",
|
||||
last_name="User2",
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
user = await user_repo.create(session, obj_in=user_data)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
return user
|
||||
@@ -191,9 +191,9 @@ class TestRevokeSession:
|
||||
|
||||
# Verify session is deactivated
|
||||
async with SessionLocal() as session:
|
||||
from app.crud.session import session as session_crud
|
||||
from app.repositories.session import session_repo as session_repo
|
||||
|
||||
revoked_session = await session_crud.get(session, id=str(session_id))
|
||||
revoked_session = await session_repo.get(session, id=str(session_id))
|
||||
assert revoked_session.is_active is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -267,8 +267,8 @@ class TestCleanupExpiredSessions:
|
||||
"""Test successfully cleaning up expired sessions."""
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create expired and active sessions using CRUD to avoid greenlet issues
|
||||
from app.crud.session import session as session_crud
|
||||
# Create expired and active sessions using repository to avoid greenlet issues
|
||||
from app.repositories.session import session_repo as session_repo
|
||||
from app.schemas.sessions import SessionCreate
|
||||
|
||||
async with SessionLocal() as db:
|
||||
@@ -282,7 +282,7 @@ class TestCleanupExpiredSessions:
|
||||
expires_at=datetime.now(UTC) - timedelta(days=1),
|
||||
last_used_at=datetime.now(UTC) - timedelta(days=2),
|
||||
)
|
||||
e1 = await session_crud.create_session(db, obj_in=e1_data)
|
||||
e1 = await session_repo.create_session(db, obj_in=e1_data)
|
||||
e1.is_active = False
|
||||
db.add(e1)
|
||||
|
||||
@@ -296,7 +296,7 @@ class TestCleanupExpiredSessions:
|
||||
expires_at=datetime.now(UTC) - timedelta(hours=1),
|
||||
last_used_at=datetime.now(UTC) - timedelta(hours=2),
|
||||
)
|
||||
e2 = await session_crud.create_session(db, obj_in=e2_data)
|
||||
e2 = await session_repo.create_session(db, obj_in=e2_data)
|
||||
e2.is_active = False
|
||||
db.add(e2)
|
||||
|
||||
@@ -310,7 +310,7 @@ class TestCleanupExpiredSessions:
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
await session_crud.create_session(db, obj_in=a1_data)
|
||||
await session_repo.create_session(db, obj_in=a1_data)
|
||||
await db.commit()
|
||||
|
||||
# Cleanup expired sessions
|
||||
@@ -333,8 +333,8 @@ class TestCleanupExpiredSessions:
|
||||
"""Test cleanup when no sessions are expired."""
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create only active sessions using CRUD
|
||||
from app.crud.session import session as session_crud
|
||||
# Create only active sessions using repository
|
||||
from app.repositories.session import session_repo as session_repo
|
||||
from app.schemas.sessions import SessionCreate
|
||||
|
||||
async with SessionLocal() as db:
|
||||
@@ -347,7 +347,7 @@ class TestCleanupExpiredSessions:
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
await session_crud.create_session(db, obj_in=a1_data)
|
||||
await session_repo.create_session(db, obj_in=a1_data)
|
||||
await db.commit()
|
||||
|
||||
response = await client.delete(
|
||||
@@ -384,7 +384,7 @@ class TestSessionsAdditionalCases:
|
||||
|
||||
# Create multiple sessions
|
||||
async with SessionLocal() as session:
|
||||
from app.crud.session import session as session_crud
|
||||
from app.repositories.session import session_repo as session_repo
|
||||
from app.schemas.sessions import SessionCreate
|
||||
|
||||
for i in range(5):
|
||||
@@ -397,7 +397,7 @@ class TestSessionsAdditionalCases:
|
||||
expires_at=datetime.now(UTC) + timedelta(days=7),
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
await session_crud.create_session(session, obj_in=session_data)
|
||||
await session_repo.create_session(session, obj_in=session_data)
|
||||
await session.commit()
|
||||
|
||||
response = await client.get(
|
||||
@@ -431,7 +431,7 @@ class TestSessionsAdditionalCases:
|
||||
"""Test cleanup with mix of active/inactive and expired/not-expired sessions."""
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
from app.crud.session import session as session_crud
|
||||
from app.repositories.session import session_repo as session_repo
|
||||
from app.schemas.sessions import SessionCreate
|
||||
|
||||
async with SessionLocal() as db:
|
||||
@@ -445,7 +445,7 @@ class TestSessionsAdditionalCases:
|
||||
expires_at=datetime.now(UTC) - timedelta(days=1),
|
||||
last_used_at=datetime.now(UTC) - timedelta(days=2),
|
||||
)
|
||||
e1 = await session_crud.create_session(db, obj_in=e1_data)
|
||||
e1 = await session_repo.create_session(db, obj_in=e1_data)
|
||||
e1.is_active = False
|
||||
db.add(e1)
|
||||
|
||||
@@ -459,7 +459,7 @@ class TestSessionsAdditionalCases:
|
||||
expires_at=datetime.now(UTC) - timedelta(hours=1),
|
||||
last_used_at=datetime.now(UTC) - timedelta(hours=2),
|
||||
)
|
||||
await session_crud.create_session(db, obj_in=e2_data)
|
||||
await session_repo.create_session(db, obj_in=e2_data)
|
||||
|
||||
await db.commit()
|
||||
|
||||
@@ -502,10 +502,10 @@ class TestSessionExceptionHandlers:
|
||||
"""Test list_sessions handles database errors (covers lines 104-106)."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from app.crud import session as session_module
|
||||
from app.repositories import session as session_module
|
||||
|
||||
with patch.object(
|
||||
session_module.session,
|
||||
session_module.session_repo,
|
||||
"get_user_sessions",
|
||||
side_effect=Exception("Database error"),
|
||||
):
|
||||
@@ -527,10 +527,10 @@ class TestSessionExceptionHandlers:
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
from app.crud import session as session_module
|
||||
from app.repositories import session as session_module
|
||||
|
||||
# First create a session to revoke
|
||||
from app.crud.session import session as session_crud
|
||||
from app.repositories.session import session_repo as session_repo
|
||||
from app.schemas.sessions import SessionCreate
|
||||
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
@@ -545,12 +545,12 @@ class TestSessionExceptionHandlers:
|
||||
last_used_at=datetime.now(UTC),
|
||||
expires_at=datetime.now(UTC) + timedelta(days=60),
|
||||
)
|
||||
user_session = await session_crud.create_session(db, obj_in=session_in)
|
||||
user_session = await session_repo.create_session(db, obj_in=session_in)
|
||||
session_id = user_session.id
|
||||
|
||||
# Mock the deactivate method to raise an exception
|
||||
with patch.object(
|
||||
session_module.session,
|
||||
session_module.session_repo,
|
||||
"deactivate",
|
||||
side_effect=Exception("Database connection lost"),
|
||||
):
|
||||
@@ -568,10 +568,10 @@ class TestSessionExceptionHandlers:
|
||||
"""Test cleanup_expired_sessions handles database errors (covers lines 233-236)."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from app.crud import session as session_module
|
||||
from app.repositories import session as session_module
|
||||
|
||||
with patch.object(
|
||||
session_module.session,
|
||||
session_module.session_repo,
|
||||
"cleanup_expired_for_user",
|
||||
side_effect=Exception("Cleanup failed"),
|
||||
):
|
||||
|
||||
@@ -157,7 +157,7 @@ class TestListUsers:
|
||||
response = await client.get("/api/v1/users")
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
# Note: Removed test_list_users_unexpected_error because mocking at CRUD level
|
||||
# Note: Removed test_list_users_unexpected_error because mocking at repository level
|
||||
# causes the exception to be raised before FastAPI can handle it properly
|
||||
|
||||
|
||||
|
||||
@@ -99,7 +99,8 @@ class TestUpdateCurrentUser:
|
||||
from unittest.mock import patch
|
||||
|
||||
with patch(
|
||||
"app.api.routes.users.user_crud.update", side_effect=Exception("DB error")
|
||||
"app.api.routes.users.user_service.update_user",
|
||||
side_effect=Exception("DB error"),
|
||||
):
|
||||
with pytest.raises(Exception):
|
||||
await client.patch(
|
||||
@@ -134,7 +135,7 @@ class TestUpdateCurrentUser:
|
||||
from unittest.mock import patch
|
||||
|
||||
with patch(
|
||||
"app.api.routes.users.user_crud.update",
|
||||
"app.api.routes.users.user_service.update_user",
|
||||
side_effect=ValueError("Invalid value"),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
@@ -224,7 +225,8 @@ class TestUpdateUserById:
|
||||
from unittest.mock import patch
|
||||
|
||||
with patch(
|
||||
"app.api.routes.users.user_crud.update", side_effect=ValueError("Invalid")
|
||||
"app.api.routes.users.user_service.update_user",
|
||||
side_effect=ValueError("Invalid"),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
await client.patch(
|
||||
@@ -241,7 +243,8 @@ class TestUpdateUserById:
|
||||
from unittest.mock import patch
|
||||
|
||||
with patch(
|
||||
"app.api.routes.users.user_crud.update", side_effect=Exception("Unexpected")
|
||||
"app.api.routes.users.user_service.update_user",
|
||||
side_effect=Exception("Unexpected"),
|
||||
):
|
||||
with pytest.raises(Exception):
|
||||
await client.patch(
|
||||
@@ -354,7 +357,7 @@ class TestDeleteUserById:
|
||||
from unittest.mock import patch
|
||||
|
||||
with patch(
|
||||
"app.api.routes.users.user_crud.soft_delete",
|
||||
"app.api.routes.users.user_service.soft_delete_user",
|
||||
side_effect=ValueError("Cannot delete"),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
@@ -371,7 +374,7 @@ class TestDeleteUserById:
|
||||
from unittest.mock import patch
|
||||
|
||||
with patch(
|
||||
"app.api.routes.users.user_crud.soft_delete",
|
||||
"app.api.routes.users.user_service.soft_delete_user",
|
||||
side_effect=Exception("Unexpected"),
|
||||
):
|
||||
with pytest.raises(Exception):
|
||||
|
||||
0
backend/tests/crud/__init__.py → backend/tests/benchmarks/__init__.py
Executable file → Normal file
0
backend/tests/crud/__init__.py → backend/tests/benchmarks/__init__.py
Executable file → Normal file
327
backend/tests/benchmarks/test_endpoint_performance.py
Normal file
327
backend/tests/benchmarks/test_endpoint_performance.py
Normal file
@@ -0,0 +1,327 @@
|
||||
"""
|
||||
Performance Benchmark Tests.
|
||||
|
||||
These tests establish baseline performance metrics for critical API endpoints
|
||||
and core operations, detecting regressions when response times degrade.
|
||||
|
||||
Usage:
|
||||
make benchmark # Run benchmarks and save baseline
|
||||
make benchmark-check # Run benchmarks and compare against saved baseline
|
||||
|
||||
Baselines are stored in .benchmarks/ and should be committed to version control
|
||||
so CI can detect performance regressions across commits.
|
||||
"""
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.core.auth import (
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
decode_token,
|
||||
get_password_hash,
|
||||
verify_password,
|
||||
)
|
||||
from app.main import app
|
||||
|
||||
pytestmark = [pytest.mark.benchmark]
|
||||
|
||||
# Pre-computed hash for sync benchmarks (avoids hashing in every iteration)
|
||||
_BENCH_PASSWORD = "BenchPass123!"
|
||||
_BENCH_HASH = get_password_hash(_BENCH_PASSWORD)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Fixtures
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sync_client():
|
||||
"""Create a FastAPI test client with mocked database for stateless endpoints."""
|
||||
with patch("app.main.check_database_health") as mock_health_check:
|
||||
mock_health_check.return_value = True
|
||||
yield TestClient(app)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Stateless Endpoint Benchmarks (no DB required)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_health_endpoint_performance(sync_client, benchmark):
|
||||
"""Benchmark: GET /health should respond within acceptable latency."""
|
||||
result = benchmark(sync_client.get, "/health")
|
||||
assert result.status_code == 200
|
||||
|
||||
|
||||
def test_openapi_schema_performance(sync_client, benchmark):
|
||||
"""Benchmark: OpenAPI schema generation should not regress."""
|
||||
result = benchmark(sync_client.get, "/api/v1/openapi.json")
|
||||
assert result.status_code == 200
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Core Crypto & Token Benchmarks (no DB required)
|
||||
#
|
||||
# These benchmark the CPU-intensive operations that underpin auth:
|
||||
# password hashing, verification, and JWT creation/decoding.
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_password_hashing_performance(benchmark):
|
||||
"""Benchmark: bcrypt password hashing (CPU-bound, ~100ms expected)."""
|
||||
result = benchmark(get_password_hash, _BENCH_PASSWORD)
|
||||
assert result.startswith("$2b$")
|
||||
|
||||
|
||||
def test_password_verification_performance(benchmark):
|
||||
"""Benchmark: bcrypt password verification against a known hash."""
|
||||
result = benchmark(verify_password, _BENCH_PASSWORD, _BENCH_HASH)
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_access_token_creation_performance(benchmark):
|
||||
"""Benchmark: JWT access token generation."""
|
||||
user_id = str(uuid.uuid4())
|
||||
token = benchmark(create_access_token, user_id)
|
||||
assert isinstance(token, str)
|
||||
assert len(token) > 0
|
||||
|
||||
|
||||
def test_refresh_token_creation_performance(benchmark):
|
||||
"""Benchmark: JWT refresh token generation."""
|
||||
user_id = str(uuid.uuid4())
|
||||
token = benchmark(create_refresh_token, user_id)
|
||||
assert isinstance(token, str)
|
||||
assert len(token) > 0
|
||||
|
||||
|
||||
def test_token_decode_performance(benchmark):
|
||||
"""Benchmark: JWT token decoding and validation."""
|
||||
user_id = str(uuid.uuid4())
|
||||
token = create_access_token(user_id)
|
||||
payload = benchmark(decode_token, token, "access")
|
||||
assert payload.sub == user_id
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Database-dependent Endpoint Benchmarks (async, manual timing)
|
||||
#
|
||||
# pytest-benchmark does not support async functions natively. These tests
|
||||
# measure latency manually and assert against a maximum threshold (in ms)
|
||||
# to catch performance regressions.
|
||||
# =============================================================================
|
||||
|
||||
MAX_LOGIN_MS = 500
|
||||
MAX_GET_USER_MS = 200
|
||||
MAX_REGISTER_MS = 500
|
||||
MAX_TOKEN_REFRESH_MS = 200
|
||||
MAX_SESSIONS_LIST_MS = 200
|
||||
MAX_USER_UPDATE_MS = 200
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def bench_user(async_test_db):
|
||||
"""Create a test user for benchmark tests."""
|
||||
from app.models.user import User
|
||||
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="bench@example.com",
|
||||
password_hash=get_password_hash("BenchPass123!"),
|
||||
first_name="Bench",
|
||||
last_name="User",
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def bench_token(client, bench_user):
|
||||
"""Get an auth token for the benchmark user."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": "bench@example.com", "password": "BenchPass123!"},
|
||||
)
|
||||
assert response.status_code == 200, f"Login failed: {response.text}"
|
||||
return response.json()["access_token"]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def bench_refresh_token(client, bench_user):
|
||||
"""Get a refresh token for the benchmark user."""
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": "bench@example.com", "password": "BenchPass123!"},
|
||||
)
|
||||
assert response.status_code == 200, f"Login failed: {response.text}"
|
||||
return response.json()["refresh_token"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_latency(client, bench_user):
|
||||
"""Performance: POST /api/v1/auth/login must respond under threshold."""
|
||||
iterations = 5
|
||||
total_ms = 0.0
|
||||
|
||||
for _ in range(iterations):
|
||||
start = time.perf_counter()
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": "bench@example.com", "password": "BenchPass123!"},
|
||||
)
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
total_ms += elapsed_ms
|
||||
assert response.status_code == 200
|
||||
|
||||
mean_ms = total_ms / iterations
|
||||
print(f"\n Login mean latency: {mean_ms:.1f}ms (threshold: {MAX_LOGIN_MS}ms)")
|
||||
assert mean_ms < MAX_LOGIN_MS, (
|
||||
f"Login latency regression: {mean_ms:.1f}ms exceeds {MAX_LOGIN_MS}ms threshold"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_user_latency(client, bench_token):
|
||||
"""Performance: GET /api/v1/users/me must respond under threshold."""
|
||||
iterations = 10
|
||||
total_ms = 0.0
|
||||
|
||||
for _ in range(iterations):
|
||||
start = time.perf_counter()
|
||||
response = await client.get(
|
||||
"/api/v1/users/me",
|
||||
headers={"Authorization": f"Bearer {bench_token}"},
|
||||
)
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
total_ms += elapsed_ms
|
||||
assert response.status_code == 200
|
||||
|
||||
mean_ms = total_ms / iterations
|
||||
print(
|
||||
f"\n Get user mean latency: {mean_ms:.1f}ms (threshold: {MAX_GET_USER_MS}ms)"
|
||||
)
|
||||
assert mean_ms < MAX_GET_USER_MS, (
|
||||
f"Get user latency regression: {mean_ms:.1f}ms exceeds {MAX_GET_USER_MS}ms threshold"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_latency(client):
|
||||
"""Performance: POST /api/v1/auth/register must respond under threshold."""
|
||||
iterations = 3
|
||||
total_ms = 0.0
|
||||
|
||||
for i in range(iterations):
|
||||
start = time.perf_counter()
|
||||
response = await client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"email": f"benchreg{i}@example.com",
|
||||
"password": "BenchRegPass123!",
|
||||
"first_name": "Bench",
|
||||
"last_name": "Register",
|
||||
},
|
||||
)
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
total_ms += elapsed_ms
|
||||
assert response.status_code == 201, f"Register failed: {response.text}"
|
||||
|
||||
mean_ms = total_ms / iterations
|
||||
print(
|
||||
f"\n Register mean latency: {mean_ms:.1f}ms (threshold: {MAX_REGISTER_MS}ms)"
|
||||
)
|
||||
assert mean_ms < MAX_REGISTER_MS, (
|
||||
f"Register latency regression: {mean_ms:.1f}ms exceeds {MAX_REGISTER_MS}ms threshold"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_refresh_latency(client, bench_refresh_token):
|
||||
"""Performance: POST /api/v1/auth/refresh must respond under threshold."""
|
||||
iterations = 5
|
||||
total_ms = 0.0
|
||||
|
||||
for _ in range(iterations):
|
||||
start = time.perf_counter()
|
||||
response = await client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": bench_refresh_token},
|
||||
)
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
total_ms += elapsed_ms
|
||||
assert response.status_code == 200, f"Refresh failed: {response.text}"
|
||||
# Use the new refresh token for the next iteration
|
||||
bench_refresh_token = response.json()["refresh_token"]
|
||||
|
||||
mean_ms = total_ms / iterations
|
||||
print(
|
||||
f"\n Token refresh mean latency: {mean_ms:.1f}ms (threshold: {MAX_TOKEN_REFRESH_MS}ms)"
|
||||
)
|
||||
assert mean_ms < MAX_TOKEN_REFRESH_MS, (
|
||||
f"Token refresh latency regression: {mean_ms:.1f}ms exceeds {MAX_TOKEN_REFRESH_MS}ms threshold"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sessions_list_latency(client, bench_token):
|
||||
"""Performance: GET /api/v1/sessions must respond under threshold."""
|
||||
iterations = 10
|
||||
total_ms = 0.0
|
||||
|
||||
for _ in range(iterations):
|
||||
start = time.perf_counter()
|
||||
response = await client.get(
|
||||
"/api/v1/sessions/me",
|
||||
headers={"Authorization": f"Bearer {bench_token}"},
|
||||
)
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
total_ms += elapsed_ms
|
||||
assert response.status_code == 200
|
||||
|
||||
mean_ms = total_ms / iterations
|
||||
print(
|
||||
f"\n Sessions list mean latency: {mean_ms:.1f}ms (threshold: {MAX_SESSIONS_LIST_MS}ms)"
|
||||
)
|
||||
assert mean_ms < MAX_SESSIONS_LIST_MS, (
|
||||
f"Sessions list latency regression: {mean_ms:.1f}ms exceeds {MAX_SESSIONS_LIST_MS}ms threshold"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_profile_update_latency(client, bench_token):
|
||||
"""Performance: PATCH /api/v1/users/me must respond under threshold."""
|
||||
iterations = 5
|
||||
total_ms = 0.0
|
||||
|
||||
for i in range(iterations):
|
||||
start = time.perf_counter()
|
||||
response = await client.patch(
|
||||
"/api/v1/users/me",
|
||||
headers={"Authorization": f"Bearer {bench_token}"},
|
||||
json={"first_name": f"Bench{i}"},
|
||||
)
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
total_ms += elapsed_ms
|
||||
assert response.status_code == 200, f"Update failed: {response.text}"
|
||||
|
||||
mean_ms = total_ms / iterations
|
||||
print(
|
||||
f"\n User update mean latency: {mean_ms:.1f}ms (threshold: {MAX_USER_UPDATE_MS}ms)"
|
||||
)
|
||||
assert mean_ms < MAX_USER_UPDATE_MS, (
|
||||
f"User update latency regression: {mean_ms:.1f}ms exceeds {MAX_USER_UPDATE_MS}ms threshold"
|
||||
)
|
||||
@@ -2,8 +2,8 @@
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import jwt
|
||||
import pytest
|
||||
from jose import jwt
|
||||
|
||||
from app.core.auth import (
|
||||
TokenExpiredError,
|
||||
@@ -215,6 +215,7 @@ class TestTokenDecoding:
|
||||
payload = {
|
||||
"sub": 123, # sub should be a string, not an integer
|
||||
"exp": int((now + timedelta(minutes=30)).timestamp()),
|
||||
"iat": int(now.timestamp()),
|
||||
}
|
||||
|
||||
token = jwt.encode(payload, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||||
|
||||
@@ -9,8 +9,8 @@ Critical security tests covering:
|
||||
These tests cover critical security vulnerabilities that could be exploited.
|
||||
"""
|
||||
|
||||
import jwt
|
||||
import pytest
|
||||
from jose import jwt
|
||||
|
||||
from app.core.auth import TokenInvalidError, create_access_token, decode_token
|
||||
from app.core.config import settings
|
||||
@@ -38,8 +38,8 @@ class TestJWTAlgorithmSecurityAttacks:
|
||||
Attacker creates a token with "alg: none" to bypass signature verification.
|
||||
|
||||
NOTE: Lines 209 and 212 in auth.py are DEFENSIVE CODE that's never reached
|
||||
because python-jose library rejects "none" algorithm tokens BEFORE we get there.
|
||||
This is good for security! The library throws JWTError which becomes TokenInvalidError.
|
||||
because PyJWT rejects "none" algorithm tokens BEFORE we get there.
|
||||
This is good for security! The library throws InvalidTokenError which becomes TokenInvalidError.
|
||||
|
||||
This test verifies the overall protection works, even though our defensive
|
||||
checks at lines 209-212 don't execute because the library catches it first.
|
||||
@@ -108,36 +108,33 @@ class TestJWTAlgorithmSecurityAttacks:
|
||||
Test that tokens with wrong algorithm are rejected.
|
||||
|
||||
Attack Scenario:
|
||||
Attacker changes algorithm from HS256 to RS256, attempting to use
|
||||
the public key as the HMAC secret. This could allow token forgery.
|
||||
Attacker changes the "alg" header to RS256 while keeping an HMAC
|
||||
signature, attempting algorithm confusion to forge tokens.
|
||||
|
||||
Reference: https://www.nccgroup.com/us/about-us/newsroom-and-events/blog/2019/january/jwt-algorithm-confusion/
|
||||
|
||||
NOTE: Like the "none" algorithm test, python-jose library catches this
|
||||
before our defensive checks at line 212. This is good for security!
|
||||
"""
|
||||
import base64
|
||||
import json
|
||||
import time
|
||||
|
||||
now = int(time.time())
|
||||
|
||||
# Create a valid payload
|
||||
payload = {"sub": "user123", "exp": now + 3600, "iat": now, "type": "access"}
|
||||
|
||||
# Encode with wrong algorithm (RS256 instead of HS256)
|
||||
# This simulates an attacker trying algorithm substitution
|
||||
wrong_algorithm = "RS256" if settings.ALGORITHM == "HS256" else "HS256"
|
||||
# Hand-craft a token claiming RS256 in the header — PyJWT cannot encode
|
||||
# RS256 with an HMAC key, so we craft the header manually (same technique
|
||||
# as the "alg: none" tests) to produce a token that actually reaches decode_token.
|
||||
header = {"alg": "RS256", "typ": "JWT"}
|
||||
header_encoded = (
|
||||
base64.urlsafe_b64encode(json.dumps(header).encode()).decode().rstrip("=")
|
||||
)
|
||||
payload_encoded = (
|
||||
base64.urlsafe_b64encode(json.dumps(payload).encode()).decode().rstrip("=")
|
||||
)
|
||||
# Attach a fake signature to form a complete (but invalid) JWT
|
||||
malicious_token = f"{header_encoded}.{payload_encoded}.fakesignature"
|
||||
|
||||
try:
|
||||
malicious_token = jwt.encode(
|
||||
payload, settings.SECRET_KEY, algorithm=wrong_algorithm
|
||||
)
|
||||
|
||||
# Should reject the token (library catches mismatch)
|
||||
with pytest.raises(TokenInvalidError):
|
||||
decode_token(malicious_token)
|
||||
except Exception:
|
||||
# If encoding fails, that's also acceptable (library protection)
|
||||
pass
|
||||
with pytest.raises(TokenInvalidError):
|
||||
decode_token(malicious_token)
|
||||
|
||||
def test_reject_hs384_when_hs256_expected(self):
|
||||
"""
|
||||
@@ -151,17 +148,11 @@ class TestJWTAlgorithmSecurityAttacks:
|
||||
|
||||
payload = {"sub": "user123", "exp": now + 3600, "iat": now, "type": "access"}
|
||||
|
||||
# Create token with HS384 instead of HS256
|
||||
try:
|
||||
malicious_token = jwt.encode(
|
||||
payload, settings.SECRET_KEY, algorithm="HS384"
|
||||
)
|
||||
# Create token with HS384 instead of HS256 (HMAC key works with HS384)
|
||||
malicious_token = jwt.encode(payload, settings.SECRET_KEY, algorithm="HS384")
|
||||
|
||||
with pytest.raises(TokenInvalidError):
|
||||
decode_token(malicious_token)
|
||||
except Exception:
|
||||
# If encoding fails, that's also fine
|
||||
pass
|
||||
with pytest.raises(TokenInvalidError):
|
||||
decode_token(malicious_token)
|
||||
|
||||
def test_valid_token_with_correct_algorithm_accepted(self):
|
||||
"""
|
||||
|
||||
@@ -46,7 +46,7 @@ async def login_user(client, email: str, password: str = "SecurePassword123!"):
|
||||
|
||||
async def create_superuser(e2e_db_session, email: str, password: str):
|
||||
"""Create a superuser directly in the database."""
|
||||
from app.crud.user import user as user_crud
|
||||
from app.repositories.user import user_repo as user_repo
|
||||
from app.schemas.users import UserCreate
|
||||
|
||||
user_in = UserCreate(
|
||||
@@ -56,7 +56,7 @@ async def create_superuser(e2e_db_session, email: str, password: str):
|
||||
last_name="User",
|
||||
is_superuser=True,
|
||||
)
|
||||
user = await user_crud.create(e2e_db_session, obj_in=user_in)
|
||||
user = await user_repo.create(e2e_db_session, obj_in=user_in)
|
||||
return user
|
||||
|
||||
|
||||
|
||||
@@ -27,13 +27,16 @@ except ImportError:
|
||||
pytestmark = [
|
||||
pytest.mark.e2e,
|
||||
pytest.mark.schemathesis,
|
||||
pytest.mark.skipif(
|
||||
not SCHEMATHESIS_AVAILABLE,
|
||||
reason="schemathesis not installed - run: make install-e2e",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
if not SCHEMATHESIS_AVAILABLE:
|
||||
|
||||
def test_schemathesis_compatibility():
|
||||
"""Gracefully handle missing schemathesis dependency."""
|
||||
pytest.skip("schemathesis not installed - run: make install-e2e")
|
||||
|
||||
|
||||
if SCHEMATHESIS_AVAILABLE:
|
||||
from app.main import app
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ async def register_and_login(client, email: str, password: str = "SecurePassword
|
||||
|
||||
async def create_superuser_and_login(client, db_session):
|
||||
"""Helper to create a superuser directly in DB and login."""
|
||||
from app.crud.user import user as user_crud
|
||||
from app.repositories.user import user_repo as user_repo
|
||||
from app.schemas.users import UserCreate
|
||||
|
||||
email = f"admin-{uuid4().hex[:8]}@example.com"
|
||||
@@ -60,7 +60,7 @@ async def create_superuser_and_login(client, db_session):
|
||||
last_name="User",
|
||||
is_superuser=True,
|
||||
)
|
||||
await user_crud.create(db_session, obj_in=user_in)
|
||||
await user_repo.create(db_session, obj_in=user_in)
|
||||
|
||||
# Login
|
||||
login_resp = await client.post(
|
||||
|
||||
0
backend/tests/repositories/__init__.py
Executable file
0
backend/tests/repositories/__init__.py
Executable file
@@ -1,6 +1,6 @@
|
||||
# tests/crud/test_base.py
|
||||
# tests/repositories/test_base.py
|
||||
"""
|
||||
Comprehensive tests for CRUDBase class covering all error paths and edge cases.
|
||||
Comprehensive tests for BaseRepository class covering all error paths and edge cases.
|
||||
"""
|
||||
|
||||
from datetime import UTC
|
||||
@@ -11,11 +11,16 @@ import pytest
|
||||
from sqlalchemy.exc import DataError, IntegrityError, OperationalError
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from app.crud.user import user as user_crud
|
||||
from app.core.repository_exceptions import (
|
||||
DuplicateEntryError,
|
||||
IntegrityConstraintError,
|
||||
InvalidInputError,
|
||||
)
|
||||
from app.repositories.user import user_repo as user_repo
|
||||
from app.schemas.users import UserCreate, UserUpdate
|
||||
|
||||
|
||||
class TestCRUDBaseGet:
|
||||
class TestRepositoryBaseGet:
|
||||
"""Tests for get method covering UUID validation and options."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -24,7 +29,7 @@ class TestCRUDBaseGet:
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
result = await user_crud.get(session, id="invalid-uuid")
|
||||
result = await user_repo.get(session, id="invalid-uuid")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -33,7 +38,7 @@ class TestCRUDBaseGet:
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
result = await user_crud.get(session, id=12345) # int instead of UUID
|
||||
result = await user_repo.get(session, id=12345) # int instead of UUID
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -43,7 +48,7 @@ class TestCRUDBaseGet:
|
||||
|
||||
async with SessionLocal() as session:
|
||||
# Pass UUID object directly
|
||||
result = await user_crud.get(session, id=async_test_user.id)
|
||||
result = await user_repo.get(session, id=async_test_user.id)
|
||||
assert result is not None
|
||||
assert result.id == async_test_user.id
|
||||
|
||||
@@ -55,7 +60,7 @@ class TestCRUDBaseGet:
|
||||
async with SessionLocal() as session:
|
||||
# Test that options parameter is accepted and doesn't error
|
||||
# We pass an empty list which still tests the code path
|
||||
result = await user_crud.get(
|
||||
result = await user_repo.get(
|
||||
session, id=str(async_test_user.id), options=[]
|
||||
)
|
||||
assert result is not None
|
||||
@@ -69,10 +74,10 @@ class TestCRUDBaseGet:
|
||||
# Mock execute to raise an exception
|
||||
with patch.object(session, "execute", side_effect=Exception("DB error")):
|
||||
with pytest.raises(Exception, match="DB error"):
|
||||
await user_crud.get(session, id=str(uuid4()))
|
||||
await user_repo.get(session, id=str(uuid4()))
|
||||
|
||||
|
||||
class TestCRUDBaseGetMulti:
|
||||
class TestRepositoryBaseGetMulti:
|
||||
"""Tests for get_multi method covering pagination validation and options."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -81,8 +86,8 @@ class TestCRUDBaseGetMulti:
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="skip must be non-negative"):
|
||||
await user_crud.get_multi(session, skip=-1)
|
||||
with pytest.raises(InvalidInputError, match="skip must be non-negative"):
|
||||
await user_repo.get_multi(session, skip=-1)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_negative_limit(self, async_test_db):
|
||||
@@ -90,8 +95,8 @@ class TestCRUDBaseGetMulti:
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="limit must be non-negative"):
|
||||
await user_crud.get_multi(session, limit=-1)
|
||||
with pytest.raises(InvalidInputError, match="limit must be non-negative"):
|
||||
await user_repo.get_multi(session, limit=-1)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_limit_too_large(self, async_test_db):
|
||||
@@ -99,8 +104,8 @@ class TestCRUDBaseGetMulti:
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="Maximum limit is 1000"):
|
||||
await user_crud.get_multi(session, limit=1001)
|
||||
with pytest.raises(InvalidInputError, match="Maximum limit is 1000"):
|
||||
await user_repo.get_multi(session, limit=1001)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_options(self, async_test_db, async_test_user):
|
||||
@@ -109,7 +114,7 @@ class TestCRUDBaseGetMulti:
|
||||
|
||||
async with SessionLocal() as session:
|
||||
# Test that options parameter is accepted
|
||||
results = await user_crud.get_multi(session, skip=0, limit=10, options=[])
|
||||
results = await user_repo.get_multi(session, skip=0, limit=10, options=[])
|
||||
assert isinstance(results, list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -120,10 +125,10 @@ class TestCRUDBaseGetMulti:
|
||||
async with SessionLocal() as session:
|
||||
with patch.object(session, "execute", side_effect=Exception("DB error")):
|
||||
with pytest.raises(Exception, match="DB error"):
|
||||
await user_crud.get_multi(session)
|
||||
await user_repo.get_multi(session)
|
||||
|
||||
|
||||
class TestCRUDBaseCreate:
|
||||
class TestRepositoryBaseCreate:
|
||||
"""Tests for create method covering various error conditions."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -140,8 +145,8 @@ class TestCRUDBaseCreate:
|
||||
last_name="Duplicate",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
await user_crud.create(session, obj_in=user_data)
|
||||
with pytest.raises(DuplicateEntryError, match="already exists"):
|
||||
await user_repo.create(session, obj_in=user_data)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_integrity_error_non_duplicate(self, async_test_db):
|
||||
@@ -165,12 +170,14 @@ class TestCRUDBaseCreate:
|
||||
last_name="User",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Database integrity error"):
|
||||
await user_crud.create(session, obj_in=user_data)
|
||||
with pytest.raises(
|
||||
DuplicateEntryError, match="Database integrity error"
|
||||
):
|
||||
await user_repo.create(session, obj_in=user_data)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_operational_error(self, async_test_db):
|
||||
"""Test create with OperationalError (user CRUD catches as generic Exception)."""
|
||||
"""Test create with OperationalError (user repository catches as generic Exception)."""
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
@@ -188,13 +195,13 @@ class TestCRUDBaseCreate:
|
||||
last_name="User",
|
||||
)
|
||||
|
||||
# User CRUD catches this as generic Exception and re-raises
|
||||
# User repository catches this as generic Exception and re-raises
|
||||
with pytest.raises(OperationalError):
|
||||
await user_crud.create(session, obj_in=user_data)
|
||||
await user_repo.create(session, obj_in=user_data)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_data_error(self, async_test_db):
|
||||
"""Test create with DataError (user CRUD catches as generic Exception)."""
|
||||
"""Test create with DataError (user repository catches as generic Exception)."""
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
@@ -210,9 +217,9 @@ class TestCRUDBaseCreate:
|
||||
last_name="User",
|
||||
)
|
||||
|
||||
# User CRUD catches this as generic Exception and re-raises
|
||||
# User repository catches this as generic Exception and re-raises
|
||||
with pytest.raises(DataError):
|
||||
await user_crud.create(session, obj_in=user_data)
|
||||
await user_repo.create(session, obj_in=user_data)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_unexpected_error(self, async_test_db):
|
||||
@@ -231,10 +238,10 @@ class TestCRUDBaseCreate:
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Unexpected error"):
|
||||
await user_crud.create(session, obj_in=user_data)
|
||||
await user_repo.create(session, obj_in=user_data)
|
||||
|
||||
|
||||
class TestCRUDBaseUpdate:
|
||||
class TestRepositoryBaseUpdate:
|
||||
"""Tests for update method covering error conditions."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -244,7 +251,7 @@ class TestCRUDBaseUpdate:
|
||||
|
||||
# Create another user
|
||||
async with SessionLocal() as session:
|
||||
from app.crud.user import user as user_crud
|
||||
from app.repositories.user import user_repo as user_repo
|
||||
|
||||
user2_data = UserCreate(
|
||||
email="user2@example.com",
|
||||
@@ -252,12 +259,12 @@ class TestCRUDBaseUpdate:
|
||||
first_name="User",
|
||||
last_name="Two",
|
||||
)
|
||||
user2 = await user_crud.create(session, obj_in=user2_data)
|
||||
user2 = await user_repo.create(session, obj_in=user2_data)
|
||||
await session.commit()
|
||||
|
||||
# Try to update user2 with user1's email
|
||||
async with SessionLocal() as session:
|
||||
user2_obj = await user_crud.get(session, id=str(user2.id))
|
||||
user2_obj = await user_repo.get(session, id=str(user2.id))
|
||||
|
||||
with patch.object(
|
||||
session,
|
||||
@@ -268,8 +275,8 @@ class TestCRUDBaseUpdate:
|
||||
):
|
||||
update_data = UserUpdate(email=async_test_user.email)
|
||||
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
await user_crud.update(
|
||||
with pytest.raises(DuplicateEntryError, match="already exists"):
|
||||
await user_repo.update(
|
||||
session, db_obj=user2_obj, obj_in=update_data
|
||||
)
|
||||
|
||||
@@ -279,10 +286,10 @@ class TestCRUDBaseUpdate:
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
user = await user_crud.get(session, id=str(async_test_user.id))
|
||||
user = await user_repo.get(session, id=str(async_test_user.id))
|
||||
|
||||
# Update with dict (tests lines 164-165)
|
||||
updated = await user_crud.update(
|
||||
updated = await user_repo.update(
|
||||
session, db_obj=user, obj_in={"first_name": "UpdatedName"}
|
||||
)
|
||||
assert updated.first_name == "UpdatedName"
|
||||
@@ -293,7 +300,7 @@ class TestCRUDBaseUpdate:
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
user = await user_crud.get(session, id=str(async_test_user.id))
|
||||
user = await user_repo.get(session, id=str(async_test_user.id))
|
||||
|
||||
with patch.object(
|
||||
session,
|
||||
@@ -302,8 +309,10 @@ class TestCRUDBaseUpdate:
|
||||
"statement", {}, Exception("constraint failed")
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError, match="Database integrity error"):
|
||||
await user_crud.update(
|
||||
with pytest.raises(
|
||||
IntegrityConstraintError, match="Database integrity error"
|
||||
):
|
||||
await user_repo.update(
|
||||
session, db_obj=user, obj_in={"first_name": "Test"}
|
||||
)
|
||||
|
||||
@@ -313,7 +322,7 @@ class TestCRUDBaseUpdate:
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
user = await user_crud.get(session, id=str(async_test_user.id))
|
||||
user = await user_repo.get(session, id=str(async_test_user.id))
|
||||
|
||||
with patch.object(
|
||||
session,
|
||||
@@ -322,8 +331,10 @@ class TestCRUDBaseUpdate:
|
||||
"statement", {}, Exception("connection error")
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError, match="Database operation failed"):
|
||||
await user_crud.update(
|
||||
with pytest.raises(
|
||||
IntegrityConstraintError, match="Database operation failed"
|
||||
):
|
||||
await user_repo.update(
|
||||
session, db_obj=user, obj_in={"first_name": "Test"}
|
||||
)
|
||||
|
||||
@@ -333,18 +344,18 @@ class TestCRUDBaseUpdate:
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
user = await user_crud.get(session, id=str(async_test_user.id))
|
||||
user = await user_repo.get(session, id=str(async_test_user.id))
|
||||
|
||||
with patch.object(
|
||||
session, "commit", side_effect=RuntimeError("Unexpected")
|
||||
):
|
||||
with pytest.raises(RuntimeError):
|
||||
await user_crud.update(
|
||||
await user_repo.update(
|
||||
session, db_obj=user, obj_in={"first_name": "Test"}
|
||||
)
|
||||
|
||||
|
||||
class TestCRUDBaseRemove:
|
||||
class TestRepositoryBaseRemove:
|
||||
"""Tests for remove method covering UUID validation and error conditions."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -353,7 +364,7 @@ class TestCRUDBaseRemove:
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
result = await user_crud.remove(session, id="invalid-uuid")
|
||||
result = await user_repo.remove(session, id="invalid-uuid")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -369,13 +380,13 @@ class TestCRUDBaseRemove:
|
||||
first_name="To",
|
||||
last_name="Delete",
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
user = await user_repo.create(session, obj_in=user_data)
|
||||
user_id = user.id
|
||||
await session.commit()
|
||||
|
||||
# Delete with UUID object
|
||||
async with SessionLocal() as session:
|
||||
result = await user_crud.remove(session, id=user_id) # UUID object
|
||||
result = await user_repo.remove(session, id=user_id) # UUID object
|
||||
assert result is not None
|
||||
assert result.id == user_id
|
||||
|
||||
@@ -385,7 +396,7 @@ class TestCRUDBaseRemove:
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
result = await user_crud.remove(session, id=str(uuid4()))
|
||||
result = await user_repo.remove(session, id=str(uuid4()))
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -403,9 +414,10 @@ class TestCRUDBaseRemove:
|
||||
),
|
||||
):
|
||||
with pytest.raises(
|
||||
ValueError, match="Cannot delete.*referenced by other records"
|
||||
IntegrityConstraintError,
|
||||
match="Cannot delete.*referenced by other records",
|
||||
):
|
||||
await user_crud.remove(session, id=str(async_test_user.id))
|
||||
await user_repo.remove(session, id=str(async_test_user.id))
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_unexpected_error(self, async_test_db, async_test_user):
|
||||
@@ -417,10 +429,10 @@ class TestCRUDBaseRemove:
|
||||
session, "commit", side_effect=RuntimeError("Unexpected")
|
||||
):
|
||||
with pytest.raises(RuntimeError):
|
||||
await user_crud.remove(session, id=str(async_test_user.id))
|
||||
await user_repo.remove(session, id=str(async_test_user.id))
|
||||
|
||||
|
||||
class TestCRUDBaseGetMultiWithTotal:
|
||||
class TestRepositoryBaseGetMultiWithTotal:
|
||||
"""Tests for get_multi_with_total method covering pagination, filtering, sorting."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -429,7 +441,7 @@ class TestCRUDBaseGetMultiWithTotal:
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
items, total = await user_crud.get_multi_with_total(
|
||||
items, total = await user_repo.get_multi_with_total(
|
||||
session, skip=0, limit=10
|
||||
)
|
||||
assert isinstance(items, list)
|
||||
@@ -442,8 +454,8 @@ class TestCRUDBaseGetMultiWithTotal:
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="skip must be non-negative"):
|
||||
await user_crud.get_multi_with_total(session, skip=-1)
|
||||
with pytest.raises(InvalidInputError, match="skip must be non-negative"):
|
||||
await user_repo.get_multi_with_total(session, skip=-1)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_negative_limit(self, async_test_db):
|
||||
@@ -451,8 +463,8 @@ class TestCRUDBaseGetMultiWithTotal:
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="limit must be non-negative"):
|
||||
await user_crud.get_multi_with_total(session, limit=-1)
|
||||
with pytest.raises(InvalidInputError, match="limit must be non-negative"):
|
||||
await user_repo.get_multi_with_total(session, limit=-1)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_limit_too_large(self, async_test_db):
|
||||
@@ -460,8 +472,8 @@ class TestCRUDBaseGetMultiWithTotal:
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="Maximum limit is 1000"):
|
||||
await user_crud.get_multi_with_total(session, limit=1001)
|
||||
with pytest.raises(InvalidInputError, match="Maximum limit is 1000"):
|
||||
await user_repo.get_multi_with_total(session, limit=1001)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_with_filters(
|
||||
@@ -472,7 +484,7 @@ class TestCRUDBaseGetMultiWithTotal:
|
||||
|
||||
async with SessionLocal() as session:
|
||||
filters = {"email": async_test_user.email}
|
||||
items, total = await user_crud.get_multi_with_total(
|
||||
items, total = await user_repo.get_multi_with_total(
|
||||
session, filters=filters
|
||||
)
|
||||
assert total == 1
|
||||
@@ -500,12 +512,12 @@ class TestCRUDBaseGetMultiWithTotal:
|
||||
first_name="ZZZ",
|
||||
last_name="User",
|
||||
)
|
||||
await user_crud.create(session, obj_in=user_data1)
|
||||
await user_crud.create(session, obj_in=user_data2)
|
||||
await user_repo.create(session, obj_in=user_data1)
|
||||
await user_repo.create(session, obj_in=user_data2)
|
||||
await session.commit()
|
||||
|
||||
async with SessionLocal() as session:
|
||||
items, total = await user_crud.get_multi_with_total(
|
||||
items, total = await user_repo.get_multi_with_total(
|
||||
session, sort_by="email", sort_order="asc"
|
||||
)
|
||||
assert total >= 3
|
||||
@@ -533,12 +545,12 @@ class TestCRUDBaseGetMultiWithTotal:
|
||||
first_name="CCC",
|
||||
last_name="User",
|
||||
)
|
||||
await user_crud.create(session, obj_in=user_data1)
|
||||
await user_crud.create(session, obj_in=user_data2)
|
||||
await user_repo.create(session, obj_in=user_data1)
|
||||
await user_repo.create(session, obj_in=user_data2)
|
||||
await session.commit()
|
||||
|
||||
async with SessionLocal() as session:
|
||||
items, _total = await user_crud.get_multi_with_total(
|
||||
items, _total = await user_repo.get_multi_with_total(
|
||||
session, sort_by="email", sort_order="desc", limit=1
|
||||
)
|
||||
assert len(items) == 1
|
||||
@@ -558,19 +570,19 @@ class TestCRUDBaseGetMultiWithTotal:
|
||||
first_name=f"User{i}",
|
||||
last_name="Test",
|
||||
)
|
||||
await user_crud.create(session, obj_in=user_data)
|
||||
await user_repo.create(session, obj_in=user_data)
|
||||
await session.commit()
|
||||
|
||||
async with SessionLocal() as session:
|
||||
# Get first page
|
||||
items1, total = await user_crud.get_multi_with_total(
|
||||
items1, total = await user_repo.get_multi_with_total(
|
||||
session, skip=0, limit=2
|
||||
)
|
||||
assert len(items1) == 2
|
||||
assert total >= 3
|
||||
|
||||
# Get second page
|
||||
items2, total2 = await user_crud.get_multi_with_total(
|
||||
items2, total2 = await user_repo.get_multi_with_total(
|
||||
session, skip=2, limit=2
|
||||
)
|
||||
assert len(items2) >= 1
|
||||
@@ -582,7 +594,7 @@ class TestCRUDBaseGetMultiWithTotal:
|
||||
assert ids1.isdisjoint(ids2)
|
||||
|
||||
|
||||
class TestCRUDBaseCount:
|
||||
class TestRepositoryBaseCount:
|
||||
"""Tests for count method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -591,7 +603,7 @@ class TestCRUDBaseCount:
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
count = await user_crud.count(session)
|
||||
count = await user_repo.count(session)
|
||||
assert isinstance(count, int)
|
||||
assert count >= 1 # At least the test user
|
||||
|
||||
@@ -602,7 +614,7 @@ class TestCRUDBaseCount:
|
||||
|
||||
# Create additional users
|
||||
async with SessionLocal() as session:
|
||||
initial_count = await user_crud.count(session)
|
||||
initial_count = await user_repo.count(session)
|
||||
|
||||
user_data1 = UserCreate(
|
||||
email="count1@example.com",
|
||||
@@ -616,12 +628,12 @@ class TestCRUDBaseCount:
|
||||
first_name="Count",
|
||||
last_name="Two",
|
||||
)
|
||||
await user_crud.create(session, obj_in=user_data1)
|
||||
await user_crud.create(session, obj_in=user_data2)
|
||||
await user_repo.create(session, obj_in=user_data1)
|
||||
await user_repo.create(session, obj_in=user_data2)
|
||||
await session.commit()
|
||||
|
||||
async with SessionLocal() as session:
|
||||
new_count = await user_crud.count(session)
|
||||
new_count = await user_repo.count(session)
|
||||
assert new_count == initial_count + 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -632,10 +644,10 @@ class TestCRUDBaseCount:
|
||||
async with SessionLocal() as session:
|
||||
with patch.object(session, "execute", side_effect=Exception("DB error")):
|
||||
with pytest.raises(Exception, match="DB error"):
|
||||
await user_crud.count(session)
|
||||
await user_repo.count(session)
|
||||
|
||||
|
||||
class TestCRUDBaseExists:
|
||||
class TestRepositoryBaseExists:
|
||||
"""Tests for exists method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -644,7 +656,7 @@ class TestCRUDBaseExists:
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
result = await user_crud.exists(session, id=str(async_test_user.id))
|
||||
result = await user_repo.exists(session, id=str(async_test_user.id))
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -653,7 +665,7 @@ class TestCRUDBaseExists:
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
result = await user_crud.exists(session, id=str(uuid4()))
|
||||
result = await user_repo.exists(session, id=str(uuid4()))
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -662,11 +674,11 @@ class TestCRUDBaseExists:
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
result = await user_crud.exists(session, id="invalid-uuid")
|
||||
result = await user_repo.exists(session, id="invalid-uuid")
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestCRUDBaseSoftDelete:
|
||||
class TestRepositoryBaseSoftDelete:
|
||||
"""Tests for soft_delete method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -682,13 +694,13 @@ class TestCRUDBaseSoftDelete:
|
||||
first_name="Soft",
|
||||
last_name="Delete",
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
user = await user_repo.create(session, obj_in=user_data)
|
||||
user_id = user.id
|
||||
await session.commit()
|
||||
|
||||
# Soft delete the user
|
||||
async with SessionLocal() as session:
|
||||
deleted = await user_crud.soft_delete(session, id=str(user_id))
|
||||
deleted = await user_repo.soft_delete(session, id=str(user_id))
|
||||
assert deleted is not None
|
||||
assert deleted.deleted_at is not None
|
||||
|
||||
@@ -698,7 +710,7 @@ class TestCRUDBaseSoftDelete:
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
result = await user_crud.soft_delete(session, id="invalid-uuid")
|
||||
result = await user_repo.soft_delete(session, id="invalid-uuid")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -707,7 +719,7 @@ class TestCRUDBaseSoftDelete:
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
result = await user_crud.soft_delete(session, id=str(uuid4()))
|
||||
result = await user_repo.soft_delete(session, id=str(uuid4()))
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -723,18 +735,18 @@ class TestCRUDBaseSoftDelete:
|
||||
first_name="Soft",
|
||||
last_name="Delete2",
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
user = await user_repo.create(session, obj_in=user_data)
|
||||
user_id = user.id
|
||||
await session.commit()
|
||||
|
||||
# Soft delete with UUID object
|
||||
async with SessionLocal() as session:
|
||||
deleted = await user_crud.soft_delete(session, id=user_id) # UUID object
|
||||
deleted = await user_repo.soft_delete(session, id=user_id) # UUID object
|
||||
assert deleted is not None
|
||||
assert deleted.deleted_at is not None
|
||||
|
||||
|
||||
class TestCRUDBaseRestore:
|
||||
class TestRepositoryBaseRestore:
|
||||
"""Tests for restore method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -750,16 +762,16 @@ class TestCRUDBaseRestore:
|
||||
first_name="Restore",
|
||||
last_name="Test",
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
user = await user_repo.create(session, obj_in=user_data)
|
||||
user_id = user.id
|
||||
await session.commit()
|
||||
|
||||
async with SessionLocal() as session:
|
||||
await user_crud.soft_delete(session, id=str(user_id))
|
||||
await user_repo.soft_delete(session, id=str(user_id))
|
||||
|
||||
# Restore the user
|
||||
async with SessionLocal() as session:
|
||||
restored = await user_crud.restore(session, id=str(user_id))
|
||||
restored = await user_repo.restore(session, id=str(user_id))
|
||||
assert restored is not None
|
||||
assert restored.deleted_at is None
|
||||
|
||||
@@ -769,7 +781,7 @@ class TestCRUDBaseRestore:
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
result = await user_crud.restore(session, id="invalid-uuid")
|
||||
result = await user_repo.restore(session, id="invalid-uuid")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -778,7 +790,7 @@ class TestCRUDBaseRestore:
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
result = await user_crud.restore(session, id=str(uuid4()))
|
||||
result = await user_repo.restore(session, id=str(uuid4()))
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -788,7 +800,7 @@ class TestCRUDBaseRestore:
|
||||
|
||||
async with SessionLocal() as session:
|
||||
# Try to restore a user that's not deleted
|
||||
result = await user_crud.restore(session, id=str(async_test_user.id))
|
||||
result = await user_repo.restore(session, id=str(async_test_user.id))
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -804,21 +816,21 @@ class TestCRUDBaseRestore:
|
||||
first_name="Restore",
|
||||
last_name="Test2",
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
user = await user_repo.create(session, obj_in=user_data)
|
||||
user_id = user.id
|
||||
await session.commit()
|
||||
|
||||
async with SessionLocal() as session:
|
||||
await user_crud.soft_delete(session, id=str(user_id))
|
||||
await user_repo.soft_delete(session, id=str(user_id))
|
||||
|
||||
# Restore with UUID object
|
||||
async with SessionLocal() as session:
|
||||
restored = await user_crud.restore(session, id=user_id) # UUID object
|
||||
restored = await user_repo.restore(session, id=user_id) # UUID object
|
||||
assert restored is not None
|
||||
assert restored.deleted_at is None
|
||||
|
||||
|
||||
class TestCRUDBasePaginationValidation:
|
||||
class TestRepositoryBasePaginationValidation:
|
||||
"""Tests for pagination parameter validation (covers lines 254-260)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -827,8 +839,8 @@ class TestCRUDBasePaginationValidation:
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="skip must be non-negative"):
|
||||
await user_crud.get_multi_with_total(session, skip=-1, limit=10)
|
||||
with pytest.raises(InvalidInputError, match="skip must be non-negative"):
|
||||
await user_repo.get_multi_with_total(session, skip=-1, limit=10)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_negative_limit(self, async_test_db):
|
||||
@@ -836,8 +848,8 @@ class TestCRUDBasePaginationValidation:
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="limit must be non-negative"):
|
||||
await user_crud.get_multi_with_total(session, skip=0, limit=-1)
|
||||
with pytest.raises(InvalidInputError, match="limit must be non-negative"):
|
||||
await user_repo.get_multi_with_total(session, skip=0, limit=-1)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_limit_too_large(self, async_test_db):
|
||||
@@ -845,8 +857,8 @@ class TestCRUDBasePaginationValidation:
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="Maximum limit is 1000"):
|
||||
await user_crud.get_multi_with_total(session, skip=0, limit=1001)
|
||||
with pytest.raises(InvalidInputError, match="Maximum limit is 1000"):
|
||||
await user_repo.get_multi_with_total(session, skip=0, limit=1001)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_total_with_filters(
|
||||
@@ -856,7 +868,7 @@ class TestCRUDBasePaginationValidation:
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
users, total = await user_crud.get_multi_with_total(
|
||||
users, total = await user_repo.get_multi_with_total(
|
||||
session, skip=0, limit=10, filters={"is_active": True}
|
||||
)
|
||||
assert isinstance(users, list)
|
||||
@@ -868,7 +880,7 @@ class TestCRUDBasePaginationValidation:
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
users, _total = await user_crud.get_multi_with_total(
|
||||
users, _total = await user_repo.get_multi_with_total(
|
||||
session, skip=0, limit=10, sort_by="created_at", sort_order="desc"
|
||||
)
|
||||
assert isinstance(users, list)
|
||||
@@ -879,13 +891,13 @@ class TestCRUDBasePaginationValidation:
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
users, _total = await user_crud.get_multi_with_total(
|
||||
users, _total = await user_repo.get_multi_with_total(
|
||||
session, skip=0, limit=10, sort_by="created_at", sort_order="asc"
|
||||
)
|
||||
assert isinstance(users, list)
|
||||
|
||||
|
||||
class TestCRUDBaseModelsWithoutSoftDelete:
|
||||
class TestRepositoryBaseModelsWithoutSoftDelete:
|
||||
"""
|
||||
Test soft_delete and restore on models without deleted_at column.
|
||||
Covers lines 342-343, 383-384 - error handling for unsupported models.
|
||||
@@ -899,8 +911,8 @@ class TestCRUDBaseModelsWithoutSoftDelete:
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create an organization (which doesn't have deleted_at)
|
||||
from app.crud.organization import organization as org_crud
|
||||
from app.models.organization import Organization
|
||||
from app.repositories.organization import organization_repo as org_repo
|
||||
|
||||
async with SessionLocal() as session:
|
||||
org = Organization(name="Test Org", slug="test-org")
|
||||
@@ -910,8 +922,10 @@ class TestCRUDBaseModelsWithoutSoftDelete:
|
||||
|
||||
# Try to soft delete organization (should fail)
|
||||
async with SessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="does not have a deleted_at column"):
|
||||
await org_crud.soft_delete(session, id=str(org_id))
|
||||
with pytest.raises(
|
||||
InvalidInputError, match="does not have a deleted_at column"
|
||||
):
|
||||
await org_repo.soft_delete(session, id=str(org_id))
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_restore_model_without_deleted_at(self, async_test_db):
|
||||
@@ -919,8 +933,8 @@ class TestCRUDBaseModelsWithoutSoftDelete:
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create an organization (which doesn't have deleted_at)
|
||||
from app.crud.organization import organization as org_crud
|
||||
from app.models.organization import Organization
|
||||
from app.repositories.organization import organization_repo as org_repo
|
||||
|
||||
async with SessionLocal() as session:
|
||||
org = Organization(name="Restore Test", slug="restore-test")
|
||||
@@ -930,11 +944,13 @@ class TestCRUDBaseModelsWithoutSoftDelete:
|
||||
|
||||
# Try to restore organization (should fail)
|
||||
async with SessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="does not have a deleted_at column"):
|
||||
await org_crud.restore(session, id=str(org_id))
|
||||
with pytest.raises(
|
||||
InvalidInputError, match="does not have a deleted_at column"
|
||||
):
|
||||
await org_repo.restore(session, id=str(org_id))
|
||||
|
||||
|
||||
class TestCRUDBaseEagerLoadingWithRealOptions:
|
||||
class TestRepositoryBaseEagerLoadingWithRealOptions:
|
||||
"""
|
||||
Test eager loading with actual SQLAlchemy load options.
|
||||
Covers lines 77-78, 119-120 - options loop execution.
|
||||
@@ -950,8 +966,8 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create a session for the user
|
||||
from app.crud.session import session as session_crud
|
||||
from app.models.user_session import UserSession
|
||||
from app.repositories.session import session_repo as session_repo
|
||||
|
||||
async with SessionLocal() as session:
|
||||
user_session = UserSession(
|
||||
@@ -969,7 +985,7 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
|
||||
|
||||
# Get session with eager loading of user relationship
|
||||
async with SessionLocal() as session:
|
||||
result = await session_crud.get(
|
||||
result = await session_repo.get(
|
||||
session,
|
||||
id=str(session_id),
|
||||
options=[joinedload(UserSession.user)], # Real option, not empty list
|
||||
@@ -989,8 +1005,8 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
# Create multiple sessions for the user
|
||||
from app.crud.session import session as session_crud
|
||||
from app.models.user_session import UserSession
|
||||
from app.repositories.session import session_repo as session_repo
|
||||
|
||||
async with SessionLocal() as session:
|
||||
for i in range(3):
|
||||
@@ -1008,7 +1024,7 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
|
||||
|
||||
# Get sessions with eager loading
|
||||
async with SessionLocal() as session:
|
||||
results = await session_crud.get_multi(
|
||||
results = await session_repo.get_multi(
|
||||
session,
|
||||
skip=0,
|
||||
limit=10,
|
||||
@@ -1,6 +1,6 @@
|
||||
# tests/crud/test_base_db_failures.py
|
||||
# tests/repositories/test_base_db_failures.py
|
||||
"""
|
||||
Comprehensive tests for base CRUD database failure scenarios.
|
||||
Comprehensive tests for base repository database failure scenarios.
|
||||
Tests exception handling, rollbacks, and error messages.
|
||||
"""
|
||||
|
||||
@@ -10,16 +10,17 @@ from uuid import uuid4
|
||||
import pytest
|
||||
from sqlalchemy.exc import DataError, OperationalError
|
||||
|
||||
from app.crud.user import user as user_crud
|
||||
from app.core.repository_exceptions import IntegrityConstraintError
|
||||
from app.repositories.user import user_repo as user_repo
|
||||
from app.schemas.users import UserCreate
|
||||
|
||||
|
||||
class TestBaseCRUDCreateFailures:
|
||||
"""Test base CRUD create method exception handling."""
|
||||
class TestBaseRepositoryCreateFailures:
|
||||
"""Test base repository create method exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_operational_error_triggers_rollback(self, async_test_db):
|
||||
"""Test that OperationalError triggers rollback (User CRUD catches as Exception)."""
|
||||
"""Test that OperationalError triggers rollback (User repository catches as Exception)."""
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
@@ -40,16 +41,16 @@ class TestBaseCRUDCreateFailures:
|
||||
last_name="User",
|
||||
)
|
||||
|
||||
# User CRUD catches this as generic Exception and re-raises
|
||||
# User repository catches this as generic Exception and re-raises
|
||||
with pytest.raises(OperationalError):
|
||||
await user_crud.create(session, obj_in=user_data)
|
||||
await user_repo.create(session, obj_in=user_data)
|
||||
|
||||
# Verify rollback was called
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_data_error_triggers_rollback(self, async_test_db):
|
||||
"""Test that DataError triggers rollback (User CRUD catches as Exception)."""
|
||||
"""Test that DataError triggers rollback (User repository catches as Exception)."""
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
@@ -68,9 +69,9 @@ class TestBaseCRUDCreateFailures:
|
||||
last_name="User",
|
||||
)
|
||||
|
||||
# User CRUD catches this as generic Exception and re-raises
|
||||
# User repository catches this as generic Exception and re-raises
|
||||
with pytest.raises(DataError):
|
||||
await user_crud.create(session, obj_in=user_data)
|
||||
await user_repo.create(session, obj_in=user_data)
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
@@ -96,13 +97,13 @@ class TestBaseCRUDCreateFailures:
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Unexpected database error"):
|
||||
await user_crud.create(session, obj_in=user_data)
|
||||
await user_repo.create(session, obj_in=user_data)
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
|
||||
class TestBaseCRUDUpdateFailures:
|
||||
"""Test base CRUD update method exception handling."""
|
||||
class TestBaseRepositoryUpdateFailures:
|
||||
"""Test base repository update method exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_operational_error(self, async_test_db, async_test_user):
|
||||
@@ -110,7 +111,7 @@ class TestBaseCRUDUpdateFailures:
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
user = await user_crud.get(session, id=str(async_test_user.id))
|
||||
user = await user_repo.get(session, id=str(async_test_user.id))
|
||||
|
||||
async def mock_commit():
|
||||
raise OperationalError("Connection timeout", {}, Exception("Timeout"))
|
||||
@@ -119,8 +120,10 @@ class TestBaseCRUDUpdateFailures:
|
||||
with patch.object(
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
with pytest.raises(ValueError, match="Database operation failed"):
|
||||
await user_crud.update(
|
||||
with pytest.raises(
|
||||
IntegrityConstraintError, match="Database operation failed"
|
||||
):
|
||||
await user_repo.update(
|
||||
session, db_obj=user, obj_in={"first_name": "Updated"}
|
||||
)
|
||||
|
||||
@@ -132,7 +135,7 @@ class TestBaseCRUDUpdateFailures:
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
user = await user_crud.get(session, id=str(async_test_user.id))
|
||||
user = await user_repo.get(session, id=str(async_test_user.id))
|
||||
|
||||
async def mock_commit():
|
||||
raise DataError("Invalid data", {}, Exception("Data type mismatch"))
|
||||
@@ -141,8 +144,10 @@ class TestBaseCRUDUpdateFailures:
|
||||
with patch.object(
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
with pytest.raises(ValueError, match="Database operation failed"):
|
||||
await user_crud.update(
|
||||
with pytest.raises(
|
||||
IntegrityConstraintError, match="Database operation failed"
|
||||
):
|
||||
await user_repo.update(
|
||||
session, db_obj=user, obj_in={"first_name": "Updated"}
|
||||
)
|
||||
|
||||
@@ -154,7 +159,7 @@ class TestBaseCRUDUpdateFailures:
|
||||
_test_engine, SessionLocal = async_test_db
|
||||
|
||||
async with SessionLocal() as session:
|
||||
user = await user_crud.get(session, id=str(async_test_user.id))
|
||||
user = await user_repo.get(session, id=str(async_test_user.id))
|
||||
|
||||
async def mock_commit():
|
||||
raise KeyError("Unexpected error")
|
||||
@@ -164,15 +169,15 @@ class TestBaseCRUDUpdateFailures:
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
with pytest.raises(KeyError):
|
||||
await user_crud.update(
|
||||
await user_repo.update(
|
||||
session, db_obj=user, obj_in={"first_name": "Updated"}
|
||||
)
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
|
||||
class TestBaseCRUDRemoveFailures:
|
||||
"""Test base CRUD remove method exception handling."""
|
||||
class TestBaseRepositoryRemoveFailures:
|
||||
"""Test base repository remove method exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_unexpected_error_triggers_rollback(
|
||||
@@ -191,12 +196,12 @@ class TestBaseCRUDRemoveFailures:
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
with pytest.raises(RuntimeError, match="Database write failed"):
|
||||
await user_crud.remove(session, id=str(async_test_user.id))
|
||||
await user_repo.remove(session, id=str(async_test_user.id))
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
|
||||
class TestBaseCRUDGetMultiWithTotalFailures:
|
||||
class TestBaseRepositoryGetMultiWithTotalFailures:
|
||||
"""Test get_multi_with_total exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -212,10 +217,10 @@ class TestBaseCRUDGetMultiWithTotalFailures:
|
||||
|
||||
with patch.object(session, "execute", side_effect=mock_execute):
|
||||
with pytest.raises(OperationalError):
|
||||
await user_crud.get_multi_with_total(session, skip=0, limit=10)
|
||||
await user_repo.get_multi_with_total(session, skip=0, limit=10)
|
||||
|
||||
|
||||
class TestBaseCRUDCountFailures:
|
||||
class TestBaseRepositoryCountFailures:
|
||||
"""Test count method exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -230,10 +235,10 @@ class TestBaseCRUDCountFailures:
|
||||
|
||||
with patch.object(session, "execute", side_effect=mock_execute):
|
||||
with pytest.raises(OperationalError):
|
||||
await user_crud.count(session)
|
||||
await user_repo.count(session)
|
||||
|
||||
|
||||
class TestBaseCRUDSoftDeleteFailures:
|
||||
class TestBaseRepositorySoftDeleteFailures:
|
||||
"""Test soft_delete method exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -253,12 +258,12 @@ class TestBaseCRUDSoftDeleteFailures:
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
with pytest.raises(RuntimeError, match="Soft delete failed"):
|
||||
await user_crud.soft_delete(session, id=str(async_test_user.id))
|
||||
await user_repo.soft_delete(session, id=str(async_test_user.id))
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
|
||||
class TestBaseCRUDRestoreFailures:
|
||||
class TestBaseRepositoryRestoreFailures:
|
||||
"""Test restore method exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -274,12 +279,12 @@ class TestBaseCRUDRestoreFailures:
|
||||
first_name="Restore",
|
||||
last_name="Test",
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
user = await user_repo.create(session, obj_in=user_data)
|
||||
user_id = user.id
|
||||
await session.commit()
|
||||
|
||||
async with SessionLocal() as session:
|
||||
await user_crud.soft_delete(session, id=str(user_id))
|
||||
await user_repo.soft_delete(session, id=str(user_id))
|
||||
|
||||
# Now test restore failure
|
||||
async with SessionLocal() as session:
|
||||
@@ -292,12 +297,12 @@ class TestBaseCRUDRestoreFailures:
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
with pytest.raises(RuntimeError, match="Restore failed"):
|
||||
await user_crud.restore(session, id=str(user_id))
|
||||
await user_repo.restore(session, id=str(user_id))
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
|
||||
class TestBaseCRUDGetFailures:
|
||||
class TestBaseRepositoryGetFailures:
|
||||
"""Test get method exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -312,10 +317,10 @@ class TestBaseCRUDGetFailures:
|
||||
|
||||
with patch.object(session, "execute", side_effect=mock_execute):
|
||||
with pytest.raises(OperationalError):
|
||||
await user_crud.get(session, id=str(uuid4()))
|
||||
await user_repo.get(session, id=str(uuid4()))
|
||||
|
||||
|
||||
class TestBaseCRUDGetMultiFailures:
|
||||
class TestBaseRepositoryGetMultiFailures:
|
||||
"""Test get_multi method exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -330,4 +335,4 @@ class TestBaseCRUDGetMultiFailures:
|
||||
|
||||
with patch.object(session, "execute", side_effect=mock_execute):
|
||||
with pytest.raises(OperationalError):
|
||||
await user_crud.get_multi(session, skip=0, limit=10)
|
||||
await user_repo.get_multi(session, skip=0, limit=10)
|
||||
@@ -1,18 +1,21 @@
|
||||
# tests/crud/test_oauth.py
|
||||
# tests/repositories/test_oauth.py
|
||||
"""
|
||||
Comprehensive tests for OAuth CRUD operations.
|
||||
Comprehensive tests for OAuth repository operations.
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from app.crud.oauth import oauth_account, oauth_client, oauth_state
|
||||
from app.core.repository_exceptions import DuplicateEntryError
|
||||
from app.repositories.oauth_account import oauth_account_repo as oauth_account
|
||||
from app.repositories.oauth_client import oauth_client_repo as oauth_client
|
||||
from app.repositories.oauth_state import oauth_state_repo as oauth_state
|
||||
from app.schemas.oauth import OAuthAccountCreate, OAuthClientCreate, OAuthStateCreate
|
||||
|
||||
|
||||
class TestOAuthAccountCRUD:
|
||||
"""Tests for OAuth account CRUD operations."""
|
||||
class TestOAuthAccountRepository:
|
||||
"""Tests for OAuth account repository operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_account(self, async_test_db, async_test_user):
|
||||
@@ -60,7 +63,8 @@ class TestOAuthAccountCRUD:
|
||||
|
||||
# SQLite returns different error message than PostgreSQL
|
||||
with pytest.raises(
|
||||
ValueError, match="(already linked|UNIQUE constraint failed)"
|
||||
DuplicateEntryError,
|
||||
match="(already linked|UNIQUE constraint failed|Failed to create)",
|
||||
):
|
||||
await oauth_account.create_account(session, obj_in=account_data2)
|
||||
|
||||
@@ -256,17 +260,17 @@ class TestOAuthAccountCRUD:
|
||||
updated = await oauth_account.update_tokens(
|
||||
session,
|
||||
account=account,
|
||||
access_token_encrypted="new_access_token",
|
||||
refresh_token_encrypted="new_refresh_token",
|
||||
access_token="new_access_token",
|
||||
refresh_token="new_refresh_token",
|
||||
token_expires_at=new_expires,
|
||||
)
|
||||
|
||||
assert updated.access_token_encrypted == "new_access_token"
|
||||
assert updated.refresh_token_encrypted == "new_refresh_token"
|
||||
assert updated.access_token == "new_access_token"
|
||||
assert updated.refresh_token == "new_refresh_token"
|
||||
|
||||
|
||||
class TestOAuthStateCRUD:
|
||||
"""Tests for OAuth state CRUD operations."""
|
||||
class TestOAuthStateRepository:
|
||||
"""Tests for OAuth state repository operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_state(self, async_test_db):
|
||||
@@ -372,8 +376,8 @@ class TestOAuthStateCRUD:
|
||||
assert result is not None
|
||||
|
||||
|
||||
class TestOAuthClientCRUD:
|
||||
"""Tests for OAuth client CRUD operations (provider mode)."""
|
||||
class TestOAuthClientRepository:
|
||||
"""Tests for OAuth client repository operations (provider mode)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_public_client(self, async_test_db):
|
||||
@@ -1,6 +1,6 @@
|
||||
# tests/crud/test_organization_async.py
|
||||
# tests/repositories/test_organization_async.py
|
||||
"""
|
||||
Comprehensive tests for async organization CRUD operations.
|
||||
Comprehensive tests for async organization repository operations.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
@@ -9,9 +9,10 @@ from uuid import uuid4
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.crud.organization import organization as organization_crud
|
||||
from app.core.repository_exceptions import DuplicateEntryError, IntegrityConstraintError
|
||||
from app.models.organization import Organization
|
||||
from app.models.user_organization import OrganizationRole, UserOrganization
|
||||
from app.repositories.organization import organization_repo as organization_repo
|
||||
from app.schemas.organizations import OrganizationCreate
|
||||
|
||||
|
||||
@@ -34,7 +35,7 @@ class TestGetBySlug:
|
||||
|
||||
# Get by slug
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await organization_crud.get_by_slug(session, slug="test-org")
|
||||
result = await organization_repo.get_by_slug(session, slug="test-org")
|
||||
assert result is not None
|
||||
assert result.id == org_id
|
||||
assert result.slug == "test-org"
|
||||
@@ -45,7 +46,7 @@ class TestGetBySlug:
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await organization_crud.get_by_slug(session, slug="nonexistent")
|
||||
result = await organization_repo.get_by_slug(session, slug="nonexistent")
|
||||
assert result is None
|
||||
|
||||
|
||||
@@ -54,7 +55,7 @@ class TestCreate:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_success(self, async_test_db):
|
||||
"""Test successfully creating an organization_crud."""
|
||||
"""Test successfully creating an organization_repo."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -65,7 +66,7 @@ class TestCreate:
|
||||
is_active=True,
|
||||
settings={"key": "value"},
|
||||
)
|
||||
result = await organization_crud.create(session, obj_in=org_in)
|
||||
result = await organization_repo.create(session, obj_in=org_in)
|
||||
|
||||
assert result.name == "New Org"
|
||||
assert result.slug == "new-org"
|
||||
@@ -87,8 +88,8 @@ class TestCreate:
|
||||
# Try to create second with same slug
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org_in = OrganizationCreate(name="Org 2", slug="duplicate-slug")
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
await organization_crud.create(session, obj_in=org_in)
|
||||
with pytest.raises(DuplicateEntryError, match="already exists"):
|
||||
await organization_repo.create(session, obj_in=org_in)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_without_settings(self, async_test_db):
|
||||
@@ -97,7 +98,7 @@ class TestCreate:
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
org_in = OrganizationCreate(name="No Settings Org", slug="no-settings")
|
||||
result = await organization_crud.create(session, obj_in=org_in)
|
||||
result = await organization_repo.create(session, obj_in=org_in)
|
||||
|
||||
assert result.settings == {}
|
||||
|
||||
@@ -118,7 +119,7 @@ class TestGetMultiWithFilters:
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
orgs, total = await organization_crud.get_multi_with_filters(session)
|
||||
orgs, total = await organization_repo.get_multi_with_filters(session)
|
||||
assert total == 5
|
||||
assert len(orgs) == 5
|
||||
|
||||
@@ -134,7 +135,7 @@ class TestGetMultiWithFilters:
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
orgs, total = await organization_crud.get_multi_with_filters(
|
||||
orgs, total = await organization_repo.get_multi_with_filters(
|
||||
session, is_active=True
|
||||
)
|
||||
assert total == 1
|
||||
@@ -156,7 +157,7 @@ class TestGetMultiWithFilters:
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
orgs, total = await organization_crud.get_multi_with_filters(
|
||||
orgs, total = await organization_repo.get_multi_with_filters(
|
||||
session, search="tech"
|
||||
)
|
||||
assert total == 1
|
||||
@@ -174,7 +175,7 @@ class TestGetMultiWithFilters:
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
orgs, total = await organization_crud.get_multi_with_filters(
|
||||
orgs, total = await organization_repo.get_multi_with_filters(
|
||||
session, skip=2, limit=3
|
||||
)
|
||||
assert total == 10
|
||||
@@ -192,7 +193,7 @@ class TestGetMultiWithFilters:
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
orgs, _total = await organization_crud.get_multi_with_filters(
|
||||
orgs, _total = await organization_repo.get_multi_with_filters(
|
||||
session, sort_by="name", sort_order="asc"
|
||||
)
|
||||
assert orgs[0].name == "A Org"
|
||||
@@ -204,7 +205,7 @@ class TestGetMemberCount:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_member_count_success(self, async_test_db, async_test_user):
|
||||
"""Test getting member count for organization_crud."""
|
||||
"""Test getting member count for organization_repo."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -224,7 +225,7 @@ class TestGetMemberCount:
|
||||
org_id = org.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await organization_crud.get_member_count(
|
||||
count = await organization_repo.get_member_count(
|
||||
session, organization_id=org_id
|
||||
)
|
||||
assert count == 1
|
||||
@@ -241,7 +242,7 @@ class TestGetMemberCount:
|
||||
org_id = org.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await organization_crud.get_member_count(
|
||||
count = await organization_repo.get_member_count(
|
||||
session, organization_id=org_id
|
||||
)
|
||||
assert count == 0
|
||||
@@ -252,7 +253,7 @@ class TestAddUser:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_user_success(self, async_test_db, async_test_user):
|
||||
"""Test successfully adding a user to organization_crud."""
|
||||
"""Test successfully adding a user to organization_repo."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -262,7 +263,7 @@ class TestAddUser:
|
||||
org_id = org.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await organization_crud.add_user(
|
||||
result = await organization_repo.add_user(
|
||||
session,
|
||||
organization_id=org_id,
|
||||
user_id=async_test_user.id,
|
||||
@@ -295,8 +296,8 @@ class TestAddUser:
|
||||
org_id = org.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="already a member"):
|
||||
await organization_crud.add_user(
|
||||
with pytest.raises(DuplicateEntryError, match="already a member"):
|
||||
await organization_repo.add_user(
|
||||
session, organization_id=org_id, user_id=async_test_user.id
|
||||
)
|
||||
|
||||
@@ -321,7 +322,7 @@ class TestAddUser:
|
||||
org_id = org.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await organization_crud.add_user(
|
||||
result = await organization_repo.add_user(
|
||||
session,
|
||||
organization_id=org_id,
|
||||
user_id=async_test_user.id,
|
||||
@@ -337,7 +338,7 @@ class TestRemoveUser:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_user_success(self, async_test_db, async_test_user):
|
||||
"""Test successfully removing a user from organization_crud."""
|
||||
"""Test successfully removing a user from organization_repo."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -356,7 +357,7 @@ class TestRemoveUser:
|
||||
org_id = org.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await organization_crud.remove_user(
|
||||
result = await organization_repo.remove_user(
|
||||
session, organization_id=org_id, user_id=async_test_user.id
|
||||
)
|
||||
|
||||
@@ -384,7 +385,7 @@ class TestRemoveUser:
|
||||
org_id = org.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await organization_crud.remove_user(
|
||||
result = await organization_repo.remove_user(
|
||||
session, organization_id=org_id, user_id=uuid4()
|
||||
)
|
||||
|
||||
@@ -415,7 +416,7 @@ class TestUpdateUserRole:
|
||||
org_id = org.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await organization_crud.update_user_role(
|
||||
result = await organization_repo.update_user_role(
|
||||
session,
|
||||
organization_id=org_id,
|
||||
user_id=async_test_user.id,
|
||||
@@ -438,7 +439,7 @@ class TestUpdateUserRole:
|
||||
org_id = org.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await organization_crud.update_user_role(
|
||||
result = await organization_repo.update_user_role(
|
||||
session,
|
||||
organization_id=org_id,
|
||||
user_id=uuid4(),
|
||||
@@ -474,7 +475,7 @@ class TestGetOrganizationMembers:
|
||||
org_id = org.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
members, total = await organization_crud.get_organization_members(
|
||||
members, total = await organization_repo.get_organization_members(
|
||||
session, organization_id=org_id
|
||||
)
|
||||
|
||||
@@ -507,7 +508,7 @@ class TestGetOrganizationMembers:
|
||||
org_id = org.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
members, total = await organization_crud.get_organization_members(
|
||||
members, total = await organization_repo.get_organization_members(
|
||||
session, organization_id=org_id, skip=0, limit=10
|
||||
)
|
||||
|
||||
@@ -538,7 +539,7 @@ class TestGetUserOrganizations:
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
orgs = await organization_crud.get_user_organizations(
|
||||
orgs = await organization_repo.get_user_organizations(
|
||||
session, user_id=async_test_user.id
|
||||
)
|
||||
|
||||
@@ -574,7 +575,7 @@ class TestGetUserOrganizations:
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
orgs = await organization_crud.get_user_organizations(
|
||||
orgs = await organization_repo.get_user_organizations(
|
||||
session, user_id=async_test_user.id, is_active=True
|
||||
)
|
||||
|
||||
@@ -587,7 +588,7 @@ class TestGetUserRole:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_role_in_org_success(self, async_test_db, async_test_user):
|
||||
"""Test getting user role in organization_crud."""
|
||||
"""Test getting user role in organization_repo."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -606,7 +607,7 @@ class TestGetUserRole:
|
||||
org_id = org.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
role = await organization_crud.get_user_role_in_org(
|
||||
role = await organization_repo.get_user_role_in_org(
|
||||
session, user_id=async_test_user.id, organization_id=org_id
|
||||
)
|
||||
|
||||
@@ -624,7 +625,7 @@ class TestGetUserRole:
|
||||
org_id = org.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
role = await organization_crud.get_user_role_in_org(
|
||||
role = await organization_repo.get_user_role_in_org(
|
||||
session, user_id=uuid4(), organization_id=org_id
|
||||
)
|
||||
|
||||
@@ -655,7 +656,7 @@ class TestIsUserOrgOwner:
|
||||
org_id = org.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
is_owner = await organization_crud.is_user_org_owner(
|
||||
is_owner = await organization_repo.is_user_org_owner(
|
||||
session, user_id=async_test_user.id, organization_id=org_id
|
||||
)
|
||||
|
||||
@@ -682,7 +683,7 @@ class TestIsUserOrgOwner:
|
||||
org_id = org.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
is_owner = await organization_crud.is_user_org_owner(
|
||||
is_owner = await organization_repo.is_user_org_owner(
|
||||
session, user_id=async_test_user.id, organization_id=org_id
|
||||
)
|
||||
|
||||
@@ -719,7 +720,7 @@ class TestGetMultiWithMemberCounts:
|
||||
(
|
||||
orgs_with_counts,
|
||||
total,
|
||||
) = await organization_crud.get_multi_with_member_counts(session)
|
||||
) = await organization_repo.get_multi_with_member_counts(session)
|
||||
|
||||
assert total == 2
|
||||
assert len(orgs_with_counts) == 2
|
||||
@@ -744,7 +745,7 @@ class TestGetMultiWithMemberCounts:
|
||||
(
|
||||
orgs_with_counts,
|
||||
total,
|
||||
) = await organization_crud.get_multi_with_member_counts(
|
||||
) = await organization_repo.get_multi_with_member_counts(
|
||||
session, is_active=True
|
||||
)
|
||||
|
||||
@@ -766,7 +767,7 @@ class TestGetMultiWithMemberCounts:
|
||||
(
|
||||
orgs_with_counts,
|
||||
total,
|
||||
) = await organization_crud.get_multi_with_member_counts(
|
||||
) = await organization_repo.get_multi_with_member_counts(
|
||||
session, search="tech"
|
||||
)
|
||||
|
||||
@@ -800,7 +801,7 @@ class TestGetUserOrganizationsWithDetails:
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
orgs_with_details = (
|
||||
await organization_crud.get_user_organizations_with_details(
|
||||
await organization_repo.get_user_organizations_with_details(
|
||||
session, user_id=async_test_user.id
|
||||
)
|
||||
)
|
||||
@@ -840,7 +841,7 @@ class TestGetUserOrganizationsWithDetails:
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
orgs_with_details = (
|
||||
await organization_crud.get_user_organizations_with_details(
|
||||
await organization_repo.get_user_organizations_with_details(
|
||||
session, user_id=async_test_user.id, is_active=True
|
||||
)
|
||||
)
|
||||
@@ -873,7 +874,7 @@ class TestIsUserOrgAdmin:
|
||||
org_id = org.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
is_admin = await organization_crud.is_user_org_admin(
|
||||
is_admin = await organization_repo.is_user_org_admin(
|
||||
session, user_id=async_test_user.id, organization_id=org_id
|
||||
)
|
||||
|
||||
@@ -900,7 +901,7 @@ class TestIsUserOrgAdmin:
|
||||
org_id = org.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
is_admin = await organization_crud.is_user_org_admin(
|
||||
is_admin = await organization_repo.is_user_org_admin(
|
||||
session, user_id=async_test_user.id, organization_id=org_id
|
||||
)
|
||||
|
||||
@@ -927,7 +928,7 @@ class TestIsUserOrgAdmin:
|
||||
org_id = org.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
is_admin = await organization_crud.is_user_org_admin(
|
||||
is_admin = await organization_repo.is_user_org_admin(
|
||||
session, user_id=async_test_user.id, organization_id=org_id
|
||||
)
|
||||
|
||||
@@ -936,7 +937,7 @@ class TestIsUserOrgAdmin:
|
||||
|
||||
class TestOrganizationExceptionHandlers:
|
||||
"""
|
||||
Test exception handlers in organization CRUD methods.
|
||||
Test exception handlers in organization repository methods.
|
||||
Uses mocks to trigger database errors and verify proper error handling.
|
||||
Covers lines: 33-35, 57-62, 114-116, 130-132, 207-209, 258-260, 291-294, 326-329, 385-387, 409-411, 466-468, 491-493
|
||||
"""
|
||||
@@ -951,7 +952,7 @@ class TestOrganizationExceptionHandlers:
|
||||
session, "execute", side_effect=Exception("Database connection lost")
|
||||
):
|
||||
with pytest.raises(Exception, match="Database connection lost"):
|
||||
await organization_crud.get_by_slug(session, slug="test-slug")
|
||||
await organization_repo.get_by_slug(session, slug="test-slug")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_integrity_error_non_slug(self, async_test_db):
|
||||
@@ -972,8 +973,10 @@ class TestOrganizationExceptionHandlers:
|
||||
with patch.object(session, "commit", side_effect=mock_commit):
|
||||
with patch.object(session, "rollback", new_callable=AsyncMock):
|
||||
org_in = OrganizationCreate(name="Test", slug="test")
|
||||
with pytest.raises(ValueError, match="Database integrity error"):
|
||||
await organization_crud.create(session, obj_in=org_in)
|
||||
with pytest.raises(
|
||||
IntegrityConstraintError, match="Database integrity error"
|
||||
):
|
||||
await organization_repo.create(session, obj_in=org_in)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_unexpected_error(self, async_test_db):
|
||||
@@ -987,7 +990,7 @@ class TestOrganizationExceptionHandlers:
|
||||
with patch.object(session, "rollback", new_callable=AsyncMock):
|
||||
org_in = OrganizationCreate(name="Test", slug="test")
|
||||
with pytest.raises(RuntimeError, match="Unexpected error"):
|
||||
await organization_crud.create(session, obj_in=org_in)
|
||||
await organization_repo.create(session, obj_in=org_in)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_filters_database_error(self, async_test_db):
|
||||
@@ -999,7 +1002,7 @@ class TestOrganizationExceptionHandlers:
|
||||
session, "execute", side_effect=Exception("Query timeout")
|
||||
):
|
||||
with pytest.raises(Exception, match="Query timeout"):
|
||||
await organization_crud.get_multi_with_filters(session)
|
||||
await organization_repo.get_multi_with_filters(session)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_member_count_database_error(self, async_test_db):
|
||||
@@ -1013,7 +1016,7 @@ class TestOrganizationExceptionHandlers:
|
||||
session, "execute", side_effect=Exception("Count query failed")
|
||||
):
|
||||
with pytest.raises(Exception, match="Count query failed"):
|
||||
await organization_crud.get_member_count(
|
||||
await organization_repo.get_member_count(
|
||||
session, organization_id=uuid4()
|
||||
)
|
||||
|
||||
@@ -1027,7 +1030,7 @@ class TestOrganizationExceptionHandlers:
|
||||
session, "execute", side_effect=Exception("Complex query failed")
|
||||
):
|
||||
with pytest.raises(Exception, match="Complex query failed"):
|
||||
await organization_crud.get_multi_with_member_counts(session)
|
||||
await organization_repo.get_multi_with_member_counts(session)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_user_integrity_error(self, async_test_db, async_test_user):
|
||||
@@ -1058,9 +1061,10 @@ class TestOrganizationExceptionHandlers:
|
||||
with patch.object(session, "commit", side_effect=mock_commit):
|
||||
with patch.object(session, "rollback", new_callable=AsyncMock):
|
||||
with pytest.raises(
|
||||
ValueError, match="Failed to add user to organization"
|
||||
IntegrityConstraintError,
|
||||
match="Failed to add user to organization",
|
||||
):
|
||||
await organization_crud.add_user(
|
||||
await organization_repo.add_user(
|
||||
session,
|
||||
organization_id=org_id,
|
||||
user_id=async_test_user.id,
|
||||
@@ -1078,7 +1082,7 @@ class TestOrganizationExceptionHandlers:
|
||||
session, "execute", side_effect=Exception("Delete failed")
|
||||
):
|
||||
with pytest.raises(Exception, match="Delete failed"):
|
||||
await organization_crud.remove_user(
|
||||
await organization_repo.remove_user(
|
||||
session, organization_id=uuid4(), user_id=async_test_user.id
|
||||
)
|
||||
|
||||
@@ -1096,7 +1100,7 @@ class TestOrganizationExceptionHandlers:
|
||||
session, "execute", side_effect=Exception("Update failed")
|
||||
):
|
||||
with pytest.raises(Exception, match="Update failed"):
|
||||
await organization_crud.update_user_role(
|
||||
await organization_repo.update_user_role(
|
||||
session,
|
||||
organization_id=uuid4(),
|
||||
user_id=async_test_user.id,
|
||||
@@ -1115,7 +1119,7 @@ class TestOrganizationExceptionHandlers:
|
||||
session, "execute", side_effect=Exception("Members query failed")
|
||||
):
|
||||
with pytest.raises(Exception, match="Members query failed"):
|
||||
await organization_crud.get_organization_members(
|
||||
await organization_repo.get_organization_members(
|
||||
session, organization_id=uuid4()
|
||||
)
|
||||
|
||||
@@ -1131,7 +1135,7 @@ class TestOrganizationExceptionHandlers:
|
||||
session, "execute", side_effect=Exception("User orgs query failed")
|
||||
):
|
||||
with pytest.raises(Exception, match="User orgs query failed"):
|
||||
await organization_crud.get_user_organizations(
|
||||
await organization_repo.get_user_organizations(
|
||||
session, user_id=async_test_user.id
|
||||
)
|
||||
|
||||
@@ -1147,7 +1151,7 @@ class TestOrganizationExceptionHandlers:
|
||||
session, "execute", side_effect=Exception("Details query failed")
|
||||
):
|
||||
with pytest.raises(Exception, match="Details query failed"):
|
||||
await organization_crud.get_user_organizations_with_details(
|
||||
await organization_repo.get_user_organizations_with_details(
|
||||
session, user_id=async_test_user.id
|
||||
)
|
||||
|
||||
@@ -1165,6 +1169,6 @@ class TestOrganizationExceptionHandlers:
|
||||
session, "execute", side_effect=Exception("Role query failed")
|
||||
):
|
||||
with pytest.raises(Exception, match="Role query failed"):
|
||||
await organization_crud.get_user_role_in_org(
|
||||
await organization_repo.get_user_role_in_org(
|
||||
session, user_id=async_test_user.id, organization_id=uuid4()
|
||||
)
|
||||
@@ -1,6 +1,6 @@
|
||||
# tests/crud/test_session_async.py
|
||||
# tests/repositories/test_session_async.py
|
||||
"""
|
||||
Comprehensive tests for async session CRUD operations.
|
||||
Comprehensive tests for async session repository operations.
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
@@ -8,8 +8,9 @@ from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.crud.session import session as session_crud
|
||||
from app.core.repository_exceptions import InvalidInputError
|
||||
from app.models.user_session import UserSession
|
||||
from app.repositories.session import session_repo as session_repo
|
||||
from app.schemas.sessions import SessionCreate
|
||||
|
||||
|
||||
@@ -36,7 +37,7 @@ class TestGetByJti:
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session_crud.get_by_jti(session, jti="test_jti_123")
|
||||
result = await session_repo.get_by_jti(session, jti="test_jti_123")
|
||||
assert result is not None
|
||||
assert result.refresh_token_jti == "test_jti_123"
|
||||
|
||||
@@ -46,7 +47,7 @@ class TestGetByJti:
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session_crud.get_by_jti(session, jti="nonexistent")
|
||||
result = await session_repo.get_by_jti(session, jti="nonexistent")
|
||||
assert result is None
|
||||
|
||||
|
||||
@@ -73,7 +74,7 @@ class TestGetActiveByJti:
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session_crud.get_active_by_jti(session, jti="active_jti")
|
||||
result = await session_repo.get_active_by_jti(session, jti="active_jti")
|
||||
assert result is not None
|
||||
assert result.is_active is True
|
||||
|
||||
@@ -97,7 +98,7 @@ class TestGetActiveByJti:
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session_crud.get_active_by_jti(session, jti="inactive_jti")
|
||||
result = await session_repo.get_active_by_jti(session, jti="inactive_jti")
|
||||
assert result is None
|
||||
|
||||
|
||||
@@ -134,7 +135,7 @@ class TestGetUserSessions:
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
results = await session_crud.get_user_sessions(
|
||||
results = await session_repo.get_user_sessions(
|
||||
session, user_id=str(async_test_user.id), active_only=True
|
||||
)
|
||||
assert len(results) == 1
|
||||
@@ -161,7 +162,7 @@ class TestGetUserSessions:
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
results = await session_crud.get_user_sessions(
|
||||
results = await session_repo.get_user_sessions(
|
||||
session, user_id=str(async_test_user.id), active_only=False
|
||||
)
|
||||
assert len(results) == 3
|
||||
@@ -172,7 +173,7 @@ class TestCreateSession:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_session_success(self, async_test_db, async_test_user):
|
||||
"""Test successfully creating a session_crud."""
|
||||
"""Test successfully creating a session_repo."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -188,7 +189,7 @@ class TestCreateSession:
|
||||
location_city="San Francisco",
|
||||
location_country="USA",
|
||||
)
|
||||
result = await session_crud.create_session(session, obj_in=session_data)
|
||||
result = await session_repo.create_session(session, obj_in=session_data)
|
||||
|
||||
assert result.user_id == async_test_user.id
|
||||
assert result.refresh_token_jti == "new_jti"
|
||||
@@ -201,7 +202,7 @@ class TestDeactivate:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deactivate_success(self, async_test_db, async_test_user):
|
||||
"""Test successfully deactivating a session_crud."""
|
||||
"""Test successfully deactivating a session_repo."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -220,7 +221,7 @@ class TestDeactivate:
|
||||
session_id = user_session.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session_crud.deactivate(session, session_id=str(session_id))
|
||||
result = await session_repo.deactivate(session, session_id=str(session_id))
|
||||
assert result is not None
|
||||
assert result.is_active is False
|
||||
|
||||
@@ -230,7 +231,7 @@ class TestDeactivate:
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await session_crud.deactivate(session, session_id=str(uuid4()))
|
||||
result = await session_repo.deactivate(session, session_id=str(uuid4()))
|
||||
assert result is None
|
||||
|
||||
|
||||
@@ -261,7 +262,7 @@ class TestDeactivateAllUserSessions:
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await session_crud.deactivate_all_user_sessions(
|
||||
count = await session_repo.deactivate_all_user_sessions(
|
||||
session, user_id=str(async_test_user.id)
|
||||
)
|
||||
assert count == 2
|
||||
@@ -291,7 +292,7 @@ class TestUpdateLastUsed:
|
||||
await session.refresh(user_session)
|
||||
|
||||
old_time = user_session.last_used_at
|
||||
result = await session_crud.update_last_used(session, session=user_session)
|
||||
result = await session_repo.update_last_used(session, session=user_session)
|
||||
|
||||
assert result.last_used_at > old_time
|
||||
|
||||
@@ -320,7 +321,7 @@ class TestGetUserSessionCount:
|
||||
await session.commit()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await session_crud.get_user_session_count(
|
||||
count = await session_repo.get_user_session_count(
|
||||
session, user_id=str(async_test_user.id)
|
||||
)
|
||||
assert count == 3
|
||||
@@ -331,7 +332,7 @@ class TestGetUserSessionCount:
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await session_crud.get_user_session_count(
|
||||
count = await session_repo.get_user_session_count(
|
||||
session, user_id=str(uuid4())
|
||||
)
|
||||
assert count == 0
|
||||
@@ -363,7 +364,7 @@ class TestUpdateRefreshToken:
|
||||
new_jti = "new_jti_123"
|
||||
new_expires = datetime.now(UTC) + timedelta(days=14)
|
||||
|
||||
result = await session_crud.update_refresh_token(
|
||||
result = await session_repo.update_refresh_token(
|
||||
session,
|
||||
session=user_session,
|
||||
new_jti=new_jti,
|
||||
@@ -409,7 +410,7 @@ class TestCleanupExpired:
|
||||
|
||||
# Cleanup
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await session_crud.cleanup_expired(session, keep_days=30)
|
||||
count = await session_repo.cleanup_expired(session, keep_days=30)
|
||||
assert count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -435,7 +436,7 @@ class TestCleanupExpired:
|
||||
|
||||
# Cleanup
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await session_crud.cleanup_expired(session, keep_days=30)
|
||||
count = await session_repo.cleanup_expired(session, keep_days=30)
|
||||
assert count == 0 # Should not delete recent sessions
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -461,7 +462,7 @@ class TestCleanupExpired:
|
||||
|
||||
# Cleanup
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await session_crud.cleanup_expired(session, keep_days=30)
|
||||
count = await session_repo.cleanup_expired(session, keep_days=30)
|
||||
assert count == 0 # Should not delete active sessions
|
||||
|
||||
|
||||
@@ -492,7 +493,7 @@ class TestCleanupExpiredForUser:
|
||||
|
||||
# Cleanup for user
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await session_crud.cleanup_expired_for_user(
|
||||
count = await session_repo.cleanup_expired_for_user(
|
||||
session, user_id=str(async_test_user.id)
|
||||
)
|
||||
assert count == 1
|
||||
@@ -503,8 +504,8 @@ class TestCleanupExpiredForUser:
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with pytest.raises(ValueError, match="Invalid user ID format"):
|
||||
await session_crud.cleanup_expired_for_user(
|
||||
with pytest.raises(InvalidInputError, match="Invalid user ID format"):
|
||||
await session_repo.cleanup_expired_for_user(
|
||||
session, user_id="not-a-valid-uuid"
|
||||
)
|
||||
|
||||
@@ -532,7 +533,7 @@ class TestCleanupExpiredForUser:
|
||||
|
||||
# Cleanup
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await session_crud.cleanup_expired_for_user(
|
||||
count = await session_repo.cleanup_expired_for_user(
|
||||
session, user_id=str(async_test_user.id)
|
||||
)
|
||||
assert count == 0 # Should not delete active sessions
|
||||
@@ -564,7 +565,7 @@ class TestGetUserSessionsWithUser:
|
||||
|
||||
# Get with user relationship
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
results = await session_crud.get_user_sessions(
|
||||
results = await session_repo.get_user_sessions(
|
||||
session, user_id=str(async_test_user.id), with_user=True
|
||||
)
|
||||
assert len(results) >= 1
|
||||
@@ -1,6 +1,6 @@
|
||||
# tests/crud/test_session_db_failures.py
|
||||
# tests/repositories/test_session_db_failures.py
|
||||
"""
|
||||
Comprehensive tests for session CRUD database failure scenarios.
|
||||
Comprehensive tests for session repository database failure scenarios.
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
@@ -10,12 +10,13 @@ from uuid import uuid4
|
||||
import pytest
|
||||
from sqlalchemy.exc import OperationalError
|
||||
|
||||
from app.crud.session import session as session_crud
|
||||
from app.core.repository_exceptions import IntegrityConstraintError
|
||||
from app.models.user_session import UserSession
|
||||
from app.repositories.session import session_repo as session_repo
|
||||
from app.schemas.sessions import SessionCreate
|
||||
|
||||
|
||||
class TestSessionCRUDGetByJtiFailures:
|
||||
class TestSessionRepositoryGetByJtiFailures:
|
||||
"""Test get_by_jti exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -30,10 +31,10 @@ class TestSessionCRUDGetByJtiFailures:
|
||||
|
||||
with patch.object(session, "execute", side_effect=mock_execute):
|
||||
with pytest.raises(OperationalError):
|
||||
await session_crud.get_by_jti(session, jti="test_jti")
|
||||
await session_repo.get_by_jti(session, jti="test_jti")
|
||||
|
||||
|
||||
class TestSessionCRUDGetActiveByJtiFailures:
|
||||
class TestSessionRepositoryGetActiveByJtiFailures:
|
||||
"""Test get_active_by_jti exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -48,10 +49,10 @@ class TestSessionCRUDGetActiveByJtiFailures:
|
||||
|
||||
with patch.object(session, "execute", side_effect=mock_execute):
|
||||
with pytest.raises(OperationalError):
|
||||
await session_crud.get_active_by_jti(session, jti="test_jti")
|
||||
await session_repo.get_active_by_jti(session, jti="test_jti")
|
||||
|
||||
|
||||
class TestSessionCRUDGetUserSessionsFailures:
|
||||
class TestSessionRepositoryGetUserSessionsFailures:
|
||||
"""Test get_user_sessions exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -68,12 +69,12 @@ class TestSessionCRUDGetUserSessionsFailures:
|
||||
|
||||
with patch.object(session, "execute", side_effect=mock_execute):
|
||||
with pytest.raises(OperationalError):
|
||||
await session_crud.get_user_sessions(
|
||||
await session_repo.get_user_sessions(
|
||||
session, user_id=str(async_test_user.id)
|
||||
)
|
||||
|
||||
|
||||
class TestSessionCRUDCreateSessionFailures:
|
||||
class TestSessionRepositoryCreateSessionFailures:
|
||||
"""Test create_session exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -102,8 +103,10 @@ class TestSessionCRUDCreateSessionFailures:
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Failed to create session"):
|
||||
await session_crud.create_session(session, obj_in=session_data)
|
||||
with pytest.raises(
|
||||
IntegrityConstraintError, match="Failed to create session"
|
||||
):
|
||||
await session_repo.create_session(session, obj_in=session_data)
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
@@ -133,13 +136,15 @@ class TestSessionCRUDCreateSessionFailures:
|
||||
last_used_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Failed to create session"):
|
||||
await session_crud.create_session(session, obj_in=session_data)
|
||||
with pytest.raises(
|
||||
IntegrityConstraintError, match="Failed to create session"
|
||||
):
|
||||
await session_repo.create_session(session, obj_in=session_data)
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
|
||||
class TestSessionCRUDDeactivateFailures:
|
||||
class TestSessionRepositoryDeactivateFailures:
|
||||
"""Test deactivate exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -177,14 +182,14 @@ class TestSessionCRUDDeactivateFailures:
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
with pytest.raises(OperationalError):
|
||||
await session_crud.deactivate(
|
||||
await session_repo.deactivate(
|
||||
session, session_id=str(session_id)
|
||||
)
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
|
||||
class TestSessionCRUDDeactivateAllFailures:
|
||||
class TestSessionRepositoryDeactivateAllFailures:
|
||||
"""Test deactivate_all_user_sessions exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -204,14 +209,14 @@ class TestSessionCRUDDeactivateAllFailures:
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
with pytest.raises(OperationalError):
|
||||
await session_crud.deactivate_all_user_sessions(
|
||||
await session_repo.deactivate_all_user_sessions(
|
||||
session, user_id=str(async_test_user.id)
|
||||
)
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
|
||||
class TestSessionCRUDUpdateLastUsedFailures:
|
||||
class TestSessionRepositoryUpdateLastUsedFailures:
|
||||
"""Test update_last_used exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -254,12 +259,12 @@ class TestSessionCRUDUpdateLastUsedFailures:
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
with pytest.raises(OperationalError):
|
||||
await session_crud.update_last_used(session, session=sess)
|
||||
await session_repo.update_last_used(session, session=sess)
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
|
||||
class TestSessionCRUDUpdateRefreshTokenFailures:
|
||||
class TestSessionRepositoryUpdateRefreshTokenFailures:
|
||||
"""Test update_refresh_token exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -302,7 +307,7 @@ class TestSessionCRUDUpdateRefreshTokenFailures:
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
with pytest.raises(OperationalError):
|
||||
await session_crud.update_refresh_token(
|
||||
await session_repo.update_refresh_token(
|
||||
session,
|
||||
session=sess,
|
||||
new_jti=str(uuid4()),
|
||||
@@ -312,7 +317,7 @@ class TestSessionCRUDUpdateRefreshTokenFailures:
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
|
||||
class TestSessionCRUDCleanupExpiredFailures:
|
||||
class TestSessionRepositoryCleanupExpiredFailures:
|
||||
"""Test cleanup_expired exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -332,12 +337,12 @@ class TestSessionCRUDCleanupExpiredFailures:
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
with pytest.raises(OperationalError):
|
||||
await session_crud.cleanup_expired(session, keep_days=30)
|
||||
await session_repo.cleanup_expired(session, keep_days=30)
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
|
||||
class TestSessionCRUDCleanupExpiredForUserFailures:
|
||||
class TestSessionRepositoryCleanupExpiredForUserFailures:
|
||||
"""Test cleanup_expired_for_user exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -357,14 +362,14 @@ class TestSessionCRUDCleanupExpiredForUserFailures:
|
||||
session, "rollback", new_callable=AsyncMock
|
||||
) as mock_rollback:
|
||||
with pytest.raises(OperationalError):
|
||||
await session_crud.cleanup_expired_for_user(
|
||||
await session_repo.cleanup_expired_for_user(
|
||||
session, user_id=str(async_test_user.id)
|
||||
)
|
||||
|
||||
mock_rollback.assert_called_once()
|
||||
|
||||
|
||||
class TestSessionCRUDGetUserSessionCountFailures:
|
||||
class TestSessionRepositoryGetUserSessionCountFailures:
|
||||
"""Test get_user_session_count exception handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -381,6 +386,6 @@ class TestSessionCRUDGetUserSessionCountFailures:
|
||||
|
||||
with patch.object(session, "execute", side_effect=mock_execute):
|
||||
with pytest.raises(OperationalError):
|
||||
await session_crud.get_user_session_count(
|
||||
await session_repo.get_user_session_count(
|
||||
session, user_id=str(async_test_user.id)
|
||||
)
|
||||
@@ -1,11 +1,12 @@
|
||||
# tests/crud/test_user_async.py
|
||||
# tests/repositories/test_user_async.py
|
||||
"""
|
||||
Comprehensive tests for async user CRUD operations.
|
||||
Comprehensive tests for async user repository operations.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.crud.user import user as user_crud
|
||||
from app.core.repository_exceptions import DuplicateEntryError, InvalidInputError
|
||||
from app.repositories.user import user_repo as user_repo
|
||||
from app.schemas.users import UserCreate, UserUpdate
|
||||
|
||||
|
||||
@@ -18,7 +19,7 @@ class TestGetByEmail:
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await user_crud.get_by_email(session, email=async_test_user.email)
|
||||
result = await user_repo.get_by_email(session, email=async_test_user.email)
|
||||
assert result is not None
|
||||
assert result.email == async_test_user.email
|
||||
assert result.id == async_test_user.id
|
||||
@@ -29,7 +30,7 @@ class TestGetByEmail:
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await user_crud.get_by_email(
|
||||
result = await user_repo.get_by_email(
|
||||
session, email="nonexistent@example.com"
|
||||
)
|
||||
assert result is None
|
||||
@@ -40,7 +41,7 @@ class TestCreate:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_user_success(self, async_test_db):
|
||||
"""Test successfully creating a user_crud."""
|
||||
"""Test successfully creating a user_repo."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -51,7 +52,7 @@ class TestCreate:
|
||||
last_name="User",
|
||||
phone_number="+1234567890",
|
||||
)
|
||||
result = await user_crud.create(session, obj_in=user_data)
|
||||
result = await user_repo.create(session, obj_in=user_data)
|
||||
|
||||
assert result.email == "newuser@example.com"
|
||||
assert result.first_name == "New"
|
||||
@@ -75,7 +76,7 @@ class TestCreate:
|
||||
last_name="User",
|
||||
is_superuser=True,
|
||||
)
|
||||
result = await user_crud.create(session, obj_in=user_data)
|
||||
result = await user_repo.create(session, obj_in=user_data)
|
||||
|
||||
assert result.is_superuser is True
|
||||
assert result.email == "superuser@example.com"
|
||||
@@ -93,8 +94,8 @@ class TestCreate:
|
||||
last_name="User",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await user_crud.create(session, obj_in=user_data)
|
||||
with pytest.raises(DuplicateEntryError) as exc_info:
|
||||
await user_repo.create(session, obj_in=user_data)
|
||||
|
||||
assert "already exists" in str(exc_info.value).lower()
|
||||
|
||||
@@ -109,12 +110,12 @@ class TestUpdate:
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Get fresh copy of user
|
||||
user = await user_crud.get(session, id=str(async_test_user.id))
|
||||
user = await user_repo.get(session, id=str(async_test_user.id))
|
||||
|
||||
update_data = UserUpdate(
|
||||
first_name="Updated", last_name="Name", phone_number="+9876543210"
|
||||
)
|
||||
result = await user_crud.update(session, db_obj=user, obj_in=update_data)
|
||||
result = await user_repo.update(session, db_obj=user, obj_in=update_data)
|
||||
|
||||
assert result.first_name == "Updated"
|
||||
assert result.last_name == "Name"
|
||||
@@ -133,16 +134,16 @@ class TestUpdate:
|
||||
first_name="Pass",
|
||||
last_name="Test",
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
user = await user_repo.create(session, obj_in=user_data)
|
||||
user_id = user.id
|
||||
old_password_hash = user.password_hash
|
||||
|
||||
# Update the password
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user = await user_crud.get(session, id=str(user_id))
|
||||
user = await user_repo.get(session, id=str(user_id))
|
||||
|
||||
update_data = UserUpdate(password="NewDifferentPassword123!")
|
||||
result = await user_crud.update(session, db_obj=user, obj_in=update_data)
|
||||
result = await user_repo.update(session, db_obj=user, obj_in=update_data)
|
||||
|
||||
await session.refresh(result)
|
||||
assert result.password_hash != old_password_hash
|
||||
@@ -157,10 +158,10 @@ class TestUpdate:
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user = await user_crud.get(session, id=str(async_test_user.id))
|
||||
user = await user_repo.get(session, id=str(async_test_user.id))
|
||||
|
||||
update_dict = {"first_name": "DictUpdate"}
|
||||
result = await user_crud.update(session, db_obj=user, obj_in=update_dict)
|
||||
result = await user_repo.update(session, db_obj=user, obj_in=update_dict)
|
||||
|
||||
assert result.first_name == "DictUpdate"
|
||||
|
||||
@@ -174,7 +175,7 @@ class TestGetMultiWithTotal:
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
users, total = await user_crud.get_multi_with_total(
|
||||
users, total = await user_repo.get_multi_with_total(
|
||||
session, skip=0, limit=10
|
||||
)
|
||||
assert total >= 1
|
||||
@@ -195,10 +196,10 @@ class TestGetMultiWithTotal:
|
||||
first_name=f"User{i}",
|
||||
last_name="Test",
|
||||
)
|
||||
await user_crud.create(session, obj_in=user_data)
|
||||
await user_repo.create(session, obj_in=user_data)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
users, _total = await user_crud.get_multi_with_total(
|
||||
users, _total = await user_repo.get_multi_with_total(
|
||||
session, skip=0, limit=10, sort_by="email", sort_order="asc"
|
||||
)
|
||||
|
||||
@@ -221,10 +222,10 @@ class TestGetMultiWithTotal:
|
||||
first_name=f"User{i}",
|
||||
last_name="Test",
|
||||
)
|
||||
await user_crud.create(session, obj_in=user_data)
|
||||
await user_repo.create(session, obj_in=user_data)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
users, _total = await user_crud.get_multi_with_total(
|
||||
users, _total = await user_repo.get_multi_with_total(
|
||||
session, skip=0, limit=10, sort_by="email", sort_order="desc"
|
||||
)
|
||||
|
||||
@@ -246,7 +247,7 @@ class TestGetMultiWithTotal:
|
||||
first_name="Active",
|
||||
last_name="User",
|
||||
)
|
||||
await user_crud.create(session, obj_in=active_user)
|
||||
await user_repo.create(session, obj_in=active_user)
|
||||
|
||||
inactive_user = UserCreate(
|
||||
email="inactive@example.com",
|
||||
@@ -254,15 +255,15 @@ class TestGetMultiWithTotal:
|
||||
first_name="Inactive",
|
||||
last_name="User",
|
||||
)
|
||||
created_inactive = await user_crud.create(session, obj_in=inactive_user)
|
||||
created_inactive = await user_repo.create(session, obj_in=inactive_user)
|
||||
|
||||
# Deactivate the user
|
||||
await user_crud.update(
|
||||
await user_repo.update(
|
||||
session, db_obj=created_inactive, obj_in={"is_active": False}
|
||||
)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
users, _total = await user_crud.get_multi_with_total(
|
||||
users, _total = await user_repo.get_multi_with_total(
|
||||
session, skip=0, limit=100, filters={"is_active": True}
|
||||
)
|
||||
|
||||
@@ -282,10 +283,10 @@ class TestGetMultiWithTotal:
|
||||
first_name="Searchable",
|
||||
last_name="UserName",
|
||||
)
|
||||
await user_crud.create(session, obj_in=user_data)
|
||||
await user_repo.create(session, obj_in=user_data)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
users, total = await user_crud.get_multi_with_total(
|
||||
users, total = await user_repo.get_multi_with_total(
|
||||
session, skip=0, limit=100, search="Searchable"
|
||||
)
|
||||
|
||||
@@ -306,16 +307,16 @@ class TestGetMultiWithTotal:
|
||||
first_name=f"Page{i}",
|
||||
last_name="User",
|
||||
)
|
||||
await user_crud.create(session, obj_in=user_data)
|
||||
await user_repo.create(session, obj_in=user_data)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
# Get first page
|
||||
users_page1, total = await user_crud.get_multi_with_total(
|
||||
users_page1, total = await user_repo.get_multi_with_total(
|
||||
session, skip=0, limit=2
|
||||
)
|
||||
|
||||
# Get second page
|
||||
users_page2, total2 = await user_crud.get_multi_with_total(
|
||||
users_page2, total2 = await user_repo.get_multi_with_total(
|
||||
session, skip=2, limit=2
|
||||
)
|
||||
|
||||
@@ -330,8 +331,8 @@ class TestGetMultiWithTotal:
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await user_crud.get_multi_with_total(session, skip=-1, limit=10)
|
||||
with pytest.raises(InvalidInputError) as exc_info:
|
||||
await user_repo.get_multi_with_total(session, skip=-1, limit=10)
|
||||
|
||||
assert "skip must be non-negative" in str(exc_info.value)
|
||||
|
||||
@@ -341,8 +342,8 @@ class TestGetMultiWithTotal:
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await user_crud.get_multi_with_total(session, skip=0, limit=-1)
|
||||
with pytest.raises(InvalidInputError) as exc_info:
|
||||
await user_repo.get_multi_with_total(session, skip=0, limit=-1)
|
||||
|
||||
assert "limit must be non-negative" in str(exc_info.value)
|
||||
|
||||
@@ -352,8 +353,8 @@ class TestGetMultiWithTotal:
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await user_crud.get_multi_with_total(session, skip=0, limit=1001)
|
||||
with pytest.raises(InvalidInputError) as exc_info:
|
||||
await user_repo.get_multi_with_total(session, skip=0, limit=1001)
|
||||
|
||||
assert "Maximum limit is 1000" in str(exc_info.value)
|
||||
|
||||
@@ -376,12 +377,12 @@ class TestBulkUpdateStatus:
|
||||
first_name=f"Bulk{i}",
|
||||
last_name="User",
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
user = await user_repo.create(session, obj_in=user_data)
|
||||
user_ids.append(user.id)
|
||||
|
||||
# Bulk deactivate
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_crud.bulk_update_status(
|
||||
count = await user_repo.bulk_update_status(
|
||||
session, user_ids=user_ids, is_active=False
|
||||
)
|
||||
assert count == 3
|
||||
@@ -389,7 +390,7 @@ class TestBulkUpdateStatus:
|
||||
# Verify all are inactive
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
for user_id in user_ids:
|
||||
user = await user_crud.get(session, id=str(user_id))
|
||||
user = await user_repo.get(session, id=str(user_id))
|
||||
assert user.is_active is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -398,7 +399,7 @@ class TestBulkUpdateStatus:
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_crud.bulk_update_status(
|
||||
count = await user_repo.bulk_update_status(
|
||||
session, user_ids=[], is_active=False
|
||||
)
|
||||
assert count == 0
|
||||
@@ -416,21 +417,21 @@ class TestBulkUpdateStatus:
|
||||
first_name="Reactivate",
|
||||
last_name="User",
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
user = await user_repo.create(session, obj_in=user_data)
|
||||
# Deactivate
|
||||
await user_crud.update(session, db_obj=user, obj_in={"is_active": False})
|
||||
await user_repo.update(session, db_obj=user, obj_in={"is_active": False})
|
||||
user_id = user.id
|
||||
|
||||
# Reactivate
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_crud.bulk_update_status(
|
||||
count = await user_repo.bulk_update_status(
|
||||
session, user_ids=[user_id], is_active=True
|
||||
)
|
||||
assert count == 1
|
||||
|
||||
# Verify active
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user = await user_crud.get(session, id=str(user_id))
|
||||
user = await user_repo.get(session, id=str(user_id))
|
||||
assert user.is_active is True
|
||||
|
||||
|
||||
@@ -452,24 +453,24 @@ class TestBulkSoftDelete:
|
||||
first_name=f"Delete{i}",
|
||||
last_name="User",
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
user = await user_repo.create(session, obj_in=user_data)
|
||||
user_ids.append(user.id)
|
||||
|
||||
# Bulk delete
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_crud.bulk_soft_delete(session, user_ids=user_ids)
|
||||
count = await user_repo.bulk_soft_delete(session, user_ids=user_ids)
|
||||
assert count == 3
|
||||
|
||||
# Verify all are soft deleted
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
for user_id in user_ids:
|
||||
user = await user_crud.get(session, id=str(user_id))
|
||||
user = await user_repo.get(session, id=str(user_id))
|
||||
assert user.deleted_at is not None
|
||||
assert user.is_active is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_soft_delete_with_exclusion(self, async_test_db):
|
||||
"""Test bulk soft delete with excluded user_crud."""
|
||||
"""Test bulk soft delete with excluded user_repo."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create multiple users
|
||||
@@ -482,20 +483,20 @@ class TestBulkSoftDelete:
|
||||
first_name=f"Exclude{i}",
|
||||
last_name="User",
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
user = await user_repo.create(session, obj_in=user_data)
|
||||
user_ids.append(user.id)
|
||||
|
||||
# Bulk delete, excluding first user
|
||||
exclude_id = user_ids[0]
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_crud.bulk_soft_delete(
|
||||
count = await user_repo.bulk_soft_delete(
|
||||
session, user_ids=user_ids, exclude_user_id=exclude_id
|
||||
)
|
||||
assert count == 2 # Only 2 deleted
|
||||
|
||||
# Verify excluded user is NOT deleted
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
excluded_user = await user_crud.get(session, id=str(exclude_id))
|
||||
excluded_user = await user_repo.get(session, id=str(exclude_id))
|
||||
assert excluded_user.deleted_at is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -504,7 +505,7 @@ class TestBulkSoftDelete:
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_crud.bulk_soft_delete(session, user_ids=[])
|
||||
count = await user_repo.bulk_soft_delete(session, user_ids=[])
|
||||
assert count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -520,12 +521,12 @@ class TestBulkSoftDelete:
|
||||
first_name="Only",
|
||||
last_name="User",
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
user = await user_repo.create(session, obj_in=user_data)
|
||||
user_id = user.id
|
||||
|
||||
# Try to delete but exclude
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_crud.bulk_soft_delete(
|
||||
count = await user_repo.bulk_soft_delete(
|
||||
session, user_ids=[user_id], exclude_user_id=user_id
|
||||
)
|
||||
assert count == 0
|
||||
@@ -543,15 +544,15 @@ class TestBulkSoftDelete:
|
||||
first_name="PreDeleted",
|
||||
last_name="User",
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
user = await user_repo.create(session, obj_in=user_data)
|
||||
user_id = user.id
|
||||
|
||||
# First deletion
|
||||
await user_crud.bulk_soft_delete(session, user_ids=[user_id])
|
||||
await user_repo.bulk_soft_delete(session, user_ids=[user_id])
|
||||
|
||||
# Try to delete again
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await user_crud.bulk_soft_delete(session, user_ids=[user_id])
|
||||
count = await user_repo.bulk_soft_delete(session, user_ids=[user_id])
|
||||
assert count == 0 # Already deleted
|
||||
|
||||
|
||||
@@ -560,16 +561,16 @@ class TestUtilityMethods:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_active_true(self, async_test_db, async_test_user):
|
||||
"""Test is_active returns True for active user_crud."""
|
||||
"""Test is_active returns True for active user_repo."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user = await user_crud.get(session, id=str(async_test_user.id))
|
||||
assert user_crud.is_active(user) is True
|
||||
user = await user_repo.get(session, id=str(async_test_user.id))
|
||||
assert user_repo.is_active(user) is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_active_false(self, async_test_db):
|
||||
"""Test is_active returns False for inactive user_crud."""
|
||||
"""Test is_active returns False for inactive user_repo."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
@@ -579,10 +580,10 @@ class TestUtilityMethods:
|
||||
first_name="Inactive",
|
||||
last_name="User",
|
||||
)
|
||||
user = await user_crud.create(session, obj_in=user_data)
|
||||
await user_crud.update(session, db_obj=user, obj_in={"is_active": False})
|
||||
user = await user_repo.create(session, obj_in=user_data)
|
||||
await user_repo.update(session, db_obj=user, obj_in={"is_active": False})
|
||||
|
||||
assert user_crud.is_active(user) is False
|
||||
assert user_repo.is_active(user) is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_superuser_true(self, async_test_db, async_test_superuser):
|
||||
@@ -590,22 +591,22 @@ class TestUtilityMethods:
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user = await user_crud.get(session, id=str(async_test_superuser.id))
|
||||
assert user_crud.is_superuser(user) is True
|
||||
user = await user_repo.get(session, id=str(async_test_superuser.id))
|
||||
assert user_repo.is_superuser(user) is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_superuser_false(self, async_test_db, async_test_user):
|
||||
"""Test is_superuser returns False for regular user_crud."""
|
||||
"""Test is_superuser returns False for regular user_repo."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user = await user_crud.get(session, id=str(async_test_user.id))
|
||||
assert user_crud.is_superuser(user) is False
|
||||
user = await user_repo.get(session, id=str(async_test_user.id))
|
||||
assert user_repo.is_superuser(user) is False
|
||||
|
||||
|
||||
class TestUserExceptionHandlers:
|
||||
"""
|
||||
Test exception handlers in user CRUD methods.
|
||||
Test exception handlers in user repository methods.
|
||||
Covers lines: 30-32, 205-208, 257-260
|
||||
"""
|
||||
|
||||
@@ -621,7 +622,7 @@ class TestUserExceptionHandlers:
|
||||
session, "execute", side_effect=Exception("Database query failed")
|
||||
):
|
||||
with pytest.raises(Exception, match="Database query failed"):
|
||||
await user_crud.get_by_email(session, email="test@example.com")
|
||||
await user_repo.get_by_email(session, email="test@example.com")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_update_status_database_error(
|
||||
@@ -639,7 +640,7 @@ class TestUserExceptionHandlers:
|
||||
):
|
||||
with patch.object(session, "rollback", new_callable=AsyncMock):
|
||||
with pytest.raises(Exception, match="Bulk update failed"):
|
||||
await user_crud.bulk_update_status(
|
||||
await user_repo.bulk_update_status(
|
||||
session, user_ids=[async_test_user.id], is_active=False
|
||||
)
|
||||
|
||||
@@ -659,6 +660,6 @@ class TestUserExceptionHandlers:
|
||||
):
|
||||
with patch.object(session, "rollback", new_callable=AsyncMock):
|
||||
with pytest.raises(Exception, match="Bulk delete failed"):
|
||||
await user_crud.bulk_soft_delete(
|
||||
await user_repo.bulk_soft_delete(
|
||||
session, user_ids=[async_test_user.id]
|
||||
)
|
||||
@@ -10,6 +10,7 @@ from app.core.auth import (
|
||||
get_password_hash,
|
||||
verify_password,
|
||||
)
|
||||
from app.core.exceptions import DuplicateError
|
||||
from app.models.user import User
|
||||
from app.schemas.users import Token, UserCreate
|
||||
from app.services.auth_service import AuthenticationError, AuthService
|
||||
@@ -152,9 +153,9 @@ class TestAuthServiceUserCreation:
|
||||
last_name="User",
|
||||
)
|
||||
|
||||
# Should raise AuthenticationError
|
||||
# Should raise DuplicateError for duplicate email
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with pytest.raises(AuthenticationError):
|
||||
with pytest.raises(DuplicateError):
|
||||
await AuthService.create_user(db=session, user_data=user_data)
|
||||
|
||||
|
||||
|
||||
@@ -269,18 +269,18 @@ class TestClientValidation:
|
||||
async def test_validate_client_legacy_sha256_hash(
|
||||
self, db, confidential_client_legacy_hash
|
||||
):
|
||||
"""Test validating a client with legacy SHA-256 hash (backward compatibility)."""
|
||||
"""Test that legacy SHA-256 hash is rejected with clear error message."""
|
||||
client, secret = confidential_client_legacy_hash
|
||||
validated = await service.validate_client(db, client.client_id, secret)
|
||||
assert validated.client_id == client.client_id
|
||||
with pytest.raises(service.InvalidClientError, match="deprecated hash format"):
|
||||
await service.validate_client(db, client.client_id, secret)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_client_legacy_sha256_wrong_secret(
|
||||
self, db, confidential_client_legacy_hash
|
||||
):
|
||||
"""Test legacy SHA-256 client rejects wrong secret."""
|
||||
"""Test that legacy SHA-256 client with wrong secret is rejected."""
|
||||
client, _ = confidential_client_legacy_hash
|
||||
with pytest.raises(service.InvalidClientError, match="Invalid client secret"):
|
||||
with pytest.raises(service.InvalidClientError, match="deprecated hash format"):
|
||||
await service.validate_client(db, client.client_id, "wrong_secret")
|
||||
|
||||
def test_validate_redirect_uri_success(self, public_client):
|
||||
|
||||
@@ -11,7 +11,8 @@ from uuid import uuid4
|
||||
import pytest
|
||||
|
||||
from app.core.exceptions import AuthenticationError
|
||||
from app.crud.oauth import oauth_account, oauth_state
|
||||
from app.repositories.oauth_account import oauth_account_repo as oauth_account
|
||||
from app.repositories.oauth_state import oauth_state_repo as oauth_state
|
||||
from app.schemas.oauth import OAuthAccountCreate, OAuthStateCreate
|
||||
from app.services.oauth_service import OAUTH_PROVIDERS, OAuthService
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user