23 Commits
dev ... dev

Author SHA1 Message Date
Felipe Cardoso
a94e29d99c chore(frontend): remove unnecessary newline in overrides field of package.json 2026-03-01 19:40:11 +01:00
Felipe Cardoso
81e48c73ca fix(tests): handle missing schemathesis gracefully in API contract tests
- Replaced `pytest.mark.skipif` with `pytest.skip` to better manage scenarios where `schemathesis` is not installed.
- Added a fallback test function to ensure explicit handling for missing dependencies.
2026-03-01 19:32:49 +01:00
Felipe Cardoso
a3f78dc801 refactor(tests): replace crud references with repo across repository test files
- Updated import statements and test logic to align with `repositories` naming changes.
- Adjusted documentation and test names for consistency with the updated naming convention.
- Improved test descriptions to reflect the repository-based structure.
2026-03-01 19:22:16 +01:00
Felipe Cardoso
07309013d7 chore(frontend): update scripts and docs to use bun run test for consistency
- Replaced `bun test` with `bun run test` in all documentation and scripts for uniformity.
- Removed outdated `glob` override in package configurations.
2026-03-01 18:44:48 +01:00
Felipe Cardoso
846fc31190 feat(api): enhance KeyMap and FieldsConfig handling for improved flexibility
- Added support for unmapped fields in `KeyMap` definitions and parsing.
- Updated `buildKeyMap` to allow aliasing keys without transport layer mappings.
- Improved parameter assignment logic to handle optional `in` mappings.
- Enhanced handling of `allowExtra` fields for more concise and robust configurations.
2026-03-01 18:01:34 +01:00
Felipe Cardoso
ff7a67cb58 chore(frontend): migrate from npm to Bun for dependency management and scripts
- Updated README to replace npm commands with Bun equivalents.
- Added `bun.lock` file to track Bun-managed dependencies.
2026-03-01 18:00:43 +01:00
Felipe Cardoso
0760a8284d feat(tests): add comprehensive benchmarks for auth and performance-critical endpoints
- Introduced benchmarks for password hashing, verification, and JWT token operations.
- Added latency tests for `/register`, `/refresh`, `/sessions`, and `/users/me` endpoints.
- Updated `BENCHMARKS.md` with new tests, thresholds, and execution details.
2026-03-01 17:01:44 +01:00
Felipe Cardoso
ce4d0c7b0d feat(backend): enhance performance benchmarking with baseline detection and documentation
- Updated `make benchmark-check` in Makefile to detect and handle missing baselines, creating them if not found.
- Added `.benchmarks` directory to `.gitignore` for local baseline exclusions.
- Linked benchmarking documentation in `ARCHITECTURE.md` and added comprehensive `BENCHMARKS.md` guide.
2026-03-01 16:30:06 +01:00
Felipe Cardoso
4ceb8ad98c feat(backend): add performance benchmarks and API security tests
- Introduced `benchmark`, `benchmark-save`, and `benchmark-check` Makefile targets for performance testing.
- Added API security fuzzing through the `test-api-security` Makefile target, leveraging Schemathesis.
- Updated Dockerfiles to use Alpine for security and CVE mitigation.
- Enhanced security with `scan-image` and `scan-images` targets for Docker image vulnerability scanning via Trivy.
- Integrated `pytest-benchmark` for performance regression detection, with tests for key API endpoints.
- Extended `uv.lock` and `pyproject.toml` to include performance benchmarking dependencies.
2026-03-01 16:16:18 +01:00
Felipe Cardoso
f8aafb250d fix(backend): suppress license-check output in Makefile for cleaner logs
- Redirect pip-licenses output to `/dev/null` to reduce noise during license checks.
- Retain success and compliance messages for clear feedback.
2026-03-01 14:24:22 +01:00
Felipe Cardoso
4385d20ca6 fix(tests): simplify invalid token test logic in test_auth_security.py
- Removed unnecessary try-except block for JWT encoding failures.
- Adjusted test to directly verify `TokenInvalidError` during decoding.
- Clarified comment on HMAC algorithm compatibility (`HS384` vs. `HS256`).
2026-03-01 14:24:17 +01:00
Felipe Cardoso
1a36907f10 refactor(backend): replace python-jose and passlib with PyJWT and bcrypt for security and simplicity
- Migrated JWT token handling from `python-jose` to `PyJWT`, reducing dependencies and improving error clarity.
- Replaced `passlib` bcrypt integration with direct `bcrypt` usage for password hashing.
- Updated `Makefile`, removing unused CVE ignore based on the replaced dependencies.
- Reflected changes in `ARCHITECTURE.md` and adjusted function headers in `auth.py`.
- Cleaned up `uv.lock` and `pyproject.toml` to remove unused dependencies (`ecdsa`, `rsa`, etc.) and add `PyJWT`.
- Refactored tests and services to align with the updated libraries (`PyJWT` error handling, decoding, and validation).
2026-03-01 14:02:04 +01:00
Felipe Cardoso
0553a1fc53 refactor(logging): switch to parameterized logging for improved performance and clarity
- Replaced f-strings with parameterized logging calls across routes, services, and repositories to optimize log message evaluation.
- Improved exception handling by using `logger.exception` where appropriate for automatic traceback logging.
2026-03-01 13:38:15 +01:00
Felipe Cardoso
57e969ed67 chore(backend): extend Makefile with audit, validation, and security targets
- Added `dep-audit`, `license-check`, `audit`, `validate-all`, and `check` targets for security and quality checks.
- Updated `.PHONY` to include new targets.
- Enhanced `help` command documentation with descriptions of the new commands.
- Updated `ARCHITECTURE.md`, `CLAUDE.md`, and `uv.lock` to reflect related changes. Upgraded dependencies where necessary.
2026-03-01 12:03:34 +01:00
Felipe Cardoso
68275b1dd3 refactor(docs): update architecture to reflect repository migration
- Rename CRUD layer to Repository layer throughout architecture documentation.
- Update dependency injection examples to use repository classes.
- Add async SQLAlchemy pattern for Repository methods (`select()` and transactions).
- Replace CRUD references in FEATURE_EXAMPLE.md with Repository-focused implementation details.
- Highlight repository class responsibilities and remove outdated CRUD patterns.
2026-03-01 11:13:51 +01:00
Felipe Cardoso
80d2dc0cb2 fix(backend): clear VIRTUAL_ENV before invoking pyright
Prevents a spurious warning when the shell's VIRTUAL_ENV points to a
different project's venv. Pyright detects the mismatch and warns; clearing
the variable inline forces pyright to resolve the venv from pyrightconfig.json.
2026-02-28 19:48:33 +01:00
Felipe Cardoso
a8aa416ecb refactor(backend): migrate type checking from mypy to pyright
Replace mypy>=1.8.0 with pyright>=1.1.390. Remove all [tool.mypy] and
[tool.pydantic-mypy] sections from pyproject.toml and add
pyrightconfig.json (standard mode, SQLAlchemy false-positive rules
suppressed globally).

Fixes surfaced by pyright:
- Remove unreachable except AuthError clauses in login/login_oauth (same class as AuthenticationError)
- Fix Pydantic v2 list Field: min_items/max_items → min_length/max_length
- Split OAuthProviderConfig TypedDict into required + optional(email_url) inheritance
- Move JWTError/ExpiredSignatureError from lazy try-block imports to module level
- Add timezone-aware guard to UserSession.is_expired to match sibling models
- Fix is_active: bool → bool | None in three organization repo signatures
- Initialize search_filter = None before conditional block (possibly unbound fix)
- Add bool() casts to model is_expired and repo is_active/is_superuser returns
- Restructure except (JWTError, Exception) into separate except clauses
2026-02-28 19:12:40 +01:00
Felipe Cardoso
4c6bf55bcc Refactor(backend): improve formatting in services, repositories & tests
- Consistently format multi-line function headers, exception handling, and repository method calls for readability.
- Reorganize misplaced imports across modules (e.g., services & tests) into proper sorted order.
- Adjust indentation, line breaks, and spacing inconsistencies in tests and migration files.
- Cleanup unnecessary trailing newlines and reorganize `__all__` declarations for consistency.
2026-02-28 18:37:56 +01:00
Felipe Cardoso
98b455fdc3 refactor(backend): enforce route→service→repo layered architecture
- introduce custom repository exception hierarchy (DuplicateEntryError,
  IntegrityConstraintError, InvalidInputError) replacing raw ValueError
- eliminate all direct repository imports and raw SQL from route layer
- add UserService, SessionService, OrganizationService to service layer
- add get_stats/get_org_distribution service methods replacing admin inline SQL
- fix timing side-channel in authenticate_user via dummy bcrypt check
- replace SHA-256 client secret fallback with explicit InvalidClientError
- replace assert with InvalidGrantError in authorization code exchange
- replace N+1 token revocation loops with bulk UPDATE statements
- rename oauth account token fields (drop misleading 'encrypted' suffix)
- add Alembic migration 0003 for token field column rename
- add 45 new service/repository tests; 975 passing, 94% coverage
2026-02-27 09:32:57 +01:00
Felipe Cardoso
0646c96b19 Add semicolons to mockServiceWorker.js for consistent style compliance
- Updated `mockServiceWorker.js` by adding missing semicolons across the file for improved code consistency and adherence to style guidelines.
- Refactored multi-line logical expressions into single-line where applicable, maintaining readability.
2026-01-01 13:21:31 +01:00
Felipe Cardoso
62afb328fe Upgrade dependencies in package-lock.json
- Upgraded various dependencies across `@esbuild`, `@eslint`, `@hey-api`, and `@img` packages to their latest versions.
- Removed unused `json5` dependency under `@babel/core`.
- Ensured integrity hashes are updated to reflect changes.
2026-01-01 13:21:23 +01:00
Felipe Cardoso
b9a746bc16 Refactor component props formatting for consistency in extends usage across UI and documentation files 2026-01-01 13:19:36 +01:00
Felipe Cardoso
de8e18e97d Update GitHub repository URLs across components and tests
- Replaced all occurrences of the previous repository URL (`your-org/fast-next-template`) with the updated repository URL (`cardosofelipe/pragma-stack.git`) in both frontend components and test files.
- Adjusted related test assertions and documentation links accordingly.
2026-01-01 13:15:08 +01:00
139 changed files with 10064 additions and 22857 deletions

View File

@@ -41,7 +41,7 @@ To enable CI/CD workflows:
- Runs on: Push to main/develop, PRs affecting frontend code - Runs on: Push to main/develop, PRs affecting frontend code
- Tests: Frontend unit tests (Jest) - Tests: Frontend unit tests (Jest)
- Coverage: Uploads to Codecov - Coverage: Uploads to Codecov
- Fast: Uses npm cache - Fast: Uses bun cache
### `e2e-tests.yml` ### `e2e-tests.yml`
- Runs on: All pushes and PRs - Runs on: All pushes and PRs

2
.gitignore vendored
View File

@@ -187,7 +187,7 @@ coverage.xml
.hypothesis/ .hypothesis/
.pytest_cache/ .pytest_cache/
cover/ cover/
backend/.benchmarks
# Translations # Translations
*.mo *.mo
*.pot *.pot

View File

@@ -13,10 +13,10 @@ uv run uvicorn app.main:app --reload # Start dev server
# Frontend (Node.js) # Frontend (Node.js)
cd frontend cd frontend
npm install # Install dependencies bun install # Install dependencies
npm run dev # Start dev server bun run dev # Start dev server
npm run generate:api # Generate API client from OpenAPI bun run generate:api # Generate API client from OpenAPI
npm run test:e2e # Run E2E tests bun run test:e2e # Run E2E tests
``` ```
**Access points:** **Access points:**
@@ -37,7 +37,7 @@ Default superuser (change in production):
│ ├── app/ │ ├── app/
│ │ ├── api/ # API routes (auth, users, organizations, admin) │ │ ├── api/ # API routes (auth, users, organizations, admin)
│ │ ├── core/ # Core functionality (auth, config, database) │ │ ├── core/ # Core functionality (auth, config, database)
│ │ ├── crud/ # Database CRUD operations │ │ ├── repositories/ # Repository pattern (database operations)
│ │ ├── models/ # SQLAlchemy ORM models │ │ ├── models/ # SQLAlchemy ORM models
│ │ ├── schemas/ # Pydantic request/response schemas │ │ ├── schemas/ # Pydantic request/response schemas
│ │ ├── services/ # Business logic layer │ │ ├── services/ # Business logic layer
@@ -113,7 +113,7 @@ OAUTH_ISSUER=https://api.yourdomain.com # JWT issuer URL (must be HTTPS in
### Database Pattern ### Database Pattern
- **Async SQLAlchemy 2.0** with PostgreSQL - **Async SQLAlchemy 2.0** with PostgreSQL
- **Connection pooling**: 20 base connections, 50 max overflow - **Connection pooling**: 20 base connections, 50 max overflow
- **CRUD base class**: `crud/base.py` with common operations - **Repository base class**: `repositories/base.py` with common operations
- **Migrations**: Alembic with helper script `migrate.py` - **Migrations**: Alembic with helper script `migrate.py`
- `python migrate.py auto "message"` - Generate and apply - `python migrate.py auto "message"` - Generate and apply
- `python migrate.py list` - View history - `python migrate.py list` - View history
@@ -121,7 +121,7 @@ OAUTH_ISSUER=https://api.yourdomain.com # JWT issuer URL (must be HTTPS in
### Frontend State Management ### Frontend State Management
- **Zustand stores**: Lightweight state management - **Zustand stores**: Lightweight state management
- **TanStack Query**: API data fetching/caching - **TanStack Query**: API data fetching/caching
- **Auto-generated client**: From OpenAPI spec via `npm run generate:api` - **Auto-generated client**: From OpenAPI spec via `bun run generate:api`
- **Dependency Injection**: ALWAYS use `useAuth()` from `AuthContext`, NEVER import `useAuthStore` directly - **Dependency Injection**: ALWAYS use `useAuth()` from `AuthContext`, NEVER import `useAuthStore` directly
### Internationalization (i18n) ### Internationalization (i18n)
@@ -165,21 +165,25 @@ Permission dependencies in `api/dependencies/permissions.py`:
**Frontend Unit Tests (Jest):** **Frontend Unit Tests (Jest):**
- 97% coverage - 97% coverage
- Component, hook, and utility testing - Component, hook, and utility testing
- Run: `npm test` - Run: `bun run test`
- Coverage: `npm run test:coverage` - Coverage: `bun run test:coverage`
**Frontend E2E Tests (Playwright):** **Frontend E2E Tests (Playwright):**
- 56 passing, 1 skipped (zero flaky tests) - 56 passing, 1 skipped (zero flaky tests)
- Complete user flows (auth, navigation, settings) - Complete user flows (auth, navigation, settings)
- Run: `npm run test:e2e` - Run: `bun run test:e2e`
- UI mode: `npm run test:e2e:ui` - UI mode: `bun run test:e2e:ui`
### Development Tooling ### Development Tooling
**Backend:** **Backend:**
- **uv**: Modern Python package manager (10-100x faster than pip) - **uv**: Modern Python package manager (10-100x faster than pip)
- **Ruff**: All-in-one linting/formatting (replaces Black, Flake8, isort) - **Ruff**: All-in-one linting/formatting (replaces Black, Flake8, isort)
- **mypy**: Type checking with Pydantic plugin - **Pyright**: Static type checking (strict mode)
- **pip-audit**: Dependency vulnerability scanning (OSV database)
- **detect-secrets**: Hardcoded secrets detection
- **pip-licenses**: License compliance checking
- **pre-commit**: Git hook framework (Ruff, detect-secrets, standard checks)
- **Makefile**: `make help` for all commands - **Makefile**: `make help` for all commands
**Frontend:** **Frontend:**
@@ -218,11 +222,11 @@ NEXT_PUBLIC_API_URL=http://localhost:8000/api/v1
### Adding a New API Endpoint ### Adding a New API Endpoint
1. **Define schema** in `backend/app/schemas/` 1. **Define schema** in `backend/app/schemas/`
2. **Create CRUD operations** in `backend/app/crud/` 2. **Create repository** in `backend/app/repositories/`
3. **Implement route** in `backend/app/api/routes/` 3. **Implement route** in `backend/app/api/routes/`
4. **Register router** in `backend/app/api/main.py` 4. **Register router** in `backend/app/api/main.py`
5. **Write tests** in `backend/tests/api/` 5. **Write tests** in `backend/tests/api/`
6. **Generate frontend client**: `npm run generate:api` 6. **Generate frontend client**: `bun run generate:api`
### Database Migrations ### Database Migrations
@@ -239,7 +243,7 @@ python migrate.py auto "description" # Generate + apply
2. **Follow design system** (see `frontend/docs/design-system/`) 2. **Follow design system** (see `frontend/docs/design-system/`)
3. **Use dependency injection** for auth (`useAuth()` not `useAuthStore`) 3. **Use dependency injection** for auth (`useAuth()` not `useAuthStore`)
4. **Write tests** in `frontend/tests/` or `__tests__/` 4. **Write tests** in `frontend/tests/` or `__tests__/`
5. **Run type check**: `npm run type-check` 5. **Run type check**: `bun run type-check`
## Security Features ## Security Features
@@ -249,6 +253,10 @@ python migrate.py auto "description" # Generate + apply
- **CSRF protection**: Built into FastAPI - **CSRF protection**: Built into FastAPI
- **Session revocation**: Database-backed session tracking - **Session revocation**: Database-backed session tracking
- **Comprehensive security tests**: JWT algorithm attacks, session hijacking, privilege escalation - **Comprehensive security tests**: JWT algorithm attacks, session hijacking, privilege escalation
- **Dependency vulnerability scanning**: `make dep-audit` (pip-audit against OSV database)
- **License compliance**: `make license-check` (blocks GPL-3.0/AGPL)
- **Secrets detection**: Pre-commit hook blocks hardcoded secrets
- **Unified security pipeline**: `make audit` (all security checks), `make check` (quality + security + tests)
## Docker Deployment ## Docker Deployment
@@ -281,7 +289,7 @@ docker-compose exec backend python -c "from app.init_db import init_db; import a
- Authentication system (JWT with refresh tokens, OAuth/social login) - Authentication system (JWT with refresh tokens, OAuth/social login)
- **OAuth Provider Mode (MCP-ready)**: Full OAuth 2.0 Authorization Server - **OAuth Provider Mode (MCP-ready)**: Full OAuth 2.0 Authorization Server
- Session management (device tracking, revocation) - Session management (device tracking, revocation)
- User management (CRUD, password change) - User management (full lifecycle, password change)
- Organization system (multi-tenant with RBAC) - Organization system (multi-tenant with RBAC)
- Admin panel (user/org management, bulk operations) - Admin panel (user/org management, bulk operations)
- **Internationalization (i18n)** with English and Italian - **Internationalization (i18n)** with English and Italian

View File

@@ -43,7 +43,7 @@ EOF
- Check current state: `python migrate.py current` - Check current state: `python migrate.py current`
**Frontend API Client Generation:** **Frontend API Client Generation:**
- Run `npm run generate:api` after backend schema changes - Run `bun run generate:api` after backend schema changes
- Client is auto-generated from OpenAPI spec - Client is auto-generated from OpenAPI spec
- Located in `frontend/src/lib/api/generated/` - Located in `frontend/src/lib/api/generated/`
- NEVER manually edit generated files - NEVER manually edit generated files
@@ -51,10 +51,16 @@ EOF
**Testing Commands:** **Testing Commands:**
- Backend unit/integration: `IS_TEST=True uv run pytest` (always prefix with `IS_TEST=True`) - Backend unit/integration: `IS_TEST=True uv run pytest` (always prefix with `IS_TEST=True`)
- Backend E2E (requires Docker): `make test-e2e` - Backend E2E (requires Docker): `make test-e2e`
- Frontend unit: `npm test` - Frontend unit: `bun run test`
- Frontend E2E: `npm run test:e2e` - Frontend E2E: `bun run test:e2e`
- Use `make test` or `make test-cov` in backend for convenience - Use `make test` or `make test-cov` in backend for convenience
**Security & Quality Commands (Backend):**
- `make validate` — lint + format + type checks
- `make audit` — dependency vulnerabilities + license compliance
- `make validate-all` — quality + security checks
- `make check`**full pipeline**: quality + security + tests
**Backend E2E Testing (requires Docker):** **Backend E2E Testing (requires Docker):**
- Install deps: `make install-e2e` - Install deps: `make install-e2e`
- Run all E2E tests: `make test-e2e` - Run all E2E tests: `make test-e2e`
@@ -142,7 +148,7 @@ async def mock_commit():
with patch.object(session, 'commit', side_effect=mock_commit): with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback: with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
with pytest.raises(OperationalError): with pytest.raises(OperationalError):
await crud_method(session, obj_in=data) await repo_method(session, obj_in=data)
mock_rollback.assert_called_once() mock_rollback.assert_called_once()
``` ```
@@ -157,14 +163,18 @@ with patch.object(session, 'commit', side_effect=mock_commit):
- Never skip security headers in production - Never skip security headers in production
- Rate limiting is configured in route decorators: `@limiter.limit("10/minute")` - Rate limiting is configured in route decorators: `@limiter.limit("10/minute")`
- Session revocation is database-backed, not just JWT expiry - Session revocation is database-backed, not just JWT expiry
- Run `make audit` to check for dependency vulnerabilities and license compliance
- Run `make check` for the full pipeline: quality + security + tests
- Pre-commit hooks enforce Ruff lint/format and detect-secrets on every commit
- Setup hooks: `cd backend && uv run pre-commit install`
### Common Workflows Guidance ### Common Workflows Guidance
**When Adding a New Feature:** **When Adding a New Feature:**
1. Start with backend schema and CRUD 1. Start with backend schema and repository
2. Implement API route with proper authorization 2. Implement API route with proper authorization
3. Write backend tests (aim for >90% coverage) 3. Write backend tests (aim for >90% coverage)
4. Generate frontend API client: `npm run generate:api` 4. Generate frontend API client: `bun run generate:api`
5. Implement frontend components 5. Implement frontend components
6. Write frontend unit tests 6. Write frontend unit tests
7. Add E2E tests for critical flows 7. Add E2E tests for critical flows
@@ -177,8 +187,8 @@ with patch.object(session, 'commit', side_effect=mock_commit):
**When Debugging:** **When Debugging:**
- Backend: Check `IS_TEST=True` environment variable is set - Backend: Check `IS_TEST=True` environment variable is set
- Frontend: Run `npm run type-check` first - Frontend: Run `bun run type-check` first
- E2E: Use `npm run test:e2e:debug` for step-by-step debugging - E2E: Use `bun run test:e2e:debug` for step-by-step debugging
- Check logs: Backend has detailed error logging - Check logs: Backend has detailed error logging
**Demo Mode (Frontend-Only Showcase):** **Demo Mode (Frontend-Only Showcase):**
@@ -186,7 +196,7 @@ with patch.object(session, 'commit', side_effect=mock_commit):
- Uses MSW (Mock Service Worker) to intercept API calls in browser - Uses MSW (Mock Service Worker) to intercept API calls in browser
- Zero backend required - perfect for Vercel deployments - Zero backend required - perfect for Vercel deployments
- **Fully Automated**: MSW handlers auto-generated from OpenAPI spec - **Fully Automated**: MSW handlers auto-generated from OpenAPI spec
- Run `npm run generate:api` → updates both API client AND MSW handlers - Run `bun run generate:api` → updates both API client AND MSW handlers
- No manual synchronization needed! - No manual synchronization needed!
- Demo credentials (any password ≥8 chars works): - Demo credentials (any password ≥8 chars works):
- User: `demo@example.com` / `DemoPass123` - User: `demo@example.com` / `DemoPass123`
@@ -214,7 +224,7 @@ with patch.object(session, 'commit', side_effect=mock_commit):
No Claude Code Skills installed yet. To create one, invoke the built-in "skill-creator" skill. No Claude Code Skills installed yet. To create one, invoke the built-in "skill-creator" skill.
**Potential skill ideas for this project:** **Potential skill ideas for this project:**
- API endpoint generator workflow (schema → CRUD → route → tests → frontend client) - API endpoint generator workflow (schema → repository → route → tests → frontend client)
- Component generator with design system compliance - Component generator with design system compliance
- Database migration troubleshooting helper - Database migration troubleshooting helper
- Test coverage analyzer and improvement suggester - Test coverage analyzer and improvement suggester

View File

@@ -91,7 +91,10 @@ Ready to write some code? Awesome!
cd backend cd backend
# Install dependencies (uv manages virtual environment automatically) # Install dependencies (uv manages virtual environment automatically)
uv sync make install-dev
# Setup pre-commit hooks
uv run pre-commit install
# Setup environment # Setup environment
cp .env.example .env cp .env.example .env
@@ -100,8 +103,14 @@ cp .env.example .env
# Run migrations # Run migrations
python migrate.py apply python migrate.py apply
# Run quality + security checks
make validate-all
# Run tests # Run tests
IS_TEST=True uv run pytest make test
# Run full pipeline (quality + security + tests)
make check
# Start dev server # Start dev server
uvicorn app.main:app --reload uvicorn app.main:app --reload
@@ -113,20 +122,20 @@ uvicorn app.main:app --reload
cd frontend cd frontend
# Install dependencies # Install dependencies
npm install bun install
# Setup environment # Setup environment
cp .env.local.example .env.local cp .env.local.example .env.local
# Generate API client # Generate API client
npm run generate:api bun run generate:api
# Run tests # Run tests
npm test bun run test
npm run test:e2e:ui bun run test:e2e:ui
# Start dev server # Start dev server
npm run dev bun run dev
``` ```
--- ---
@@ -195,7 +204,7 @@ export function UserProfile({ userId }: UserProfileProps) {
### Key Patterns ### Key Patterns
- **Backend**: Use CRUD pattern, keep routes thin, business logic in services - **Backend**: Use repository pattern, keep routes thin, business logic in services
- **Frontend**: Use React Query for server state, Zustand for client state - **Frontend**: Use React Query for server state, Zustand for client state
- **Both**: Handle errors gracefully, log appropriately, write tests - **Both**: Handle errors gracefully, log appropriately, write tests
@@ -316,7 +325,7 @@ Fixed stuff
### Before Submitting ### Before Submitting
- [ ] Code follows project style guidelines - [ ] Code follows project style guidelines
- [ ] All tests pass locally - [ ] `make check` passes (quality + security + tests) in backend
- [ ] New tests added for new features - [ ] New tests added for new features
- [ ] Documentation updated if needed - [ ] Documentation updated if needed
- [ ] No merge conflicts with `main` - [ ] No merge conflicts with `main`

View File

@@ -1,4 +1,4 @@
.PHONY: help dev dev-full prod down logs logs-dev clean clean-slate drop-db reset-db push-images deploy .PHONY: help dev dev-full prod down logs logs-dev clean clean-slate drop-db reset-db push-images deploy scan-images
VERSION ?= latest VERSION ?= latest
REGISTRY ?= ghcr.io/cardosofelipe/pragma-stack REGISTRY ?= ghcr.io/cardosofelipe/pragma-stack
@@ -21,6 +21,7 @@ help:
@echo " make prod - Start production stack" @echo " make prod - Start production stack"
@echo " make deploy - Pull and deploy latest images" @echo " make deploy - Pull and deploy latest images"
@echo " make push-images - Build and push images to registry" @echo " make push-images - Build and push images to registry"
@echo " make scan-images - Scan production images for CVEs (requires trivy)"
@echo " make logs - Follow production container logs" @echo " make logs - Follow production container logs"
@echo "" @echo ""
@echo "Cleanup:" @echo "Cleanup:"
@@ -89,6 +90,28 @@ push-images:
docker push $(REGISTRY)/backend:$(VERSION) docker push $(REGISTRY)/backend:$(VERSION)
docker push $(REGISTRY)/frontend:$(VERSION) docker push $(REGISTRY)/frontend:$(VERSION)
scan-images:
@docker info > /dev/null 2>&1 || (echo "❌ Docker is not running!"; exit 1)
@echo "🐳 Building and scanning production images for CVEs..."
docker build -t $(REGISTRY)/backend:scan --target production ./backend
docker build -t $(REGISTRY)/frontend:scan --target runner ./frontend
@echo ""
@echo "=== Backend Image Scan ==="
@if command -v trivy > /dev/null 2>&1; then \
trivy image --severity HIGH,CRITICAL --exit-code 1 $(REGISTRY)/backend:scan; \
else \
echo " Trivy not found locally, using Docker to run Trivy..."; \
docker run --rm -v /var/run/docker.sock:/var/run/docker.sock aquasec/trivy image --severity HIGH,CRITICAL --exit-code 1 $(REGISTRY)/backend:scan; \
fi
@echo ""
@echo "=== Frontend Image Scan ==="
@if command -v trivy > /dev/null 2>&1; then \
trivy image --severity HIGH,CRITICAL --exit-code 1 $(REGISTRY)/frontend:scan; \
else \
docker run --rm -v /var/run/docker.sock:/var/run/docker.sock aquasec/trivy image --severity HIGH,CRITICAL --exit-code 1 $(REGISTRY)/frontend:scan; \
fi
@echo "✅ No HIGH/CRITICAL CVEs found in production images!"
# ============================================================================ # ============================================================================
# Cleanup # Cleanup
# ============================================================================ # ============================================================================

View File

@@ -58,7 +58,7 @@ Full OAuth 2.0 Authorization Server for Model Context Protocol (MCP) and third-p
- User can belong to multiple organizations - User can belong to multiple organizations
### 🛠️ **Admin Panel** ### 🛠️ **Admin Panel**
- Complete user management (CRUD, activate/deactivate, bulk operations) - Complete user management (full lifecycle, activate/deactivate, bulk operations)
- Organization management (create, edit, delete, member management) - Organization management (create, edit, delete, member management)
- Session monitoring across all users - Session monitoring across all users
- Real-time statistics dashboard - Real-time statistics dashboard
@@ -166,7 +166,7 @@ Full OAuth 2.0 Authorization Server for Model Context Protocol (MCP) and third-p
```bash ```bash
cd frontend cd frontend
echo "NEXT_PUBLIC_DEMO_MODE=true" > .env.local echo "NEXT_PUBLIC_DEMO_MODE=true" > .env.local
npm run dev bun run dev
``` ```
**Demo Credentials:** **Demo Credentials:**
@@ -298,17 +298,17 @@ uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
cd frontend cd frontend
# Install dependencies # Install dependencies
npm install bun install
# Setup environment # Setup environment
cp .env.local.example .env.local cp .env.local.example .env.local
# Edit .env.local with your backend URL # Edit .env.local with your backend URL
# Generate API client # Generate API client
npm run generate:api bun run generate:api
# Start development server # Start development server
npm run dev bun run dev
``` ```
Visit http://localhost:3000 to see your app! Visit http://localhost:3000 to see your app!
@@ -322,7 +322,7 @@ Visit http://localhost:3000 to see your app!
│ ├── app/ │ ├── app/
│ │ ├── api/ # API routes and dependencies │ │ ├── api/ # API routes and dependencies
│ │ ├── core/ # Core functionality (auth, config, database) │ │ ├── core/ # Core functionality (auth, config, database)
│ │ ├── crud/ # Database operations │ │ ├── repositories/ # Repository pattern (database operations)
│ │ ├── models/ # SQLAlchemy models │ │ ├── models/ # SQLAlchemy models
│ │ ├── schemas/ # Pydantic schemas │ │ ├── schemas/ # Pydantic schemas
│ │ ├── services/ # Business logic │ │ ├── services/ # Business logic
@@ -377,7 +377,7 @@ open htmlcov/index.html
``` ```
**Test types:** **Test types:**
- **Unit tests**: CRUD operations, utilities, business logic - **Unit tests**: Repository operations, utilities, business logic
- **Integration tests**: API endpoints with database - **Integration tests**: API endpoints with database
- **Security tests**: JWT algorithm attacks, session hijacking, privilege escalation - **Security tests**: JWT algorithm attacks, session hijacking, privilege escalation
- **Error handling tests**: Database failures, validation errors - **Error handling tests**: Database failures, validation errors
@@ -390,13 +390,13 @@ open htmlcov/index.html
cd frontend cd frontend
# Run unit tests # Run unit tests
npm test bun run test
# Run with coverage # Run with coverage
npm run test:coverage bun run test:coverage
# Watch mode # Watch mode
npm run test:watch bun run test:watch
``` ```
**Test types:** **Test types:**
@@ -414,10 +414,10 @@ npm run test:watch
cd frontend cd frontend
# Run E2E tests # Run E2E tests
npm run test:e2e bun run test:e2e
# Run E2E tests in UI mode (recommended for development) # Run E2E tests in UI mode (recommended for development)
npm run test:e2e:ui bun run test:e2e:ui
# Run specific test file # Run specific test file
npx playwright test auth-login.spec.ts npx playwright test auth-login.spec.ts
@@ -542,7 +542,7 @@ docker-compose down
### ✅ Completed ### ✅ Completed
- [x] Authentication system (JWT, refresh tokens, session management, OAuth) - [x] Authentication system (JWT, refresh tokens, session management, OAuth)
- [x] User management (CRUD, profile, password change) - [x] User management (full lifecycle, profile, password change)
- [x] Organization system with RBAC (Owner, Admin, Member) - [x] Organization system with RBAC (Owner, Admin, Member)
- [x] Admin panel (users, organizations, sessions, statistics) - [x] Admin panel (users, organizations, sessions, statistics)
- [x] **Internationalization (i18n)** with next-intl (English + Italian) - [x] **Internationalization (i18n)** with next-intl (English + Italian)

View File

@@ -11,7 +11,7 @@ omit =
app/utils/auth_test_utils.py app/utils/auth_test_utils.py
# Async implementations not yet in use # Async implementations not yet in use
app/crud/base_async.py app/repositories/base_async.py
app/core/database_async.py app/core/database_async.py
# CLI scripts - run manually, not tested # CLI scripts - run manually, not tested
@@ -23,7 +23,7 @@ omit =
app/api/routes/__init__.py app/api/routes/__init__.py
app/api/dependencies/__init__.py app/api/dependencies/__init__.py
app/core/__init__.py app/core/__init__.py
app/crud/__init__.py app/repositories/__init__.py
app/models/__init__.py app/models/__init__.py
app/schemas/__init__.py app/schemas/__init__.py
app/services/__init__.py app/services/__init__.py

View File

@@ -0,0 +1,44 @@
# Pre-commit hooks for backend quality and security checks.
#
# Install:
# cd backend && uv run pre-commit install
#
# Run manually on all files:
# cd backend && uv run pre-commit run --all-files
#
# Skip hooks temporarily:
# git commit --no-verify
#
repos:
# ── Code Quality ──────────────────────────────────────────────────────────
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.14.4
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- id: ruff-format
# ── General File Hygiene ──────────────────────────────────────────────────
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-toml
- id: check-merge-conflict
- id: check-added-large-files
args: [--maxkb=500]
- id: debug-statements
# ── Security ──────────────────────────────────────────────────────────────
- repo: https://github.com/Yelp/detect-secrets
rev: v1.5.0
hooks:
- id: detect-secrets
args: ['--baseline', '.secrets.baseline']
exclude: |
(?x)^(
.*\.lock$|
.*\.svg$
)$

1073
backend/.secrets.baseline Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -33,11 +33,11 @@ RUN chmod +x /usr/local/bin/entrypoint.sh
ENTRYPOINT ["/usr/local/bin/entrypoint.sh"] ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
# Production stage # Production stage — Alpine eliminates glibc CVEs (e.g. CVE-2026-0861)
FROM python:3.12-slim AS production FROM python:3.12-alpine AS production
# Create non-root user # Create non-root user
RUN groupadd -r appuser && useradd -r -g appuser appuser RUN addgroup -S appuser && adduser -S -G appuser appuser
WORKDIR /app WORKDIR /app
ENV PYTHONDONTWRITEBYTECODE=1 \ ENV PYTHONDONTWRITEBYTECODE=1 \
@@ -48,18 +48,18 @@ ENV PYTHONDONTWRITEBYTECODE=1 \
UV_NO_CACHE=1 UV_NO_CACHE=1
# Install system dependencies and uv # Install system dependencies and uv
RUN apt-get update && \ RUN apk add --no-cache postgresql-client curl ca-certificates && \
apt-get install -y --no-install-recommends postgresql-client curl ca-certificates && \
curl -LsSf https://astral.sh/uv/install.sh | sh && \ curl -LsSf https://astral.sh/uv/install.sh | sh && \
mv /root/.local/bin/uv* /usr/local/bin/ && \ mv /root/.local/bin/uv* /usr/local/bin/
apt-get clean && \
rm -rf /var/lib/apt/lists/*
# Copy dependency files # Copy dependency files
COPY pyproject.toml uv.lock ./ COPY pyproject.toml uv.lock ./
# Install only production dependencies using uv (no dev dependencies) # Install build dependencies, compile Python packages, then remove build deps
RUN uv sync --frozen --no-dev RUN apk add --no-cache --virtual .build-deps \
gcc g++ musl-dev python3-dev linux-headers libffi-dev openssl-dev && \
uv sync --frozen --no-dev && \
apk del .build-deps
# Copy application code # Copy application code
COPY . . COPY . .

View File

@@ -1,4 +1,7 @@
.PHONY: help lint lint-fix format format-check type-check test test-cov validate clean install-dev sync check-docker install-e2e test-e2e test-e2e-schema test-all .PHONY: help lint lint-fix format format-check type-check test test-cov validate clean install-dev sync check-docker install-e2e test-e2e test-e2e-schema test-all dep-audit license-check audit validate-all check benchmark benchmark-check benchmark-save scan-image test-api-security
# Prevent a stale VIRTUAL_ENV in the caller's shell from confusing uv
unexport VIRTUAL_ENV
# Default target # Default target
help: help:
@@ -14,8 +17,21 @@ help:
@echo " make lint-fix - Run Ruff linter with auto-fix" @echo " make lint-fix - Run Ruff linter with auto-fix"
@echo " make format - Format code with Ruff" @echo " make format - Format code with Ruff"
@echo " make format-check - Check if code is formatted" @echo " make format-check - Check if code is formatted"
@echo " make type-check - Run mypy type checking" @echo " make type-check - Run pyright type checking"
@echo " make validate - Run all checks (lint + format + types)" @echo " make validate - Run all checks (lint + format + types + schema fuzz)"
@echo ""
@echo "Performance:"
@echo " make benchmark - Run performance benchmarks"
@echo " make benchmark-save - Run benchmarks and save as baseline"
@echo " make benchmark-check - Run benchmarks and compare against baseline"
@echo ""
@echo "Security & Audit:"
@echo " make dep-audit - Scan dependencies for known vulnerabilities"
@echo " make license-check - Check dependency license compliance"
@echo " make audit - Run all security audits (deps + licenses)"
@echo " make scan-image - Scan Docker image for CVEs (requires trivy)"
@echo " make validate-all - Run all quality + security checks"
@echo " make check - Full pipeline: quality + security + tests"
@echo "" @echo ""
@echo "Testing:" @echo "Testing:"
@echo " make test - Run pytest (unit/integration, SQLite)" @echo " make test - Run pytest (unit/integration, SQLite)"
@@ -24,6 +40,7 @@ help:
@echo " make test-e2e-schema - Run Schemathesis API schema tests" @echo " make test-e2e-schema - Run Schemathesis API schema tests"
@echo " make test-all - Run all tests (unit + E2E)" @echo " make test-all - Run all tests (unit + E2E)"
@echo " make check-docker - Check if Docker is available" @echo " make check-docker - Check if Docker is available"
@echo " make check - Full pipeline: quality + security + tests"
@echo "" @echo ""
@echo "Cleanup:" @echo "Cleanup:"
@echo " make clean - Remove cache and build artifacts" @echo " make clean - Remove cache and build artifacts"
@@ -63,12 +80,52 @@ format-check:
@uv run ruff format --check app/ tests/ @uv run ruff format --check app/ tests/
type-check: type-check:
@echo "🔎 Running mypy type checking..." @echo "🔎 Running pyright type checking..."
@uv run mypy app/ @uv run pyright app/
validate: lint format-check type-check validate: lint format-check type-check test-api-security
@echo "✅ All quality checks passed!" @echo "✅ All quality checks passed!"
# API Security Testing (Schemathesis property-based fuzzing)
test-api-security: check-docker
@echo "🔐 Running Schemathesis API security fuzzing..."
@IS_TEST=True PYTHONPATH=. uv run pytest tests/e2e/ -v -m "schemathesis" --tb=short -n 0
@echo "✅ API schema security tests passed!"
# ============================================================================
# Security & Audit
# ============================================================================
dep-audit:
@echo "🔒 Scanning dependencies for known vulnerabilities..."
@uv run pip-audit --desc --skip-editable
@echo "✅ No known vulnerabilities found!"
license-check:
@echo "📜 Checking dependency license compliance..."
@uv run pip-licenses --fail-on="GPL-3.0-or-later;AGPL-3.0-or-later" --format=plain > /dev/null
@echo "✅ All dependency licenses are compliant!"
audit: dep-audit license-check
@echo "✅ All security audits passed!"
scan-image: check-docker
@echo "🐳 Scanning Docker image for OS-level CVEs with Trivy..."
@docker build -t pragma-backend:scan -q --target production .
@if command -v trivy > /dev/null 2>&1; then \
trivy image --severity HIGH,CRITICAL --exit-code 1 pragma-backend:scan; \
else \
echo " Trivy not found locally, using Docker to run Trivy..."; \
docker run --rm -v /var/run/docker.sock:/var/run/docker.sock aquasec/trivy image --severity HIGH,CRITICAL --exit-code 1 pragma-backend:scan; \
fi
@echo "✅ No HIGH/CRITICAL CVEs found in Docker image!"
validate-all: validate audit
@echo "✅ All quality + security checks passed!"
check: validate-all test
@echo "✅ Full validation pipeline complete!"
# ============================================================================ # ============================================================================
# Testing # Testing
# ============================================================================ # ============================================================================
@@ -114,6 +171,31 @@ test-e2e-schema: check-docker
@echo "🧪 Running Schemathesis API schema tests..." @echo "🧪 Running Schemathesis API schema tests..."
@IS_TEST=True PYTHONPATH=. uv run pytest tests/e2e/ -v -m "schemathesis" --tb=short -n 0 @IS_TEST=True PYTHONPATH=. uv run pytest tests/e2e/ -v -m "schemathesis" --tb=short -n 0
# ============================================================================
# Performance Benchmarks
# ============================================================================
benchmark:
@echo "⏱️ Running performance benchmarks..."
@IS_TEST=True PYTHONPATH=. uv run pytest tests/benchmarks/ -v --benchmark-only --benchmark-sort=mean -p no:xdist --override-ini='addopts='
benchmark-save:
@echo "⏱️ Running benchmarks and saving baseline..."
@IS_TEST=True PYTHONPATH=. uv run pytest tests/benchmarks/ -v --benchmark-only --benchmark-save=baseline --benchmark-sort=mean -p no:xdist --override-ini='addopts='
@echo "✅ Benchmark baseline saved to .benchmarks/"
benchmark-check:
@echo "⏱️ Running benchmarks and comparing against baseline..."
@if find .benchmarks -name '*_baseline*' -print -quit 2>/dev/null | grep -q .; then \
IS_TEST=True PYTHONPATH=. uv run pytest tests/benchmarks/ -v --benchmark-only --benchmark-compare=0001_baseline --benchmark-sort=mean --benchmark-compare-fail=mean:200% -p no:xdist --override-ini='addopts='; \
echo "✅ No performance regressions detected!"; \
else \
echo "⚠️ No benchmark baseline found. Run 'make benchmark-save' first to create one."; \
echo " Running benchmarks without comparison..."; \
IS_TEST=True PYTHONPATH=. uv run pytest tests/benchmarks/ -v --benchmark-only --benchmark-save=baseline --benchmark-sort=mean -p no:xdist --override-ini='addopts='; \
echo "✅ Benchmark baseline created. Future runs of 'make benchmark-check' will compare against it."; \
fi
test-all: test-all:
@echo "🧪 Running ALL tests (unit + E2E)..." @echo "🧪 Running ALL tests (unit + E2E)..."
@$(MAKE) test @$(MAKE) test
@@ -127,7 +209,7 @@ clean:
@echo "🧹 Cleaning up..." @echo "🧹 Cleaning up..."
@find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true @find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true
@find . -type d -name ".pytest_cache" -exec rm -rf {} + 2>/dev/null || true @find . -type d -name ".pytest_cache" -exec rm -rf {} + 2>/dev/null || true
@find . -type d -name ".mypy_cache" -exec rm -rf {} + 2>/dev/null || true @find . -type d -name ".pyright" -exec rm -rf {} + 2>/dev/null || true
@find . -type d -name ".ruff_cache" -exec rm -rf {} + 2>/dev/null || true @find . -type d -name ".ruff_cache" -exec rm -rf {} + 2>/dev/null || true
@find . -type d -name "*.egg-info" -exec rm -rf {} + 2>/dev/null || true @find . -type d -name "*.egg-info" -exec rm -rf {} + 2>/dev/null || true
@find . -type d -name "htmlcov" -exec rm -rf {} + 2>/dev/null || true @find . -type d -name "htmlcov" -exec rm -rf {} + 2>/dev/null || true

View File

@@ -14,7 +14,9 @@ Features:
- **Multi-tenancy**: Organization-based access control with roles (Owner/Admin/Member) - **Multi-tenancy**: Organization-based access control with roles (Owner/Admin/Member)
- **Testing**: 97%+ coverage with security-focused test suite - **Testing**: 97%+ coverage with security-focused test suite
- **Performance**: Async throughout, connection pooling, optimized queries - **Performance**: Async throughout, connection pooling, optimized queries
- **Modern Tooling**: uv for dependencies, Ruff for linting/formatting, mypy for type checking - **Modern Tooling**: uv for dependencies, Ruff for linting/formatting, Pyright for type checking
- **Security Auditing**: Automated dependency vulnerability scanning, license compliance, secrets detection
- **Pre-commit Hooks**: Ruff, detect-secrets, and standard checks on every commit
## Quick Start ## Quick Start
@@ -149,7 +151,7 @@ uv pip list --outdated
# Run any Python command via uv (no activation needed) # Run any Python command via uv (no activation needed)
uv run python script.py uv run python script.py
uv run pytest uv run pytest
uv run mypy app/ uv run pyright app/
# Or activate the virtual environment # Or activate the virtual environment
source .venv/bin/activate source .venv/bin/activate
@@ -171,12 +173,22 @@ make lint # Run Ruff linter (check only)
make lint-fix # Run Ruff with auto-fix make lint-fix # Run Ruff with auto-fix
make format # Format code with Ruff make format # Format code with Ruff
make format-check # Check if code is formatted make format-check # Check if code is formatted
make type-check # Run mypy type checking make type-check # Run Pyright type checking
make validate # Run all checks (lint + format + types) make validate # Run all checks (lint + format + types)
# Security & Audit
make dep-audit # Scan dependencies for known vulnerabilities (CVEs)
make license-check # Check dependency license compliance
make audit # Run all security audits (deps + licenses)
make validate-all # Run all quality + security checks
make check # Full pipeline: quality + security + tests
# Testing # Testing
make test # Run all tests make test # Run all tests
make test-cov # Run tests with coverage report make test-cov # Run tests with coverage report
make test-e2e # Run E2E tests (PostgreSQL, requires Docker)
make test-e2e-schema # Run Schemathesis API schema tests
make test-all # Run all tests (unit + E2E)
# Utilities # Utilities
make clean # Remove cache and build artifacts make clean # Remove cache and build artifacts
@@ -252,7 +264,7 @@ app/
│ ├── database.py # Database engine setup │ ├── database.py # Database engine setup
│ ├── auth.py # JWT token handling │ ├── auth.py # JWT token handling
│ └── exceptions.py # Custom exceptions │ └── exceptions.py # Custom exceptions
├── crud/ # Database operations ├── repositories/ # Repository pattern (database operations)
├── models/ # SQLAlchemy ORM models ├── models/ # SQLAlchemy ORM models
├── schemas/ # Pydantic request/response schemas ├── schemas/ # Pydantic request/response schemas
├── services/ # Business logic layer ├── services/ # Business logic layer
@@ -352,18 +364,29 @@ open htmlcov/index.html
# Using Makefile (recommended) # Using Makefile (recommended)
make lint # Ruff linting make lint # Ruff linting
make format # Ruff formatting make format # Ruff formatting
make type-check # mypy type checking make type-check # Pyright type checking
make validate # All checks at once make validate # All checks at once
# Security audits
make dep-audit # Scan dependencies for CVEs
make license-check # Check license compliance
make audit # All security audits
make validate-all # Quality + security checks
make check # Full pipeline: quality + security + tests
# Using uv directly # Using uv directly
uv run ruff check app/ tests/ uv run ruff check app/ tests/
uv run ruff format app/ tests/ uv run ruff format app/ tests/
uv run mypy app/ uv run pyright app/
``` ```
**Tools:** **Tools:**
- **Ruff**: All-in-one linting, formatting, and import sorting (replaces Black, Flake8, isort) - **Ruff**: All-in-one linting, formatting, and import sorting (replaces Black, Flake8, isort)
- **mypy**: Static type checking with Pydantic plugin - **Pyright**: Static type checking (strict mode)
- **pip-audit**: Dependency vulnerability scanning against the OSV database
- **pip-licenses**: Dependency license compliance checking
- **detect-secrets**: Hardcoded secrets/credentials detection
- **pre-commit**: Git hook framework for automated checks on every commit
All configurations are in `pyproject.toml`. All configurations are in `pyproject.toml`.
@@ -439,7 +462,7 @@ See [docs/FEATURE_EXAMPLE.md](docs/FEATURE_EXAMPLE.md) for step-by-step guide.
Quick overview: Quick overview:
1. Create Pydantic schemas in `app/schemas/` 1. Create Pydantic schemas in `app/schemas/`
2. Create CRUD operations in `app/crud/` 2. Create repository in `app/repositories/`
3. Create route in `app/api/routes/` 3. Create route in `app/api/routes/`
4. Register router in `app/api/main.py` 4. Register router in `app/api/main.py`
5. Write tests in `tests/api/` 5. Write tests in `tests/api/`
@@ -589,13 +612,42 @@ Configured in `app/core/config.py`:
- **Security Headers**: CSP, HSTS, X-Frame-Options, etc. - **Security Headers**: CSP, HSTS, X-Frame-Options, etc.
- **Input Validation**: Pydantic schemas, SQL injection prevention (ORM) - **Input Validation**: Pydantic schemas, SQL injection prevention (ORM)
### Security Auditing
Automated, deterministic security checks are built into the development workflow:
```bash
# Scan dependencies for known vulnerabilities (CVEs)
make dep-audit
# Check dependency license compliance (blocks GPL-3.0/AGPL)
make license-check
# Run all security audits
make audit
# Full pipeline: quality + security + tests
make check
```
**Pre-commit hooks** automatically run on every commit:
- **Ruff** lint + format checks
- **detect-secrets** blocks commits containing hardcoded secrets
- **Standard checks**: trailing whitespace, YAML/TOML validation, merge conflict detection, large file prevention
Setup pre-commit hooks:
```bash
uv run pre-commit install
```
### Security Best Practices ### Security Best Practices
1. **Never commit secrets**: Use `.env` files (git-ignored) 1. **Never commit secrets**: Use `.env` files (git-ignored), enforced by detect-secrets pre-commit hook
2. **Strong SECRET_KEY**: Min 32 chars, cryptographically random 2. **Strong SECRET_KEY**: Min 32 chars, cryptographically random
3. **HTTPS in production**: Required for token security 3. **HTTPS in production**: Required for token security
4. **Regular updates**: Keep dependencies current (`uv sync --upgrade`) 4. **Regular updates**: Keep dependencies current (`uv sync --upgrade`), run `make dep-audit` to check for CVEs
5. **Audit logs**: Monitor authentication events 5. **Audit logs**: Monitor authentication events
6. **Run `make check` before pushing**: Validates quality, security, and tests in one command
--- ---
@@ -645,7 +697,11 @@ logging.basicConfig(level=logging.INFO)
**Built with modern Python tooling:** **Built with modern Python tooling:**
- 🚀 **uv** - 10-100x faster dependency management - 🚀 **uv** - 10-100x faster dependency management
-**Ruff** - 10-100x faster linting & formatting -**Ruff** - 10-100x faster linting & formatting
- 🔍 **mypy** - Static type checking - 🔍 **Pyright** - Static type checking (strict mode)
-**pytest** - Comprehensive test suite -**pytest** - Comprehensive test suite
- 🔒 **pip-audit** - Dependency vulnerability scanning
- 🔑 **detect-secrets** - Hardcoded secrets detection
- 📜 **pip-licenses** - License compliance checking
- 🪝 **pre-commit** - Automated git hooks
**All configured in a single `pyproject.toml` file!** **All configured in a single `pyproject.toml` file!**

View File

@@ -40,6 +40,7 @@ def include_object(object, name, type_, reflected, compare_to):
return False return False
return True return True
# Interpret the config file for Python logging. # Interpret the config file for Python logging.
# This line sets up loggers basically. # This line sets up loggers basically.
if config.config_file_name is not None: if config.config_file_name is not None:

View File

@@ -1,262 +1,446 @@
"""initial models """initial models
Revision ID: 0001 Revision ID: 0001
Revises: Revises:
Create Date: 2025-11-27 09:08:09.464506 Create Date: 2025-11-27 09:08:09.464506
""" """
from typing import Sequence, Union
from alembic import op from collections.abc import Sequence
import sqlalchemy as sa import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = '0001' revision: str = "0001"
down_revision: Union[str, None] = None down_revision: str | None = None
branch_labels: Union[str, Sequence[str], None] = None branch_labels: str | Sequence[str] | None = None
depends_on: Union[str, Sequence[str], None] = None depends_on: str | Sequence[str] | None = None
def upgrade() -> None: def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
op.create_table('oauth_states', op.create_table(
sa.Column('state', sa.String(length=255), nullable=False), "oauth_states",
sa.Column('code_verifier', sa.String(length=128), nullable=True), sa.Column("state", sa.String(length=255), nullable=False),
sa.Column('nonce', sa.String(length=255), nullable=True), sa.Column("code_verifier", sa.String(length=128), nullable=True),
sa.Column('provider', sa.String(length=50), nullable=False), sa.Column("nonce", sa.String(length=255), nullable=True),
sa.Column('redirect_uri', sa.String(length=500), nullable=True), sa.Column("provider", sa.String(length=50), nullable=False),
sa.Column('user_id', sa.UUID(), nullable=True), sa.Column("redirect_uri", sa.String(length=500), nullable=True),
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=False), sa.Column("user_id", sa.UUID(), nullable=True),
sa.Column('id', sa.UUID(), nullable=False), sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False), sa.Column("id", sa.UUID(), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False), sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint('id') sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint("id"),
) )
op.create_index(op.f('ix_oauth_states_state'), 'oauth_states', ['state'], unique=True) op.create_index(
op.create_table('organizations', op.f("ix_oauth_states_state"), "oauth_states", ["state"], unique=True
sa.Column('name', sa.String(length=255), nullable=False),
sa.Column('slug', sa.String(length=255), nullable=False),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('is_active', sa.Boolean(), nullable=False),
sa.Column('settings', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint('id')
) )
op.create_index(op.f('ix_organizations_is_active'), 'organizations', ['is_active'], unique=False) op.create_table(
op.create_index(op.f('ix_organizations_name'), 'organizations', ['name'], unique=False) "organizations",
op.create_index('ix_organizations_name_active', 'organizations', ['name', 'is_active'], unique=False) sa.Column("name", sa.String(length=255), nullable=False),
op.create_index(op.f('ix_organizations_slug'), 'organizations', ['slug'], unique=True) sa.Column("slug", sa.String(length=255), nullable=False),
op.create_index('ix_organizations_slug_active', 'organizations', ['slug', 'is_active'], unique=False) sa.Column("description", sa.Text(), nullable=True),
op.create_table('users', sa.Column("is_active", sa.Boolean(), nullable=False),
sa.Column('email', sa.String(length=255), nullable=False), sa.Column("settings", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column('password_hash', sa.String(length=255), nullable=True), sa.Column("id", sa.UUID(), nullable=False),
sa.Column('first_name', sa.String(length=100), nullable=False), sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column('last_name', sa.String(length=100), nullable=True), sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.Column('phone_number', sa.String(length=20), nullable=True), sa.PrimaryKeyConstraint("id"),
sa.Column('is_active', sa.Boolean(), nullable=False),
sa.Column('is_superuser', sa.Boolean(), nullable=False),
sa.Column('preferences', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column('locale', sa.String(length=10), nullable=True),
sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True),
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint('id')
) )
op.create_index(op.f('ix_users_deleted_at'), 'users', ['deleted_at'], unique=False) op.create_index(
op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=True) op.f("ix_organizations_is_active"), "organizations", ["is_active"], unique=False
op.create_index(op.f('ix_users_is_active'), 'users', ['is_active'], unique=False)
op.create_index(op.f('ix_users_is_superuser'), 'users', ['is_superuser'], unique=False)
op.create_index(op.f('ix_users_locale'), 'users', ['locale'], unique=False)
op.create_table('oauth_accounts',
sa.Column('user_id', sa.UUID(), nullable=False),
sa.Column('provider', sa.String(length=50), nullable=False),
sa.Column('provider_user_id', sa.String(length=255), nullable=False),
sa.Column('provider_email', sa.String(length=255), nullable=True),
sa.Column('access_token_encrypted', sa.String(length=2048), nullable=True),
sa.Column('refresh_token_encrypted', sa.String(length=2048), nullable=True),
sa.Column('token_expires_at', sa.DateTime(timezone=True), nullable=True),
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('provider', 'provider_user_id', name='uq_oauth_provider_user')
) )
op.create_index(op.f('ix_oauth_accounts_provider'), 'oauth_accounts', ['provider'], unique=False) op.create_index(
op.create_index(op.f('ix_oauth_accounts_provider_email'), 'oauth_accounts', ['provider_email'], unique=False) op.f("ix_organizations_name"), "organizations", ["name"], unique=False
op.create_index(op.f('ix_oauth_accounts_user_id'), 'oauth_accounts', ['user_id'], unique=False)
op.create_index('ix_oauth_accounts_user_provider', 'oauth_accounts', ['user_id', 'provider'], unique=False)
op.create_table('oauth_clients',
sa.Column('client_id', sa.String(length=64), nullable=False),
sa.Column('client_secret_hash', sa.String(length=255), nullable=True),
sa.Column('client_name', sa.String(length=255), nullable=False),
sa.Column('client_description', sa.String(length=1000), nullable=True),
sa.Column('client_type', sa.String(length=20), nullable=False),
sa.Column('redirect_uris', postgresql.JSONB(astext_type=sa.Text()), nullable=False),
sa.Column('allowed_scopes', postgresql.JSONB(astext_type=sa.Text()), nullable=False),
sa.Column('access_token_lifetime', sa.String(length=10), nullable=False),
sa.Column('refresh_token_lifetime', sa.String(length=10), nullable=False),
sa.Column('is_active', sa.Boolean(), nullable=False),
sa.Column('owner_user_id', sa.UUID(), nullable=True),
sa.Column('mcp_server_url', sa.String(length=2048), nullable=True),
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
sa.ForeignKeyConstraint(['owner_user_id'], ['users.id'], ondelete='SET NULL'),
sa.PrimaryKeyConstraint('id')
) )
op.create_index(op.f('ix_oauth_clients_client_id'), 'oauth_clients', ['client_id'], unique=True) op.create_index(
op.create_index(op.f('ix_oauth_clients_is_active'), 'oauth_clients', ['is_active'], unique=False) "ix_organizations_name_active",
op.create_table('user_organizations', "organizations",
sa.Column('user_id', sa.UUID(), nullable=False), ["name", "is_active"],
sa.Column('organization_id', sa.UUID(), nullable=False), unique=False,
sa.Column('role', sa.Enum('OWNER', 'ADMIN', 'MEMBER', 'GUEST', name='organizationrole'), nullable=False),
sa.Column('is_active', sa.Boolean(), nullable=False),
sa.Column('custom_permissions', sa.String(length=500), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('user_id', 'organization_id')
) )
op.create_index('ix_user_org_org_active', 'user_organizations', ['organization_id', 'is_active'], unique=False) op.create_index(
op.create_index('ix_user_org_role', 'user_organizations', ['role'], unique=False) op.f("ix_organizations_slug"), "organizations", ["slug"], unique=True
op.create_index('ix_user_org_user_active', 'user_organizations', ['user_id', 'is_active'], unique=False)
op.create_index(op.f('ix_user_organizations_is_active'), 'user_organizations', ['is_active'], unique=False)
op.create_table('user_sessions',
sa.Column('user_id', sa.UUID(), nullable=False),
sa.Column('refresh_token_jti', sa.String(length=255), nullable=False),
sa.Column('device_name', sa.String(length=255), nullable=True),
sa.Column('device_id', sa.String(length=255), nullable=True),
sa.Column('ip_address', sa.String(length=45), nullable=True),
sa.Column('user_agent', sa.String(length=500), nullable=True),
sa.Column('last_used_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('is_active', sa.Boolean(), nullable=False),
sa.Column('location_city', sa.String(length=100), nullable=True),
sa.Column('location_country', sa.String(length=100), nullable=True),
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
) )
op.create_index(op.f('ix_user_sessions_is_active'), 'user_sessions', ['is_active'], unique=False) op.create_index(
op.create_index('ix_user_sessions_jti_active', 'user_sessions', ['refresh_token_jti', 'is_active'], unique=False) "ix_organizations_slug_active",
op.create_index(op.f('ix_user_sessions_refresh_token_jti'), 'user_sessions', ['refresh_token_jti'], unique=True) "organizations",
op.create_index('ix_user_sessions_user_active', 'user_sessions', ['user_id', 'is_active'], unique=False) ["slug", "is_active"],
op.create_index(op.f('ix_user_sessions_user_id'), 'user_sessions', ['user_id'], unique=False) unique=False,
op.create_table('oauth_authorization_codes',
sa.Column('code', sa.String(length=128), nullable=False),
sa.Column('client_id', sa.String(length=64), nullable=False),
sa.Column('user_id', sa.UUID(), nullable=False),
sa.Column('redirect_uri', sa.String(length=2048), nullable=False),
sa.Column('scope', sa.String(length=1000), nullable=False),
sa.Column('code_challenge', sa.String(length=128), nullable=True),
sa.Column('code_challenge_method', sa.String(length=10), nullable=True),
sa.Column('state', sa.String(length=256), nullable=True),
sa.Column('nonce', sa.String(length=256), nullable=True),
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('used', sa.Boolean(), nullable=False),
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
sa.ForeignKeyConstraint(['client_id'], ['oauth_clients.client_id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
) )
op.create_index('ix_oauth_authorization_codes_client_user', 'oauth_authorization_codes', ['client_id', 'user_id'], unique=False) op.create_table(
op.create_index(op.f('ix_oauth_authorization_codes_code'), 'oauth_authorization_codes', ['code'], unique=True) "users",
op.create_index('ix_oauth_authorization_codes_expires_at', 'oauth_authorization_codes', ['expires_at'], unique=False) sa.Column("email", sa.String(length=255), nullable=False),
op.create_table('oauth_consents', sa.Column("password_hash", sa.String(length=255), nullable=True),
sa.Column('user_id', sa.UUID(), nullable=False), sa.Column("first_name", sa.String(length=100), nullable=False),
sa.Column('client_id', sa.String(length=64), nullable=False), sa.Column("last_name", sa.String(length=100), nullable=True),
sa.Column('granted_scopes', sa.String(length=1000), nullable=False), sa.Column("phone_number", sa.String(length=20), nullable=True),
sa.Column('id', sa.UUID(), nullable=False), sa.Column("is_active", sa.Boolean(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False), sa.Column("is_superuser", sa.Boolean(), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False), sa.Column(
sa.ForeignKeyConstraint(['client_id'], ['oauth_clients.client_id'], ondelete='CASCADE'), "preferences", postgresql.JSONB(astext_type=sa.Text()), nullable=True
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'), ),
sa.PrimaryKeyConstraint('id') sa.Column("locale", sa.String(length=10), nullable=True),
sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("id", sa.UUID(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint("id"),
) )
op.create_index('ix_oauth_consents_user_client', 'oauth_consents', ['user_id', 'client_id'], unique=True) op.create_index(op.f("ix_users_deleted_at"), "users", ["deleted_at"], unique=False)
op.create_table('oauth_provider_refresh_tokens', op.create_index(op.f("ix_users_email"), "users", ["email"], unique=True)
sa.Column('token_hash', sa.String(length=64), nullable=False), op.create_index(op.f("ix_users_is_active"), "users", ["is_active"], unique=False)
sa.Column('jti', sa.String(length=64), nullable=False), op.create_index(
sa.Column('client_id', sa.String(length=64), nullable=False), op.f("ix_users_is_superuser"), "users", ["is_superuser"], unique=False
sa.Column('user_id', sa.UUID(), nullable=False), )
sa.Column('scope', sa.String(length=1000), nullable=False), op.create_index(op.f("ix_users_locale"), "users", ["locale"], unique=False)
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=False), op.create_table(
sa.Column('revoked', sa.Boolean(), nullable=False), "oauth_accounts",
sa.Column('last_used_at', sa.DateTime(timezone=True), nullable=True), sa.Column("user_id", sa.UUID(), nullable=False),
sa.Column('device_info', sa.String(length=500), nullable=True), sa.Column("provider", sa.String(length=50), nullable=False),
sa.Column('ip_address', sa.String(length=45), nullable=True), sa.Column("provider_user_id", sa.String(length=255), nullable=False),
sa.Column('id', sa.UUID(), nullable=False), sa.Column("provider_email", sa.String(length=255), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False), sa.Column("access_token_encrypted", sa.String(length=2048), nullable=True),
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False), sa.Column("refresh_token_encrypted", sa.String(length=2048), nullable=True),
sa.ForeignKeyConstraint(['client_id'], ['oauth_clients.client_id'], ondelete='CASCADE'), sa.Column("token_expires_at", sa.DateTime(timezone=True), nullable=True),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'), sa.Column("id", sa.UUID(), nullable=False),
sa.PrimaryKeyConstraint('id') sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint(
"provider", "provider_user_id", name="uq_oauth_provider_user"
),
)
op.create_index(
op.f("ix_oauth_accounts_provider"), "oauth_accounts", ["provider"], unique=False
)
op.create_index(
op.f("ix_oauth_accounts_provider_email"),
"oauth_accounts",
["provider_email"],
unique=False,
)
op.create_index(
op.f("ix_oauth_accounts_user_id"), "oauth_accounts", ["user_id"], unique=False
)
op.create_index(
"ix_oauth_accounts_user_provider",
"oauth_accounts",
["user_id", "provider"],
unique=False,
)
op.create_table(
"oauth_clients",
sa.Column("client_id", sa.String(length=64), nullable=False),
sa.Column("client_secret_hash", sa.String(length=255), nullable=True),
sa.Column("client_name", sa.String(length=255), nullable=False),
sa.Column("client_description", sa.String(length=1000), nullable=True),
sa.Column("client_type", sa.String(length=20), nullable=False),
sa.Column(
"redirect_uris", postgresql.JSONB(astext_type=sa.Text()), nullable=False
),
sa.Column(
"allowed_scopes", postgresql.JSONB(astext_type=sa.Text()), nullable=False
),
sa.Column("access_token_lifetime", sa.String(length=10), nullable=False),
sa.Column("refresh_token_lifetime", sa.String(length=10), nullable=False),
sa.Column("is_active", sa.Boolean(), nullable=False),
sa.Column("owner_user_id", sa.UUID(), nullable=True),
sa.Column("mcp_server_url", sa.String(length=2048), nullable=True),
sa.Column("id", sa.UUID(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.ForeignKeyConstraint(["owner_user_id"], ["users.id"], ondelete="SET NULL"),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
op.f("ix_oauth_clients_client_id"), "oauth_clients", ["client_id"], unique=True
)
op.create_index(
op.f("ix_oauth_clients_is_active"), "oauth_clients", ["is_active"], unique=False
)
op.create_table(
"user_organizations",
sa.Column("user_id", sa.UUID(), nullable=False),
sa.Column("organization_id", sa.UUID(), nullable=False),
sa.Column(
"role",
sa.Enum("OWNER", "ADMIN", "MEMBER", "GUEST", name="organizationrole"),
nullable=False,
),
sa.Column("is_active", sa.Boolean(), nullable=False),
sa.Column("custom_permissions", sa.String(length=500), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.ForeignKeyConstraint(
["organization_id"], ["organizations.id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("user_id", "organization_id"),
)
op.create_index(
"ix_user_org_org_active",
"user_organizations",
["organization_id", "is_active"],
unique=False,
)
op.create_index("ix_user_org_role", "user_organizations", ["role"], unique=False)
op.create_index(
"ix_user_org_user_active",
"user_organizations",
["user_id", "is_active"],
unique=False,
)
op.create_index(
op.f("ix_user_organizations_is_active"),
"user_organizations",
["is_active"],
unique=False,
)
op.create_table(
"user_sessions",
sa.Column("user_id", sa.UUID(), nullable=False),
sa.Column("refresh_token_jti", sa.String(length=255), nullable=False),
sa.Column("device_name", sa.String(length=255), nullable=True),
sa.Column("device_id", sa.String(length=255), nullable=True),
sa.Column("ip_address", sa.String(length=45), nullable=True),
sa.Column("user_agent", sa.String(length=500), nullable=True),
sa.Column("last_used_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("is_active", sa.Boolean(), nullable=False),
sa.Column("location_city", sa.String(length=100), nullable=True),
sa.Column("location_country", sa.String(length=100), nullable=True),
sa.Column("id", sa.UUID(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
op.f("ix_user_sessions_is_active"), "user_sessions", ["is_active"], unique=False
)
op.create_index(
"ix_user_sessions_jti_active",
"user_sessions",
["refresh_token_jti", "is_active"],
unique=False,
)
op.create_index(
op.f("ix_user_sessions_refresh_token_jti"),
"user_sessions",
["refresh_token_jti"],
unique=True,
)
op.create_index(
"ix_user_sessions_user_active",
"user_sessions",
["user_id", "is_active"],
unique=False,
)
op.create_index(
op.f("ix_user_sessions_user_id"), "user_sessions", ["user_id"], unique=False
)
op.create_table(
"oauth_authorization_codes",
sa.Column("code", sa.String(length=128), nullable=False),
sa.Column("client_id", sa.String(length=64), nullable=False),
sa.Column("user_id", sa.UUID(), nullable=False),
sa.Column("redirect_uri", sa.String(length=2048), nullable=False),
sa.Column("scope", sa.String(length=1000), nullable=False),
sa.Column("code_challenge", sa.String(length=128), nullable=True),
sa.Column("code_challenge_method", sa.String(length=10), nullable=True),
sa.Column("state", sa.String(length=256), nullable=True),
sa.Column("nonce", sa.String(length=256), nullable=True),
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("used", sa.Boolean(), nullable=False),
sa.Column("id", sa.UUID(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.ForeignKeyConstraint(
["client_id"], ["oauth_clients.client_id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"ix_oauth_authorization_codes_client_user",
"oauth_authorization_codes",
["client_id", "user_id"],
unique=False,
)
op.create_index(
op.f("ix_oauth_authorization_codes_code"),
"oauth_authorization_codes",
["code"],
unique=True,
)
op.create_index(
"ix_oauth_authorization_codes_expires_at",
"oauth_authorization_codes",
["expires_at"],
unique=False,
)
op.create_table(
"oauth_consents",
sa.Column("user_id", sa.UUID(), nullable=False),
sa.Column("client_id", sa.String(length=64), nullable=False),
sa.Column("granted_scopes", sa.String(length=1000), nullable=False),
sa.Column("id", sa.UUID(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.ForeignKeyConstraint(
["client_id"], ["oauth_clients.client_id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"ix_oauth_consents_user_client",
"oauth_consents",
["user_id", "client_id"],
unique=True,
)
op.create_table(
"oauth_provider_refresh_tokens",
sa.Column("token_hash", sa.String(length=64), nullable=False),
sa.Column("jti", sa.String(length=64), nullable=False),
sa.Column("client_id", sa.String(length=64), nullable=False),
sa.Column("user_id", sa.UUID(), nullable=False),
sa.Column("scope", sa.String(length=1000), nullable=False),
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("revoked", sa.Boolean(), nullable=False),
sa.Column("last_used_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("device_info", sa.String(length=500), nullable=True),
sa.Column("ip_address", sa.String(length=45), nullable=True),
sa.Column("id", sa.UUID(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.ForeignKeyConstraint(
["client_id"], ["oauth_clients.client_id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"ix_oauth_provider_refresh_tokens_client_user",
"oauth_provider_refresh_tokens",
["client_id", "user_id"],
unique=False,
)
op.create_index(
"ix_oauth_provider_refresh_tokens_expires_at",
"oauth_provider_refresh_tokens",
["expires_at"],
unique=False,
)
op.create_index(
op.f("ix_oauth_provider_refresh_tokens_jti"),
"oauth_provider_refresh_tokens",
["jti"],
unique=True,
)
op.create_index(
op.f("ix_oauth_provider_refresh_tokens_revoked"),
"oauth_provider_refresh_tokens",
["revoked"],
unique=False,
)
op.create_index(
op.f("ix_oauth_provider_refresh_tokens_token_hash"),
"oauth_provider_refresh_tokens",
["token_hash"],
unique=True,
)
op.create_index(
"ix_oauth_provider_refresh_tokens_user_revoked",
"oauth_provider_refresh_tokens",
["user_id", "revoked"],
unique=False,
) )
op.create_index('ix_oauth_provider_refresh_tokens_client_user', 'oauth_provider_refresh_tokens', ['client_id', 'user_id'], unique=False)
op.create_index('ix_oauth_provider_refresh_tokens_expires_at', 'oauth_provider_refresh_tokens', ['expires_at'], unique=False)
op.create_index(op.f('ix_oauth_provider_refresh_tokens_jti'), 'oauth_provider_refresh_tokens', ['jti'], unique=True)
op.create_index(op.f('ix_oauth_provider_refresh_tokens_revoked'), 'oauth_provider_refresh_tokens', ['revoked'], unique=False)
op.create_index(op.f('ix_oauth_provider_refresh_tokens_token_hash'), 'oauth_provider_refresh_tokens', ['token_hash'], unique=True)
op.create_index('ix_oauth_provider_refresh_tokens_user_revoked', 'oauth_provider_refresh_tokens', ['user_id', 'revoked'], unique=False)
# ### end Alembic commands ### # ### end Alembic commands ###
def downgrade() -> None: def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
op.drop_index('ix_oauth_provider_refresh_tokens_user_revoked', table_name='oauth_provider_refresh_tokens') op.drop_index(
op.drop_index(op.f('ix_oauth_provider_refresh_tokens_token_hash'), table_name='oauth_provider_refresh_tokens') "ix_oauth_provider_refresh_tokens_user_revoked",
op.drop_index(op.f('ix_oauth_provider_refresh_tokens_revoked'), table_name='oauth_provider_refresh_tokens') table_name="oauth_provider_refresh_tokens",
op.drop_index(op.f('ix_oauth_provider_refresh_tokens_jti'), table_name='oauth_provider_refresh_tokens') )
op.drop_index('ix_oauth_provider_refresh_tokens_expires_at', table_name='oauth_provider_refresh_tokens') op.drop_index(
op.drop_index('ix_oauth_provider_refresh_tokens_client_user', table_name='oauth_provider_refresh_tokens') op.f("ix_oauth_provider_refresh_tokens_token_hash"),
op.drop_table('oauth_provider_refresh_tokens') table_name="oauth_provider_refresh_tokens",
op.drop_index('ix_oauth_consents_user_client', table_name='oauth_consents') )
op.drop_table('oauth_consents') op.drop_index(
op.drop_index('ix_oauth_authorization_codes_expires_at', table_name='oauth_authorization_codes') op.f("ix_oauth_provider_refresh_tokens_revoked"),
op.drop_index(op.f('ix_oauth_authorization_codes_code'), table_name='oauth_authorization_codes') table_name="oauth_provider_refresh_tokens",
op.drop_index('ix_oauth_authorization_codes_client_user', table_name='oauth_authorization_codes') )
op.drop_table('oauth_authorization_codes') op.drop_index(
op.drop_index(op.f('ix_user_sessions_user_id'), table_name='user_sessions') op.f("ix_oauth_provider_refresh_tokens_jti"),
op.drop_index('ix_user_sessions_user_active', table_name='user_sessions') table_name="oauth_provider_refresh_tokens",
op.drop_index(op.f('ix_user_sessions_refresh_token_jti'), table_name='user_sessions') )
op.drop_index('ix_user_sessions_jti_active', table_name='user_sessions') op.drop_index(
op.drop_index(op.f('ix_user_sessions_is_active'), table_name='user_sessions') "ix_oauth_provider_refresh_tokens_expires_at",
op.drop_table('user_sessions') table_name="oauth_provider_refresh_tokens",
op.drop_index(op.f('ix_user_organizations_is_active'), table_name='user_organizations') )
op.drop_index('ix_user_org_user_active', table_name='user_organizations') op.drop_index(
op.drop_index('ix_user_org_role', table_name='user_organizations') "ix_oauth_provider_refresh_tokens_client_user",
op.drop_index('ix_user_org_org_active', table_name='user_organizations') table_name="oauth_provider_refresh_tokens",
op.drop_table('user_organizations') )
op.drop_index(op.f('ix_oauth_clients_is_active'), table_name='oauth_clients') op.drop_table("oauth_provider_refresh_tokens")
op.drop_index(op.f('ix_oauth_clients_client_id'), table_name='oauth_clients') op.drop_index("ix_oauth_consents_user_client", table_name="oauth_consents")
op.drop_table('oauth_clients') op.drop_table("oauth_consents")
op.drop_index('ix_oauth_accounts_user_provider', table_name='oauth_accounts') op.drop_index(
op.drop_index(op.f('ix_oauth_accounts_user_id'), table_name='oauth_accounts') "ix_oauth_authorization_codes_expires_at",
op.drop_index(op.f('ix_oauth_accounts_provider_email'), table_name='oauth_accounts') table_name="oauth_authorization_codes",
op.drop_index(op.f('ix_oauth_accounts_provider'), table_name='oauth_accounts') )
op.drop_table('oauth_accounts') op.drop_index(
op.drop_index(op.f('ix_users_locale'), table_name='users') op.f("ix_oauth_authorization_codes_code"),
op.drop_index(op.f('ix_users_is_superuser'), table_name='users') table_name="oauth_authorization_codes",
op.drop_index(op.f('ix_users_is_active'), table_name='users') )
op.drop_index(op.f('ix_users_email'), table_name='users') op.drop_index(
op.drop_index(op.f('ix_users_deleted_at'), table_name='users') "ix_oauth_authorization_codes_client_user",
op.drop_table('users') table_name="oauth_authorization_codes",
op.drop_index('ix_organizations_slug_active', table_name='organizations') )
op.drop_index(op.f('ix_organizations_slug'), table_name='organizations') op.drop_table("oauth_authorization_codes")
op.drop_index('ix_organizations_name_active', table_name='organizations') op.drop_index(op.f("ix_user_sessions_user_id"), table_name="user_sessions")
op.drop_index(op.f('ix_organizations_name'), table_name='organizations') op.drop_index("ix_user_sessions_user_active", table_name="user_sessions")
op.drop_index(op.f('ix_organizations_is_active'), table_name='organizations') op.drop_index(
op.drop_table('organizations') op.f("ix_user_sessions_refresh_token_jti"), table_name="user_sessions"
op.drop_index(op.f('ix_oauth_states_state'), table_name='oauth_states') )
op.drop_table('oauth_states') op.drop_index("ix_user_sessions_jti_active", table_name="user_sessions")
op.drop_index(op.f("ix_user_sessions_is_active"), table_name="user_sessions")
op.drop_table("user_sessions")
op.drop_index(
op.f("ix_user_organizations_is_active"), table_name="user_organizations"
)
op.drop_index("ix_user_org_user_active", table_name="user_organizations")
op.drop_index("ix_user_org_role", table_name="user_organizations")
op.drop_index("ix_user_org_org_active", table_name="user_organizations")
op.drop_table("user_organizations")
op.drop_index(op.f("ix_oauth_clients_is_active"), table_name="oauth_clients")
op.drop_index(op.f("ix_oauth_clients_client_id"), table_name="oauth_clients")
op.drop_table("oauth_clients")
op.drop_index("ix_oauth_accounts_user_provider", table_name="oauth_accounts")
op.drop_index(op.f("ix_oauth_accounts_user_id"), table_name="oauth_accounts")
op.drop_index(op.f("ix_oauth_accounts_provider_email"), table_name="oauth_accounts")
op.drop_index(op.f("ix_oauth_accounts_provider"), table_name="oauth_accounts")
op.drop_table("oauth_accounts")
op.drop_index(op.f("ix_users_locale"), table_name="users")
op.drop_index(op.f("ix_users_is_superuser"), table_name="users")
op.drop_index(op.f("ix_users_is_active"), table_name="users")
op.drop_index(op.f("ix_users_email"), table_name="users")
op.drop_index(op.f("ix_users_deleted_at"), table_name="users")
op.drop_table("users")
op.drop_index("ix_organizations_slug_active", table_name="organizations")
op.drop_index(op.f("ix_organizations_slug"), table_name="organizations")
op.drop_index("ix_organizations_name_active", table_name="organizations")
op.drop_index(op.f("ix_organizations_name"), table_name="organizations")
op.drop_index(op.f("ix_organizations_is_active"), table_name="organizations")
op.drop_table("organizations")
op.drop_index(op.f("ix_oauth_states_state"), table_name="oauth_states")
op.drop_table("oauth_states")
# ### end Alembic commands ### # ### end Alembic commands ###

View File

@@ -114,8 +114,13 @@ def upgrade() -> None:
def downgrade() -> None: def downgrade() -> None:
# Drop indexes in reverse order # Drop indexes in reverse order
op.drop_index("ix_perf_oauth_auth_codes_expires", table_name="oauth_authorization_codes") op.drop_index(
op.drop_index("ix_perf_oauth_refresh_tokens_expires", table_name="oauth_provider_refresh_tokens") "ix_perf_oauth_auth_codes_expires", table_name="oauth_authorization_codes"
)
op.drop_index(
"ix_perf_oauth_refresh_tokens_expires",
table_name="oauth_provider_refresh_tokens",
)
op.drop_index("ix_perf_user_sessions_expires", table_name="user_sessions") op.drop_index("ix_perf_user_sessions_expires", table_name="user_sessions")
op.drop_index("ix_perf_organizations_slug_lower", table_name="organizations") op.drop_index("ix_perf_organizations_slug_lower", table_name="organizations")
op.drop_index("ix_perf_users_active", table_name="users") op.drop_index("ix_perf_users_active", table_name="users")

View File

@@ -0,0 +1,35 @@
"""rename oauth account token fields drop encrypted suffix
Revision ID: 0003
Revises: 0002
Create Date: 2026-02-27 01:03:18.869178
"""
from collections.abc import Sequence
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "0003"
down_revision: str | None = "0002"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
op.alter_column(
"oauth_accounts", "access_token_encrypted", new_column_name="access_token"
)
op.alter_column(
"oauth_accounts", "refresh_token_encrypted", new_column_name="refresh_token"
)
def downgrade() -> None:
op.alter_column(
"oauth_accounts", "access_token", new_column_name="access_token_encrypted"
)
op.alter_column(
"oauth_accounts", "refresh_token", new_column_name="refresh_token_encrypted"
)

View File

@@ -1,12 +1,12 @@
from fastapi import Depends, Header, HTTPException, status from fastapi import Depends, Header, HTTPException, status
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from fastapi.security.utils import get_authorization_scheme_param from fastapi.security.utils import get_authorization_scheme_param
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.core.auth import TokenExpiredError, TokenInvalidError, get_token_data from app.core.auth import TokenExpiredError, TokenInvalidError, get_token_data
from app.core.database import get_db from app.core.database import get_db
from app.models.user import User from app.models.user import User
from app.repositories.user import user_repo
# OAuth2 configuration # OAuth2 configuration
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
@@ -32,9 +32,8 @@ async def get_current_user(
# Decode token and get user ID # Decode token and get user ID
token_data = get_token_data(token) token_data = get_token_data(token)
# Get user from database # Get user from database via repository
result = await db.execute(select(User).where(User.id == token_data.user_id)) user = await user_repo.get(db, id=str(token_data.user_id))
user = result.scalar_one_or_none()
if not user: if not user:
raise HTTPException( raise HTTPException(
@@ -144,8 +143,7 @@ async def get_optional_current_user(
try: try:
token_data = get_token_data(token) token_data = get_token_data(token)
result = await db.execute(select(User).where(User.id == token_data.user_id)) user = await user_repo.get(db, id=str(token_data.user_id))
user = result.scalar_one_or_none()
if not user or not user.is_active: if not user or not user.is_active:
return None return None
return user return user

View File

@@ -15,9 +15,9 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.auth import get_current_user from app.api.dependencies.auth import get_current_user
from app.core.database import get_db from app.core.database import get_db
from app.crud.organization import organization as organization_crud
from app.models.user import User from app.models.user import User
from app.models.user_organization import OrganizationRole from app.models.user_organization import OrganizationRole
from app.services.organization_service import organization_service
def require_superuser(current_user: User = Depends(get_current_user)) -> User: def require_superuser(current_user: User = Depends(get_current_user)) -> User:
@@ -81,7 +81,7 @@ class OrganizationPermission:
return current_user return current_user
# Get user's role in organization # Get user's role in organization
user_role = await organization_crud.get_user_role_in_org( user_role = await organization_service.get_user_role_in_org(
db, user_id=current_user.id, organization_id=organization_id db, user_id=current_user.id, organization_id=organization_id
) )
@@ -123,7 +123,7 @@ async def require_org_membership(
if current_user.is_superuser: if current_user.is_superuser:
return current_user return current_user
user_role = await organization_crud.get_user_role_in_org( user_role = await organization_service.get_user_role_in_org(
db, user_id=current_user.id, organization_id=organization_id db, user_id=current_user.id, organization_id=organization_id
) )

View File

@@ -0,0 +1,41 @@
# app/api/dependencies/services.py
"""FastAPI dependency functions for service singletons."""
from app.services import oauth_provider_service
from app.services.auth_service import AuthService
from app.services.oauth_service import OAuthService
from app.services.organization_service import OrganizationService, organization_service
from app.services.session_service import SessionService, session_service
from app.services.user_service import UserService, user_service
def get_auth_service() -> AuthService:
"""Return the AuthService singleton for dependency injection."""
from app.services.auth_service import AuthService as _AuthService
return _AuthService()
def get_user_service() -> UserService:
"""Return the UserService singleton for dependency injection."""
return user_service
def get_organization_service() -> OrganizationService:
"""Return the OrganizationService singleton for dependency injection."""
return organization_service
def get_session_service() -> SessionService:
"""Return the SessionService singleton for dependency injection."""
return session_service
def get_oauth_service() -> OAuthService:
"""Return OAuthService for dependency injection."""
return OAuthService()
def get_oauth_provider_service():
"""Return the oauth_provider_service module for dependency injection."""
return oauth_provider_service

View File

@@ -14,7 +14,6 @@ from uuid import UUID
from fastapi import APIRouter, Depends, Query, status from fastapi import APIRouter, Depends, Query, status
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.permissions import require_superuser from app.api.dependencies.permissions import require_superuser
@@ -25,12 +24,9 @@ from app.core.exceptions import (
ErrorCode, ErrorCode,
NotFoundError, NotFoundError,
) )
from app.crud.organization import organization as organization_crud from app.core.repository_exceptions import DuplicateEntryError
from app.crud.session import session as session_crud
from app.crud.user import user as user_crud
from app.models.organization import Organization
from app.models.user import User from app.models.user import User
from app.models.user_organization import OrganizationRole, UserOrganization from app.models.user_organization import OrganizationRole
from app.schemas.common import ( from app.schemas.common import (
MessageResponse, MessageResponse,
PaginatedResponse, PaginatedResponse,
@@ -46,6 +42,9 @@ from app.schemas.organizations import (
) )
from app.schemas.sessions import AdminSessionResponse from app.schemas.sessions import AdminSessionResponse
from app.schemas.users import UserCreate, UserResponse, UserUpdate from app.schemas.users import UserCreate, UserResponse, UserUpdate
from app.services.organization_service import organization_service
from app.services.session_service import session_service
from app.services.user_service import user_service
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -66,7 +65,7 @@ class BulkUserAction(BaseModel):
action: BulkAction = Field(..., description="Action to perform on selected users") action: BulkAction = Field(..., description="Action to perform on selected users")
user_ids: list[UUID] = Field( user_ids: list[UUID] = Field(
..., min_items=1, max_items=100, description="List of user IDs (max 100)" ..., min_length=1, max_length=100, description="List of user IDs (max 100)"
) )
@@ -178,38 +177,29 @@ async def admin_get_stats(
"""Get admin dashboard statistics with real data from database.""" """Get admin dashboard statistics with real data from database."""
from app.core.config import settings from app.core.config import settings
# Check if we have any data stats = await user_service.get_stats(db)
total_users_query = select(func.count()).select_from(User) total_users = stats["total_users"]
total_users = (await db.execute(total_users_query)).scalar() or 0 active_count = stats["active_count"]
inactive_count = stats["inactive_count"]
all_users = stats["all_users"]
# If database is essentially empty (only admin user), return demo data # If database is essentially empty (only admin user), return demo data
if total_users <= 1 and settings.DEMO_MODE: # pragma: no cover if total_users <= 1 and settings.DEMO_MODE: # pragma: no cover
logger.info("Returning demo stats data (empty database in demo mode)") logger.info("Returning demo stats data (empty database in demo mode)")
return _generate_demo_stats() return _generate_demo_stats()
# 1. User Growth (Last 30 days) - Improved calculation # 1. User Growth (Last 30 days)
datetime.now(UTC) - timedelta(days=30)
# Get all users with their creation dates
all_users_query = select(User).order_by(User.created_at)
result = await db.execute(all_users_query)
all_users = result.scalars().all()
# Build cumulative counts per day
user_growth = [] user_growth = []
for i in range(29, -1, -1): for i in range(29, -1, -1):
date = datetime.now(UTC) - timedelta(days=i) date = datetime.now(UTC) - timedelta(days=i)
date_start = date.replace(hour=0, minute=0, second=0, microsecond=0, tzinfo=UTC) date_start = date.replace(hour=0, minute=0, second=0, microsecond=0, tzinfo=UTC)
date_end = date_start + timedelta(days=1) date_end = date_start + timedelta(days=1)
# Count all users created before end of this day
# Make comparison timezone-aware
total_users_on_date = sum( total_users_on_date = sum(
1 1
for u in all_users for u in all_users
if u.created_at and u.created_at.replace(tzinfo=UTC) < date_end if u.created_at and u.created_at.replace(tzinfo=UTC) < date_end
) )
# Count active users created before end of this day
active_users_on_date = sum( active_users_on_date = sum(
1 1
for u in all_users for u in all_users
@@ -227,27 +217,16 @@ async def admin_get_stats(
) )
# 2. Organization Distribution - Top 6 organizations by member count # 2. Organization Distribution - Top 6 organizations by member count
org_query = ( org_rows = await organization_service.get_org_distribution(db, limit=6)
select(Organization.name, func.count(UserOrganization.user_id).label("count")) org_dist = [OrgDistributionData(name=r["name"], value=r["value"]) for r in org_rows]
.join(UserOrganization, Organization.id == UserOrganization.organization_id)
.group_by(Organization.name)
.order_by(func.count(UserOrganization.user_id).desc())
.limit(6)
)
result = await db.execute(org_query)
org_dist = [
OrgDistributionData(name=row.name, value=row.count) for row in result.all()
]
# 3. User Registration Activity (Last 14 days) - NEW # 3. User Registration Activity (Last 14 days)
registration_activity = [] registration_activity = []
for i in range(13, -1, -1): for i in range(13, -1, -1):
date = datetime.now(UTC) - timedelta(days=i) date = datetime.now(UTC) - timedelta(days=i)
date_start = date.replace(hour=0, minute=0, second=0, microsecond=0, tzinfo=UTC) date_start = date.replace(hour=0, minute=0, second=0, microsecond=0, tzinfo=UTC)
date_end = date_start + timedelta(days=1) date_end = date_start + timedelta(days=1)
# Count users created on this specific day
# Make comparison timezone-aware
day_registrations = sum( day_registrations = sum(
1 1
for u in all_users for u in all_users
@@ -263,16 +242,8 @@ async def admin_get_stats(
) )
# 4. User Status - Active vs Inactive # 4. User Status - Active vs Inactive
active_query = select(func.count()).select_from(User).where(User.is_active)
inactive_query = (
select(func.count()).select_from(User).where(User.is_active.is_(False))
)
active_count = (await db.execute(active_query)).scalar() or 0
inactive_count = (await db.execute(inactive_query)).scalar() or 0
logger.info( logger.info(
f"User status counts - Active: {active_count}, Inactive: {inactive_count}" "User status counts - Active: %s, Inactive: %s", active_count, inactive_count
) )
user_status = [ user_status = [
@@ -321,7 +292,7 @@ async def admin_list_users(
filters["is_superuser"] = is_superuser filters["is_superuser"] = is_superuser
# Get users with search # Get users with search
users, total = await user_crud.get_multi_with_total( users, total = await user_service.list_users(
db, db,
skip=pagination.offset, skip=pagination.offset,
limit=pagination.limit, limit=pagination.limit,
@@ -341,7 +312,7 @@ async def admin_list_users(
return PaginatedResponse(data=users, pagination=pagination_meta) return PaginatedResponse(data=users, pagination=pagination_meta)
except Exception as e: except Exception as e:
logger.error(f"Error listing users (admin): {e!s}", exc_info=True) logger.exception("Error listing users (admin): %s", e)
raise raise
@@ -364,14 +335,14 @@ async def admin_create_user(
Allows setting is_superuser and other fields. Allows setting is_superuser and other fields.
""" """
try: try:
user = await user_crud.create(db, obj_in=user_in) user = await user_service.create_user(db, user_in)
logger.info(f"Admin {admin.email} created user {user.email}") logger.info("Admin %s created user %s", admin.email, user.email)
return user return user
except ValueError as e: except DuplicateEntryError as e:
logger.warning(f"Failed to create user: {e!s}") logger.warning("Failed to create user: %s", e)
raise NotFoundError(message=str(e), error_code=ErrorCode.USER_ALREADY_EXISTS) raise DuplicateError(message=str(e), error_code=ErrorCode.USER_ALREADY_EXISTS)
except Exception as e: except Exception as e:
logger.error(f"Error creating user (admin): {e!s}", exc_info=True) logger.exception("Error creating user (admin): %s", e)
raise raise
@@ -388,11 +359,7 @@ async def admin_get_user(
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
) -> Any: ) -> Any:
"""Get detailed information about a specific user.""" """Get detailed information about a specific user."""
user = await user_crud.get(db, id=user_id) user = await user_service.get_user(db, str(user_id))
if not user:
raise NotFoundError(
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
)
return user return user
@@ -411,20 +378,13 @@ async def admin_update_user(
) -> Any: ) -> Any:
"""Update user information with admin privileges.""" """Update user information with admin privileges."""
try: try:
user = await user_crud.get(db, id=user_id) user = await user_service.get_user(db, str(user_id))
if not user: updated_user = await user_service.update_user(db, user=user, obj_in=user_in)
raise NotFoundError( logger.info("Admin %s updated user %s", admin.email, updated_user.email)
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
)
updated_user = await user_crud.update(db, db_obj=user, obj_in=user_in)
logger.info(f"Admin {admin.email} updated user {updated_user.email}")
return updated_user return updated_user
except NotFoundError:
raise
except Exception as e: except Exception as e:
logger.error(f"Error updating user (admin): {e!s}", exc_info=True) logger.exception("Error updating user (admin): %s", e)
raise raise
@@ -442,11 +402,7 @@ async def admin_delete_user(
) -> Any: ) -> Any:
"""Soft delete a user (sets deleted_at timestamp).""" """Soft delete a user (sets deleted_at timestamp)."""
try: try:
user = await user_crud.get(db, id=user_id) user = await user_service.get_user(db, str(user_id))
if not user:
raise NotFoundError(
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
)
# Prevent deleting yourself # Prevent deleting yourself
if user.id == admin.id: if user.id == admin.id:
@@ -456,17 +412,15 @@ async def admin_delete_user(
error_code=ErrorCode.OPERATION_FORBIDDEN, error_code=ErrorCode.OPERATION_FORBIDDEN,
) )
await user_crud.soft_delete(db, id=user_id) await user_service.soft_delete_user(db, str(user_id))
logger.info(f"Admin {admin.email} deleted user {user.email}") logger.info("Admin %s deleted user %s", admin.email, user.email)
return MessageResponse( return MessageResponse(
success=True, message=f"User {user.email} has been deleted" success=True, message=f"User {user.email} has been deleted"
) )
except NotFoundError:
raise
except Exception as e: except Exception as e:
logger.error(f"Error deleting user (admin): {e!s}", exc_info=True) logger.exception("Error deleting user (admin): %s", e)
raise raise
@@ -484,23 +438,16 @@ async def admin_activate_user(
) -> Any: ) -> Any:
"""Activate a user account.""" """Activate a user account."""
try: try:
user = await user_crud.get(db, id=user_id) user = await user_service.get_user(db, str(user_id))
if not user: await user_service.update_user(db, user=user, obj_in={"is_active": True})
raise NotFoundError( logger.info("Admin %s activated user %s", admin.email, user.email)
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
)
await user_crud.update(db, db_obj=user, obj_in={"is_active": True})
logger.info(f"Admin {admin.email} activated user {user.email}")
return MessageResponse( return MessageResponse(
success=True, message=f"User {user.email} has been activated" success=True, message=f"User {user.email} has been activated"
) )
except NotFoundError:
raise
except Exception as e: except Exception as e:
logger.error(f"Error activating user (admin): {e!s}", exc_info=True) logger.exception("Error activating user (admin): %s", e)
raise raise
@@ -518,11 +465,7 @@ async def admin_deactivate_user(
) -> Any: ) -> Any:
"""Deactivate a user account.""" """Deactivate a user account."""
try: try:
user = await user_crud.get(db, id=user_id) user = await user_service.get_user(db, str(user_id))
if not user:
raise NotFoundError(
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
)
# Prevent deactivating yourself # Prevent deactivating yourself
if user.id == admin.id: if user.id == admin.id:
@@ -532,17 +475,15 @@ async def admin_deactivate_user(
error_code=ErrorCode.OPERATION_FORBIDDEN, error_code=ErrorCode.OPERATION_FORBIDDEN,
) )
await user_crud.update(db, db_obj=user, obj_in={"is_active": False}) await user_service.update_user(db, user=user, obj_in={"is_active": False})
logger.info(f"Admin {admin.email} deactivated user {user.email}") logger.info("Admin %s deactivated user %s", admin.email, user.email)
return MessageResponse( return MessageResponse(
success=True, message=f"User {user.email} has been deactivated" success=True, message=f"User {user.email} has been deactivated"
) )
except NotFoundError:
raise
except Exception as e: except Exception as e:
logger.error(f"Error deactivating user (admin): {e!s}", exc_info=True) logger.exception("Error deactivating user (admin): %s", e)
raise raise
@@ -567,16 +508,16 @@ async def admin_bulk_user_action(
try: try:
# Use efficient bulk operations instead of loop # Use efficient bulk operations instead of loop
if bulk_action.action == BulkAction.ACTIVATE: if bulk_action.action == BulkAction.ACTIVATE:
affected_count = await user_crud.bulk_update_status( affected_count = await user_service.bulk_update_status(
db, user_ids=bulk_action.user_ids, is_active=True db, user_ids=bulk_action.user_ids, is_active=True
) )
elif bulk_action.action == BulkAction.DEACTIVATE: elif bulk_action.action == BulkAction.DEACTIVATE:
affected_count = await user_crud.bulk_update_status( affected_count = await user_service.bulk_update_status(
db, user_ids=bulk_action.user_ids, is_active=False db, user_ids=bulk_action.user_ids, is_active=False
) )
elif bulk_action.action == BulkAction.DELETE: elif bulk_action.action == BulkAction.DELETE:
# bulk_soft_delete automatically excludes the admin user # bulk_soft_delete automatically excludes the admin user
affected_count = await user_crud.bulk_soft_delete( affected_count = await user_service.bulk_soft_delete(
db, user_ids=bulk_action.user_ids, exclude_user_id=admin.id db, user_ids=bulk_action.user_ids, exclude_user_id=admin.id
) )
else: # pragma: no cover else: # pragma: no cover
@@ -587,8 +528,11 @@ async def admin_bulk_user_action(
failed_count = requested_count - affected_count failed_count = requested_count - affected_count
logger.info( logger.info(
f"Admin {admin.email} performed bulk {bulk_action.action.value} " "Admin %s performed bulk %s on %s users (%s skipped/failed)",
f"on {affected_count} users ({failed_count} skipped/failed)" admin.email,
bulk_action.action.value,
affected_count,
failed_count,
) )
return BulkActionResult( return BulkActionResult(
@@ -600,7 +544,7 @@ async def admin_bulk_user_action(
) )
except Exception as e: # pragma: no cover except Exception as e: # pragma: no cover
logger.error(f"Error in bulk user action: {e!s}", exc_info=True) logger.exception("Error in bulk user action: %s", e)
raise raise
@@ -624,7 +568,7 @@ async def admin_list_organizations(
"""List all organizations with filtering and search.""" """List all organizations with filtering and search."""
try: try:
# Use optimized method that gets member counts in single query (no N+1) # Use optimized method that gets member counts in single query (no N+1)
orgs_with_data, total = await organization_crud.get_multi_with_member_counts( orgs_with_data, total = await organization_service.get_multi_with_member_counts(
db, db,
skip=pagination.offset, skip=pagination.offset,
limit=pagination.limit, limit=pagination.limit,
@@ -661,7 +605,7 @@ async def admin_list_organizations(
return PaginatedResponse(data=orgs_with_count, pagination=pagination_meta) return PaginatedResponse(data=orgs_with_count, pagination=pagination_meta)
except Exception as e: except Exception as e:
logger.error(f"Error listing organizations (admin): {e!s}", exc_info=True) logger.exception("Error listing organizations (admin): %s", e)
raise raise
@@ -680,8 +624,8 @@ async def admin_create_organization(
) -> Any: ) -> Any:
"""Create a new organization.""" """Create a new organization."""
try: try:
org = await organization_crud.create(db, obj_in=org_in) org = await organization_service.create_organization(db, obj_in=org_in)
logger.info(f"Admin {admin.email} created organization {org.name}") logger.info("Admin %s created organization %s", admin.email, org.name)
# Add member count # Add member count
org_dict = { org_dict = {
@@ -697,11 +641,11 @@ async def admin_create_organization(
} }
return OrganizationResponse(**org_dict) return OrganizationResponse(**org_dict)
except ValueError as e: except DuplicateEntryError as e:
logger.warning(f"Failed to create organization: {e!s}") logger.warning("Failed to create organization: %s", e)
raise NotFoundError(message=str(e), error_code=ErrorCode.ALREADY_EXISTS) raise DuplicateError(message=str(e), error_code=ErrorCode.ALREADY_EXISTS)
except Exception as e: except Exception as e:
logger.error(f"Error creating organization (admin): {e!s}", exc_info=True) logger.exception("Error creating organization (admin): %s", e)
raise raise
@@ -718,12 +662,7 @@ async def admin_get_organization(
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
) -> Any: ) -> Any:
"""Get detailed information about a specific organization.""" """Get detailed information about a specific organization."""
org = await organization_crud.get(db, id=org_id) org = await organization_service.get_organization(db, str(org_id))
if not org:
raise NotFoundError(
message=f"Organization {org_id} not found", error_code=ErrorCode.NOT_FOUND
)
org_dict = { org_dict = {
"id": org.id, "id": org.id,
"name": org.name, "name": org.name,
@@ -733,7 +672,7 @@ async def admin_get_organization(
"settings": org.settings, "settings": org.settings,
"created_at": org.created_at, "created_at": org.created_at,
"updated_at": org.updated_at, "updated_at": org.updated_at,
"member_count": await organization_crud.get_member_count( "member_count": await organization_service.get_member_count(
db, organization_id=org.id db, organization_id=org.id
), ),
} }
@@ -755,15 +694,11 @@ async def admin_update_organization(
) -> Any: ) -> Any:
"""Update organization information.""" """Update organization information."""
try: try:
org = await organization_crud.get(db, id=org_id) org = await organization_service.get_organization(db, str(org_id))
if not org: updated_org = await organization_service.update_organization(
raise NotFoundError( db, org=org, obj_in=org_in
message=f"Organization {org_id} not found", )
error_code=ErrorCode.NOT_FOUND, logger.info("Admin %s updated organization %s", admin.email, updated_org.name)
)
updated_org = await organization_crud.update(db, db_obj=org, obj_in=org_in)
logger.info(f"Admin {admin.email} updated organization {updated_org.name}")
org_dict = { org_dict = {
"id": updated_org.id, "id": updated_org.id,
@@ -774,16 +709,14 @@ async def admin_update_organization(
"settings": updated_org.settings, "settings": updated_org.settings,
"created_at": updated_org.created_at, "created_at": updated_org.created_at,
"updated_at": updated_org.updated_at, "updated_at": updated_org.updated_at,
"member_count": await organization_crud.get_member_count( "member_count": await organization_service.get_member_count(
db, organization_id=updated_org.id db, organization_id=updated_org.id
), ),
} }
return OrganizationResponse(**org_dict) return OrganizationResponse(**org_dict)
except NotFoundError:
raise
except Exception as e: except Exception as e:
logger.error(f"Error updating organization (admin): {e!s}", exc_info=True) logger.exception("Error updating organization (admin): %s", e)
raise raise
@@ -801,24 +734,16 @@ async def admin_delete_organization(
) -> Any: ) -> Any:
"""Delete an organization and all its relationships.""" """Delete an organization and all its relationships."""
try: try:
org = await organization_crud.get(db, id=org_id) org = await organization_service.get_organization(db, str(org_id))
if not org: await organization_service.remove_organization(db, str(org_id))
raise NotFoundError( logger.info("Admin %s deleted organization %s", admin.email, org.name)
message=f"Organization {org_id} not found",
error_code=ErrorCode.NOT_FOUND,
)
await organization_crud.remove(db, id=org_id)
logger.info(f"Admin {admin.email} deleted organization {org.name}")
return MessageResponse( return MessageResponse(
success=True, message=f"Organization {org.name} has been deleted" success=True, message=f"Organization {org.name} has been deleted"
) )
except NotFoundError:
raise
except Exception as e: except Exception as e:
logger.error(f"Error deleting organization (admin): {e!s}", exc_info=True) logger.exception("Error deleting organization (admin): %s", e)
raise raise
@@ -838,14 +763,8 @@ async def admin_list_organization_members(
) -> Any: ) -> Any:
"""List all members of an organization.""" """List all members of an organization."""
try: try:
org = await organization_crud.get(db, id=org_id) await organization_service.get_organization(db, str(org_id)) # validates exists
if not org: members, total = await organization_service.get_organization_members(
raise NotFoundError(
message=f"Organization {org_id} not found",
error_code=ErrorCode.NOT_FOUND,
)
members, total = await organization_crud.get_organization_members(
db, db,
organization_id=org_id, organization_id=org_id,
skip=pagination.offset, skip=pagination.offset,
@@ -868,9 +787,7 @@ async def admin_list_organization_members(
except NotFoundError: except NotFoundError:
raise raise
except Exception as e: except Exception as e:
logger.error( logger.exception("Error listing organization members (admin): %s", e)
f"Error listing organization members (admin): {e!s}", exc_info=True
)
raise raise
@@ -898,45 +815,32 @@ async def admin_add_organization_member(
) -> Any: ) -> Any:
"""Add a user to an organization.""" """Add a user to an organization."""
try: try:
org = await organization_crud.get(db, id=org_id) org = await organization_service.get_organization(db, str(org_id))
if not org: user = await user_service.get_user(db, str(request.user_id))
raise NotFoundError(
message=f"Organization {org_id} not found",
error_code=ErrorCode.NOT_FOUND,
)
user = await user_crud.get(db, id=request.user_id) await organization_service.add_member(
if not user:
raise NotFoundError(
message=f"User {request.user_id} not found",
error_code=ErrorCode.USER_NOT_FOUND,
)
await organization_crud.add_user(
db, organization_id=org_id, user_id=request.user_id, role=request.role db, organization_id=org_id, user_id=request.user_id, role=request.role
) )
logger.info( logger.info(
f"Admin {admin.email} added user {user.email} to organization {org.name} " "Admin %s added user %s to organization %s with role %s",
f"with role {request.role.value}" admin.email,
user.email,
org.name,
request.role.value,
) )
return MessageResponse( return MessageResponse(
success=True, message=f"User {user.email} added to organization {org.name}" success=True, message=f"User {user.email} added to organization {org.name}"
) )
except ValueError as e: except DuplicateEntryError as e:
logger.warning(f"Failed to add user to organization: {e!s}") logger.warning("Failed to add user to organization: %s", e)
# Use DuplicateError for "already exists" scenarios
raise DuplicateError( raise DuplicateError(
message=str(e), error_code=ErrorCode.USER_ALREADY_EXISTS, field="user_id" message=str(e), error_code=ErrorCode.USER_ALREADY_EXISTS, field="user_id"
) )
except NotFoundError:
raise
except Exception as e: except Exception as e:
logger.error( logger.exception("Error adding member to organization (admin): %s", e)
f"Error adding member to organization (admin): {e!s}", exc_info=True
)
raise raise
@@ -955,20 +859,10 @@ async def admin_remove_organization_member(
) -> Any: ) -> Any:
"""Remove a user from an organization.""" """Remove a user from an organization."""
try: try:
org = await organization_crud.get(db, id=org_id) org = await organization_service.get_organization(db, str(org_id))
if not org: user = await user_service.get_user(db, str(user_id))
raise NotFoundError(
message=f"Organization {org_id} not found",
error_code=ErrorCode.NOT_FOUND,
)
user = await user_crud.get(db, id=user_id) success = await organization_service.remove_member(
if not user:
raise NotFoundError(
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
)
success = await organization_crud.remove_user(
db, organization_id=org_id, user_id=user_id db, organization_id=org_id, user_id=user_id
) )
@@ -979,7 +873,10 @@ async def admin_remove_organization_member(
) )
logger.info( logger.info(
f"Admin {admin.email} removed user {user.email} from organization {org.name}" "Admin %s removed user %s from organization %s",
admin.email,
user.email,
org.name,
) )
return MessageResponse( return MessageResponse(
@@ -990,9 +887,7 @@ async def admin_remove_organization_member(
except NotFoundError: except NotFoundError:
raise raise
except Exception as e: # pragma: no cover except Exception as e: # pragma: no cover
logger.error( logger.exception("Error removing member from organization (admin): %s", e)
f"Error removing member from organization (admin): {e!s}", exc_info=True
)
raise raise
@@ -1022,7 +917,7 @@ async def admin_list_sessions(
"""List all sessions across all users with filtering and pagination.""" """List all sessions across all users with filtering and pagination."""
try: try:
# Get sessions with user info (eager loaded to prevent N+1) # Get sessions with user info (eager loaded to prevent N+1)
sessions, total = await session_crud.get_all_sessions( sessions, total = await session_service.get_all_sessions(
db, db,
skip=pagination.offset, skip=pagination.offset,
limit=pagination.limit, limit=pagination.limit,
@@ -1061,7 +956,10 @@ async def admin_list_sessions(
session_responses.append(session_response) session_responses.append(session_response)
logger.info( logger.info(
f"Admin {admin.email} listed {len(session_responses)} sessions (total: {total})" "Admin %s listed %s sessions (total: %s)",
admin.email,
len(session_responses),
total,
) )
pagination_meta = create_pagination_meta( pagination_meta = create_pagination_meta(
@@ -1074,5 +972,5 @@ async def admin_list_sessions(
return PaginatedResponse(data=session_responses, pagination=pagination_meta) return PaginatedResponse(data=session_responses, pagination=pagination_meta)
except Exception as e: # pragma: no cover except Exception as e: # pragma: no cover
logger.error(f"Error listing sessions (admin): {e!s}", exc_info=True) logger.exception("Error listing sessions (admin): %s", e)
raise raise

View File

@@ -15,16 +15,14 @@ from app.core.auth import (
TokenExpiredError, TokenExpiredError,
TokenInvalidError, TokenInvalidError,
decode_token, decode_token,
get_password_hash,
) )
from app.core.database import get_db from app.core.database import get_db
from app.core.exceptions import ( from app.core.exceptions import (
AuthenticationError as AuthError, AuthenticationError as AuthError,
DatabaseError, DatabaseError,
DuplicateError,
ErrorCode, ErrorCode,
) )
from app.crud.session import session as session_crud
from app.crud.user import user as user_crud
from app.models.user import User from app.models.user import User
from app.schemas.common import MessageResponse from app.schemas.common import MessageResponse
from app.schemas.sessions import LogoutRequest, SessionCreate from app.schemas.sessions import LogoutRequest, SessionCreate
@@ -39,6 +37,8 @@ from app.schemas.users import (
) )
from app.services.auth_service import AuthenticationError, AuthService from app.services.auth_service import AuthenticationError, AuthService
from app.services.email_service import email_service from app.services.email_service import email_service
from app.services.session_service import session_service
from app.services.user_service import user_service
from app.utils.device import extract_device_info from app.utils.device import extract_device_info
from app.utils.security import create_password_reset_token, verify_password_reset_token from app.utils.security import create_password_reset_token, verify_password_reset_token
@@ -91,17 +91,18 @@ async def _create_login_session(
location_country=device_info.location_country, location_country=device_info.location_country,
) )
await session_crud.create_session(db, obj_in=session_data) await session_service.create_session(db, obj_in=session_data)
logger.info( logger.info(
f"{login_type.capitalize()} successful: {user.email} from {device_info.device_name} " "%s successful: %s from %s (IP: %s)",
f"(IP: {device_info.ip_address})" login_type.capitalize(),
user.email,
device_info.device_name,
device_info.ip_address,
) )
except Exception as session_err: except Exception as session_err:
# Log but don't fail login if session creation fails # Log but don't fail login if session creation fails
logger.error( logger.exception("Failed to create session for %s: %s", user.email, session_err)
f"Failed to create session for {user.email}: {session_err!s}", exc_info=True
)
@router.post( @router.post(
@@ -123,15 +124,21 @@ async def register_user(
try: try:
user = await AuthService.create_user(db, user_data) user = await AuthService.create_user(db, user_data)
return user return user
except AuthenticationError as e: except DuplicateError:
# SECURITY: Don't reveal if email exists - generic error message # SECURITY: Don't reveal if email exists - generic error message
logger.warning(f"Registration failed: {e!s}") logger.warning("Registration failed: duplicate email %s", user_data.email)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Registration failed. Please check your information and try again.",
)
except AuthError as e:
logger.warning("Registration failed: %s", e)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail="Registration failed. Please check your information and try again.", detail="Registration failed. Please check your information and try again.",
) )
except Exception as e: except Exception as e:
logger.error(f"Unexpected error during registration: {e!s}", exc_info=True) logger.exception("Unexpected error during registration: %s", e)
raise DatabaseError( raise DatabaseError(
message="An unexpected error occurred. Please try again later.", message="An unexpected error occurred. Please try again later.",
error_code=ErrorCode.INTERNAL_ERROR, error_code=ErrorCode.INTERNAL_ERROR,
@@ -159,7 +166,7 @@ async def login(
# Explicitly check for None result and raise correct exception # Explicitly check for None result and raise correct exception
if user is None: if user is None:
logger.warning(f"Invalid login attempt for: {login_data.email}") logger.warning("Invalid login attempt for: %s", login_data.email)
raise AuthError( raise AuthError(
message="Invalid email or password", message="Invalid email or password",
error_code=ErrorCode.INVALID_CREDENTIALS, error_code=ErrorCode.INVALID_CREDENTIALS,
@@ -175,14 +182,11 @@ async def login(
except AuthenticationError as e: except AuthenticationError as e:
# Handle specific authentication errors like inactive accounts # Handle specific authentication errors like inactive accounts
logger.warning(f"Authentication failed: {e!s}") logger.warning("Authentication failed: %s", e)
raise AuthError(message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS) raise AuthError(message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS)
except AuthError:
# Re-raise custom auth exceptions without modification
raise
except Exception as e: except Exception as e:
# Handle unexpected errors # Handle unexpected errors
logger.error(f"Unexpected error during login: {e!s}", exc_info=True) logger.exception("Unexpected error during login: %s", e)
raise DatabaseError( raise DatabaseError(
message="An unexpected error occurred. Please try again later.", message="An unexpected error occurred. Please try again later.",
error_code=ErrorCode.INTERNAL_ERROR, error_code=ErrorCode.INTERNAL_ERROR,
@@ -224,13 +228,10 @@ async def login_oauth(
# Return full token response with user data # Return full token response with user data
return tokens return tokens
except AuthenticationError as e: except AuthenticationError as e:
logger.warning(f"OAuth authentication failed: {e!s}") logger.warning("OAuth authentication failed: %s", e)
raise AuthError(message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS) raise AuthError(message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS)
except AuthError:
# Re-raise custom auth exceptions without modification
raise
except Exception as e: except Exception as e:
logger.error(f"Unexpected error during OAuth login: {e!s}", exc_info=True) logger.exception("Unexpected error during OAuth login: %s", e)
raise DatabaseError( raise DatabaseError(
message="An unexpected error occurred. Please try again later.", message="An unexpected error occurred. Please try again later.",
error_code=ErrorCode.INTERNAL_ERROR, error_code=ErrorCode.INTERNAL_ERROR,
@@ -259,11 +260,12 @@ async def refresh_token(
) )
# Check if session exists and is active # Check if session exists and is active
session = await session_crud.get_active_by_jti(db, jti=refresh_payload.jti) session = await session_service.get_active_by_jti(db, jti=refresh_payload.jti)
if not session: if not session:
logger.warning( logger.warning(
f"Refresh token used for inactive or non-existent session: {refresh_payload.jti}" "Refresh token used for inactive or non-existent session: %s",
refresh_payload.jti,
) )
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@@ -279,16 +281,14 @@ async def refresh_token(
# Update session with new refresh token JTI and expiration # Update session with new refresh token JTI and expiration
try: try:
await session_crud.update_refresh_token( await session_service.update_refresh_token(
db, db,
session=session, session=session,
new_jti=new_refresh_payload.jti, new_jti=new_refresh_payload.jti,
new_expires_at=datetime.fromtimestamp(new_refresh_payload.exp, tz=UTC), new_expires_at=datetime.fromtimestamp(new_refresh_payload.exp, tz=UTC),
) )
except Exception as session_err: except Exception as session_err:
logger.error( logger.exception("Failed to update session %s: %s", session.id, session_err)
f"Failed to update session {session.id}: {session_err!s}", exc_info=True
)
# Continue anyway - tokens are already issued # Continue anyway - tokens are already issued
return tokens return tokens
@@ -311,7 +311,7 @@ async def refresh_token(
# Re-raise HTTP exceptions (like session revoked) # Re-raise HTTP exceptions (like session revoked)
raise raise
except Exception as e: except Exception as e:
logger.error(f"Unexpected error during token refresh: {e!s}") logger.error("Unexpected error during token refresh: %s", e)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An unexpected error occurred. Please try again later.", detail="An unexpected error occurred. Please try again later.",
@@ -347,7 +347,7 @@ async def request_password_reset(
""" """
try: try:
# Look up user by email # Look up user by email
user = await user_crud.get_by_email(db, email=reset_request.email) user = await user_service.get_by_email(db, email=reset_request.email)
# Only send email if user exists and is active # Only send email if user exists and is active
if user and user.is_active: if user and user.is_active:
@@ -358,11 +358,12 @@ async def request_password_reset(
await email_service.send_password_reset_email( await email_service.send_password_reset_email(
to_email=user.email, reset_token=reset_token, user_name=user.first_name to_email=user.email, reset_token=reset_token, user_name=user.first_name
) )
logger.info(f"Password reset requested for {user.email}") logger.info("Password reset requested for %s", user.email)
else: else:
# Log attempt but don't reveal if email exists # Log attempt but don't reveal if email exists
logger.warning( logger.warning(
f"Password reset requested for non-existent or inactive email: {reset_request.email}" "Password reset requested for non-existent or inactive email: %s",
reset_request.email,
) )
# Always return success to prevent email enumeration # Always return success to prevent email enumeration
@@ -371,7 +372,7 @@ async def request_password_reset(
message="If your email is registered, you will receive a password reset link shortly", message="If your email is registered, you will receive a password reset link shortly",
) )
except Exception as e: except Exception as e:
logger.error(f"Error processing password reset request: {e!s}", exc_info=True) logger.exception("Error processing password reset request: %s", e)
# Still return success to prevent information leakage # Still return success to prevent information leakage
return MessageResponse( return MessageResponse(
success=True, success=True,
@@ -412,40 +413,34 @@ async def confirm_password_reset(
detail="Invalid or expired password reset token", detail="Invalid or expired password reset token",
) )
# Look up user # Reset password via service (validates user exists and is active)
user = await user_crud.get_by_email(db, email=email) try:
user = await AuthService.reset_password(
if not user: db, email=email, new_password=reset_confirm.new_password
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
) )
except AuthenticationError as e:
if not user.is_active: err_msg = str(e)
raise HTTPException( if "inactive" in err_msg.lower():
status_code=status.HTTP_400_BAD_REQUEST, raise HTTPException(
detail="User account is inactive", status_code=status.HTTP_400_BAD_REQUEST, detail=err_msg
) )
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=err_msg)
# Update password
user.password_hash = get_password_hash(reset_confirm.new_password)
db.add(user)
await db.commit()
# SECURITY: Invalidate all existing sessions after password reset # SECURITY: Invalidate all existing sessions after password reset
# This prevents stolen sessions from being used after password change # This prevents stolen sessions from being used after password change
from app.crud.session import session as session_crud
try: try:
deactivated_count = await session_crud.deactivate_all_user_sessions( deactivated_count = await session_service.deactivate_all_user_sessions(
db, user_id=str(user.id) db, user_id=str(user.id)
) )
logger.info( logger.info(
f"Password reset successful for {user.email}, invalidated {deactivated_count} sessions" "Password reset successful for %s, invalidated %s sessions",
user.email,
deactivated_count,
) )
except Exception as session_error: except Exception as session_error:
# Log but don't fail password reset if session invalidation fails # Log but don't fail password reset if session invalidation fails
logger.error( logger.error(
f"Failed to invalidate sessions after password reset: {session_error!s}" "Failed to invalidate sessions after password reset: %s", session_error
) )
return MessageResponse( return MessageResponse(
@@ -456,7 +451,7 @@ async def confirm_password_reset(
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
logger.error(f"Error confirming password reset: {e!s}", exc_info=True) logger.exception("Error confirming password reset: %s", e)
await db.rollback() await db.rollback()
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
@@ -506,19 +501,21 @@ async def logout(
) )
except (TokenExpiredError, TokenInvalidError) as e: except (TokenExpiredError, TokenInvalidError) as e:
# Even if token is expired/invalid, try to deactivate session # Even if token is expired/invalid, try to deactivate session
logger.warning(f"Logout with invalid/expired token: {e!s}") logger.warning("Logout with invalid/expired token: %s", e)
# Don't fail - return success anyway # Don't fail - return success anyway
return MessageResponse(success=True, message="Logged out successfully") return MessageResponse(success=True, message="Logged out successfully")
# Find the session by JTI # Find the session by JTI
session = await session_crud.get_by_jti(db, jti=refresh_payload.jti) session = await session_service.get_by_jti(db, jti=refresh_payload.jti)
if session: if session:
# Verify session belongs to current user (security check) # Verify session belongs to current user (security check)
if str(session.user_id) != str(current_user.id): if str(session.user_id) != str(current_user.id):
logger.warning( logger.warning(
f"User {current_user.id} attempted to logout session {session.id} " "User %s attempted to logout session %s belonging to user %s",
f"belonging to user {session.user_id}" current_user.id,
session.id,
session.user_id,
) )
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
@@ -526,17 +523,20 @@ async def logout(
) )
# Deactivate the session # Deactivate the session
await session_crud.deactivate(db, session_id=str(session.id)) await session_service.deactivate(db, session_id=str(session.id))
logger.info( logger.info(
f"User {current_user.id} logged out from {session.device_name} " "User %s logged out from %s (session %s)",
f"(session {session.id})" current_user.id,
session.device_name,
session.id,
) )
else: else:
# Session not found - maybe already deleted or never existed # Session not found - maybe already deleted or never existed
# Return success anyway (idempotent) # Return success anyway (idempotent)
logger.info( logger.info(
f"Logout requested for non-existent session (JTI: {refresh_payload.jti})" "Logout requested for non-existent session (JTI: %s)",
refresh_payload.jti,
) )
return MessageResponse(success=True, message="Logged out successfully") return MessageResponse(success=True, message="Logged out successfully")
@@ -544,9 +544,7 @@ async def logout(
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
logger.error( logger.exception("Error during logout for user %s: %s", current_user.id, e)
f"Error during logout for user {current_user.id}: {e!s}", exc_info=True
)
# Don't expose error details # Don't expose error details
return MessageResponse(success=True, message="Logged out successfully") return MessageResponse(success=True, message="Logged out successfully")
@@ -584,12 +582,12 @@ async def logout_all(
""" """
try: try:
# Deactivate all sessions for this user # Deactivate all sessions for this user
count = await session_crud.deactivate_all_user_sessions( count = await session_service.deactivate_all_user_sessions(
db, user_id=str(current_user.id) db, user_id=str(current_user.id)
) )
logger.info( logger.info(
f"User {current_user.id} logged out from all devices ({count} sessions)" "User %s logged out from all devices (%s sessions)", current_user.id, count
) )
return MessageResponse( return MessageResponse(
@@ -598,9 +596,7 @@ async def logout_all(
) )
except Exception as e: except Exception as e:
logger.error( logger.exception("Error during logout-all for user %s: %s", current_user.id, e)
f"Error during logout-all for user {current_user.id}: {e!s}", exc_info=True
)
await db.rollback() await db.rollback()
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,

View File

@@ -25,8 +25,6 @@ from app.core.auth import decode_token
from app.core.config import settings from app.core.config import settings
from app.core.database import get_db from app.core.database import get_db
from app.core.exceptions import AuthenticationError as AuthError from app.core.exceptions import AuthenticationError as AuthError
from app.crud import oauth_account
from app.crud.session import session as session_crud
from app.models.user import User from app.models.user import User
from app.schemas.oauth import ( from app.schemas.oauth import (
OAuthAccountsListResponse, OAuthAccountsListResponse,
@@ -38,6 +36,7 @@ from app.schemas.oauth import (
from app.schemas.sessions import SessionCreate from app.schemas.sessions import SessionCreate
from app.schemas.users import Token from app.schemas.users import Token
from app.services.oauth_service import OAuthService from app.services.oauth_service import OAuthService
from app.services.session_service import session_service
from app.utils.device import extract_device_info from app.utils.device import extract_device_info
router = APIRouter() router = APIRouter()
@@ -82,17 +81,19 @@ async def _create_oauth_login_session(
location_country=device_info.location_country, location_country=device_info.location_country,
) )
await session_crud.create_session(db, obj_in=session_data) await session_service.create_session(db, obj_in=session_data)
logger.info( logger.info(
f"OAuth login successful: {user.email} via {provider} " "OAuth login successful: %s via %s from %s (IP: %s)",
f"from {device_info.device_name} (IP: {device_info.ip_address})" user.email,
provider,
device_info.device_name,
device_info.ip_address,
) )
except Exception as session_err: except Exception as session_err:
# Log but don't fail login if session creation fails # Log but don't fail login if session creation fails
logger.error( logger.exception(
f"Failed to create session for OAuth login {user.email}: {session_err!s}", "Failed to create session for OAuth login %s: %s", user.email, session_err
exc_info=True,
) )
@@ -177,13 +178,13 @@ async def get_authorization_url(
} }
except AuthError as e: except AuthError as e:
logger.warning(f"OAuth authorization failed: {e!s}") logger.warning("OAuth authorization failed: %s", e)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e), detail=str(e),
) )
except Exception as e: except Exception as e:
logger.error(f"OAuth authorization error: {e!s}", exc_info=True) logger.exception("OAuth authorization error: %s", e)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to create authorization URL", detail="Failed to create authorization URL",
@@ -251,13 +252,13 @@ async def handle_callback(
return result return result
except AuthError as e: except AuthError as e:
logger.warning(f"OAuth callback failed: {e!s}") logger.warning("OAuth callback failed: %s", e)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=str(e), detail=str(e),
) )
except Exception as e: except Exception as e:
logger.error(f"OAuth callback error: {e!s}", exc_info=True) logger.exception("OAuth callback error: %s", e)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="OAuth authentication failed", detail="OAuth authentication failed",
@@ -289,7 +290,7 @@ async def list_accounts(
Returns: Returns:
List of linked OAuth accounts List of linked OAuth accounts
""" """
accounts = await oauth_account.get_user_accounts(db, user_id=current_user.id) accounts = await OAuthService.get_user_accounts(db, user_id=current_user.id)
return OAuthAccountsListResponse(accounts=accounts) return OAuthAccountsListResponse(accounts=accounts)
@@ -338,13 +339,13 @@ async def unlink_account(
) )
except AuthError as e: except AuthError as e:
logger.warning(f"OAuth unlink failed for {current_user.email}: {e!s}") logger.warning("OAuth unlink failed for %s: %s", current_user.email, e)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e), detail=str(e),
) )
except Exception as e: except Exception as e:
logger.error(f"OAuth unlink error: {e!s}", exc_info=True) logger.exception("OAuth unlink error: %s", e)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to unlink OAuth account", detail="Failed to unlink OAuth account",
@@ -397,7 +398,7 @@ async def start_link(
) )
# Check if user already has this provider linked # Check if user already has this provider linked
existing = await oauth_account.get_user_account_by_provider( existing = await OAuthService.get_user_account_by_provider(
db, user_id=current_user.id, provider=provider db, user_id=current_user.id, provider=provider
) )
if existing: if existing:
@@ -420,13 +421,13 @@ async def start_link(
} }
except AuthError as e: except AuthError as e:
logger.warning(f"OAuth link authorization failed: {e!s}") logger.warning("OAuth link authorization failed: %s", e)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e), detail=str(e),
) )
except Exception as e: except Exception as e:
logger.error(f"OAuth link error: {e!s}", exc_info=True) logger.exception("OAuth link error: %s", e)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to create authorization URL", detail="Failed to create authorization URL",

View File

@@ -34,7 +34,6 @@ from app.api.dependencies.auth import (
) )
from app.core.config import settings from app.core.config import settings
from app.core.database import get_db from app.core.database import get_db
from app.crud import oauth_client as oauth_client_crud
from app.models.user import User from app.models.user import User
from app.schemas.oauth import ( from app.schemas.oauth import (
OAuthClientCreate, OAuthClientCreate,
@@ -453,7 +452,7 @@ async def token(
except Exception as e: except Exception as e:
# Log malformed Basic auth for security monitoring # Log malformed Basic auth for security monitoring
logger.warning( logger.warning(
f"Malformed Basic auth header in token request: {type(e).__name__}" "Malformed Basic auth header in token request: %s", type(e).__name__
) )
# Fall back to form body # Fall back to form body
@@ -564,7 +563,8 @@ async def revoke(
except Exception as e: except Exception as e:
# Log malformed Basic auth for security monitoring # Log malformed Basic auth for security monitoring
logger.warning( logger.warning(
f"Malformed Basic auth header in revoke request: {type(e).__name__}" "Malformed Basic auth header in revoke request: %s",
type(e).__name__,
) )
# Fall back to form body # Fall back to form body
@@ -586,7 +586,7 @@ async def revoke(
) )
except Exception as e: except Exception as e:
# Log but don't expose errors per RFC 7009 # Log but don't expose errors per RFC 7009
logger.warning(f"Token revocation error: {e}") logger.warning("Token revocation error: %s", e)
# Always return 200 OK per RFC 7009 # Always return 200 OK per RFC 7009
return {"status": "ok"} return {"status": "ok"}
@@ -635,7 +635,8 @@ async def introspect(
except Exception as e: except Exception as e:
# Log malformed Basic auth for security monitoring # Log malformed Basic auth for security monitoring
logger.warning( logger.warning(
f"Malformed Basic auth header in introspect request: {type(e).__name__}" "Malformed Basic auth header in introspect request: %s",
type(e).__name__,
) )
# Fall back to form body # Fall back to form body
@@ -655,8 +656,8 @@ async def introspect(
headers={"WWW-Authenticate": "Basic"}, headers={"WWW-Authenticate": "Basic"},
) )
except Exception as e: except Exception as e:
logger.warning(f"Token introspection error: {e}") logger.warning("Token introspection error: %s", e)
return OAuthTokenIntrospectionResponse(active=False) return OAuthTokenIntrospectionResponse(active=False) # pyright: ignore[reportCallIssue]
# ============================================================================ # ============================================================================
@@ -712,7 +713,7 @@ async def register_client(
client_type=client_type, client_type=client_type,
) )
client, secret = await oauth_client_crud.create_client(db, obj_in=client_data) client, secret = await provider_service.register_client(db, client_data)
# Update MCP server URL if provided # Update MCP server URL if provided
if mcp_server_url: if mcp_server_url:
@@ -750,7 +751,7 @@ async def list_clients(
current_user: User = Depends(get_current_superuser), current_user: User = Depends(get_current_superuser),
) -> list[OAuthClientResponse]: ) -> list[OAuthClientResponse]:
"""List all OAuth clients.""" """List all OAuth clients."""
clients = await oauth_client_crud.get_all_clients(db) clients = await provider_service.list_clients(db)
return [OAuthClientResponse.model_validate(c) for c in clients] return [OAuthClientResponse.model_validate(c) for c in clients]
@@ -776,7 +777,7 @@ async def delete_client(
detail="Client not found", detail="Client not found",
) )
await oauth_client_crud.delete_client(db, client_id=client_id) await provider_service.delete_client_by_id(db, client_id=client_id)
# ============================================================================ # ============================================================================
@@ -797,30 +798,7 @@ async def list_my_consents(
current_user: User = Depends(get_current_active_user), current_user: User = Depends(get_current_active_user),
) -> list[dict]: ) -> list[dict]:
"""List applications the user has authorized.""" """List applications the user has authorized."""
from sqlalchemy import select return await provider_service.list_user_consents(db, user_id=current_user.id)
from app.models.oauth_client import OAuthClient
from app.models.oauth_provider_token import OAuthConsent
result = await db.execute(
select(OAuthConsent, OAuthClient)
.join(OAuthClient, OAuthConsent.client_id == OAuthClient.client_id)
.where(OAuthConsent.user_id == current_user.id)
)
rows = result.all()
return [
{
"client_id": consent.client_id,
"client_name": client.client_name,
"client_description": client.client_description,
"granted_scopes": consent.granted_scopes.split()
if consent.granted_scopes
else [],
"granted_at": consent.created_at.isoformat(),
}
for consent, client in rows
]
@router.delete( @router.delete(

View File

@@ -15,8 +15,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.auth import get_current_user from app.api.dependencies.auth import get_current_user
from app.api.dependencies.permissions import require_org_admin, require_org_membership from app.api.dependencies.permissions import require_org_admin, require_org_membership
from app.core.database import get_db from app.core.database import get_db
from app.core.exceptions import ErrorCode, NotFoundError
from app.crud.organization import organization as organization_crud
from app.models.user import User from app.models.user import User
from app.schemas.common import ( from app.schemas.common import (
PaginatedResponse, PaginatedResponse,
@@ -28,6 +26,7 @@ from app.schemas.organizations import (
OrganizationResponse, OrganizationResponse,
OrganizationUpdate, OrganizationUpdate,
) )
from app.services.organization_service import organization_service
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -54,7 +53,7 @@ async def get_my_organizations(
""" """
try: try:
# Get all org data in single query with JOIN and subquery # Get all org data in single query with JOIN and subquery
orgs_data = await organization_crud.get_user_organizations_with_details( orgs_data = await organization_service.get_user_organizations_with_details(
db, user_id=current_user.id, is_active=is_active db, user_id=current_user.id, is_active=is_active
) )
@@ -78,7 +77,7 @@ async def get_my_organizations(
return orgs_with_data return orgs_with_data
except Exception as e: except Exception as e:
logger.error(f"Error getting user organizations: {e!s}", exc_info=True) logger.exception("Error getting user organizations: %s", e)
raise raise
@@ -100,13 +99,7 @@ async def get_organization(
User must be a member of the organization. User must be a member of the organization.
""" """
try: try:
org = await organization_crud.get(db, id=organization_id) org = await organization_service.get_organization(db, str(organization_id))
if not org: # pragma: no cover - Permission check prevents this (see docs/UNREACHABLE_DEFENSIVE_CODE_ANALYSIS.md)
raise NotFoundError(
detail=f"Organization {organization_id} not found",
error_code=ErrorCode.NOT_FOUND,
)
org_dict = { org_dict = {
"id": org.id, "id": org.id,
"name": org.name, "name": org.name,
@@ -116,16 +109,14 @@ async def get_organization(
"settings": org.settings, "settings": org.settings,
"created_at": org.created_at, "created_at": org.created_at,
"updated_at": org.updated_at, "updated_at": org.updated_at,
"member_count": await organization_crud.get_member_count( "member_count": await organization_service.get_member_count(
db, organization_id=org.id db, organization_id=org.id
), ),
} }
return OrganizationResponse(**org_dict) return OrganizationResponse(**org_dict)
except NotFoundError: # pragma: no cover - See above
raise
except Exception as e: except Exception as e:
logger.error(f"Error getting organization: {e!s}", exc_info=True) logger.exception("Error getting organization: %s", e)
raise raise
@@ -149,7 +140,7 @@ async def get_organization_members(
User must be a member of the organization to view members. User must be a member of the organization to view members.
""" """
try: try:
members, total = await organization_crud.get_organization_members( members, total = await organization_service.get_organization_members(
db, db,
organization_id=organization_id, organization_id=organization_id,
skip=pagination.offset, skip=pagination.offset,
@@ -169,7 +160,7 @@ async def get_organization_members(
return PaginatedResponse(data=member_responses, pagination=pagination_meta) return PaginatedResponse(data=member_responses, pagination=pagination_meta)
except Exception as e: except Exception as e:
logger.error(f"Error getting organization members: {e!s}", exc_info=True) logger.exception("Error getting organization members: %s", e)
raise raise
@@ -192,16 +183,12 @@ async def update_organization(
Requires owner or admin role in the organization. Requires owner or admin role in the organization.
""" """
try: try:
org = await organization_crud.get(db, id=organization_id) org = await organization_service.get_organization(db, str(organization_id))
if not org: # pragma: no cover - Permission check prevents this (see docs/UNREACHABLE_DEFENSIVE_CODE_ANALYSIS.md) updated_org = await organization_service.update_organization(
raise NotFoundError( db, org=org, obj_in=org_in
detail=f"Organization {organization_id} not found", )
error_code=ErrorCode.NOT_FOUND,
)
updated_org = await organization_crud.update(db, db_obj=org, obj_in=org_in)
logger.info( logger.info(
f"User {current_user.email} updated organization {updated_org.name}" "User %s updated organization %s", current_user.email, updated_org.name
) )
org_dict = { org_dict = {
@@ -213,14 +200,12 @@ async def update_organization(
"settings": updated_org.settings, "settings": updated_org.settings,
"created_at": updated_org.created_at, "created_at": updated_org.created_at,
"updated_at": updated_org.updated_at, "updated_at": updated_org.updated_at,
"member_count": await organization_crud.get_member_count( "member_count": await organization_service.get_member_count(
db, organization_id=updated_org.id db, organization_id=updated_org.id
), ),
} }
return OrganizationResponse(**org_dict) return OrganizationResponse(**org_dict)
except NotFoundError: # pragma: no cover - See above
raise
except Exception as e: except Exception as e:
logger.error(f"Error updating organization: {e!s}", exc_info=True) logger.exception("Error updating organization: %s", e)
raise raise

View File

@@ -17,10 +17,10 @@ from app.api.dependencies.auth import get_current_user
from app.core.auth import decode_token from app.core.auth import decode_token
from app.core.database import get_db from app.core.database import get_db
from app.core.exceptions import AuthorizationError, ErrorCode, NotFoundError from app.core.exceptions import AuthorizationError, ErrorCode, NotFoundError
from app.crud.session import session as session_crud
from app.models.user import User from app.models.user import User
from app.schemas.common import MessageResponse from app.schemas.common import MessageResponse
from app.schemas.sessions import SessionListResponse, SessionResponse from app.schemas.sessions import SessionListResponse, SessionResponse
from app.services.session_service import session_service
router = APIRouter() router = APIRouter()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -60,7 +60,7 @@ async def list_my_sessions(
""" """
try: try:
# Get all active sessions for user # Get all active sessions for user
sessions = await session_crud.get_user_sessions( sessions = await session_service.get_user_sessions(
db, user_id=str(current_user.id), active_only=True db, user_id=str(current_user.id), active_only=True
) )
@@ -74,9 +74,7 @@ async def list_my_sessions(
# For now, we'll mark current based on most recent activity # For now, we'll mark current based on most recent activity
except Exception as e: except Exception as e:
# Optional token parsing - silently ignore failures # Optional token parsing - silently ignore failures
logger.debug( logger.debug("Failed to decode access token for session marking: %s", e)
f"Failed to decode access token for session marking: {e!s}"
)
# Convert to response format # Convert to response format
session_responses = [] session_responses = []
@@ -98,7 +96,7 @@ async def list_my_sessions(
session_responses.append(session_response) session_responses.append(session_response)
logger.info( logger.info(
f"User {current_user.id} listed {len(session_responses)} active sessions" "User %s listed %s active sessions", current_user.id, len(session_responses)
) )
return SessionListResponse( return SessionListResponse(
@@ -106,9 +104,7 @@ async def list_my_sessions(
) )
except Exception as e: except Exception as e:
logger.error( logger.exception("Error listing sessions for user %s: %s", current_user.id, e)
f"Error listing sessions for user {current_user.id}: {e!s}", exc_info=True
)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve sessions", detail="Failed to retrieve sessions",
@@ -150,7 +146,7 @@ async def revoke_session(
""" """
try: try:
# Get the session # Get the session
session = await session_crud.get(db, id=str(session_id)) session = await session_service.get_session(db, str(session_id))
if not session: if not session:
raise NotFoundError( raise NotFoundError(
@@ -161,8 +157,10 @@ async def revoke_session(
# Verify session belongs to current user # Verify session belongs to current user
if str(session.user_id) != str(current_user.id): if str(session.user_id) != str(current_user.id):
logger.warning( logger.warning(
f"User {current_user.id} attempted to revoke session {session_id} " "User %s attempted to revoke session %s belonging to user %s",
f"belonging to user {session.user_id}" current_user.id,
session_id,
session.user_id,
) )
raise AuthorizationError( raise AuthorizationError(
message="You can only revoke your own sessions", message="You can only revoke your own sessions",
@@ -170,11 +168,13 @@ async def revoke_session(
) )
# Deactivate the session # Deactivate the session
await session_crud.deactivate(db, session_id=str(session_id)) await session_service.deactivate(db, session_id=str(session_id))
logger.info( logger.info(
f"User {current_user.id} revoked session {session_id} " "User %s revoked session %s (%s)",
f"({session.device_name})" current_user.id,
session_id,
session.device_name,
) )
return MessageResponse( return MessageResponse(
@@ -185,7 +185,7 @@ async def revoke_session(
except (NotFoundError, AuthorizationError): except (NotFoundError, AuthorizationError):
raise raise
except Exception as e: except Exception as e:
logger.error(f"Error revoking session {session_id}: {e!s}", exc_info=True) logger.exception("Error revoking session %s: %s", session_id, e)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to revoke session", detail="Failed to revoke session",
@@ -224,12 +224,12 @@ async def cleanup_expired_sessions(
""" """
try: try:
# Use optimized bulk DELETE instead of N individual deletes # Use optimized bulk DELETE instead of N individual deletes
deleted_count = await session_crud.cleanup_expired_for_user( deleted_count = await session_service.cleanup_expired_for_user(
db, user_id=str(current_user.id) db, user_id=str(current_user.id)
) )
logger.info( logger.info(
f"User {current_user.id} cleaned up {deleted_count} expired sessions" "User %s cleaned up %s expired sessions", current_user.id, deleted_count
) )
return MessageResponse( return MessageResponse(
@@ -237,9 +237,8 @@ async def cleanup_expired_sessions(
) )
except Exception as e: except Exception as e:
logger.error( logger.exception(
f"Error cleaning up sessions for user {current_user.id}: {e!s}", "Error cleaning up sessions for user %s: %s", current_user.id, e
exc_info=True,
) )
await db.rollback() await db.rollback()
raise HTTPException( raise HTTPException(

View File

@@ -1,5 +1,5 @@
""" """
User management endpoints for CRUD operations. User management endpoints for database operations.
""" """
import logging import logging
@@ -13,8 +13,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.auth import get_current_superuser, get_current_user from app.api.dependencies.auth import get_current_superuser, get_current_user
from app.core.database import get_db from app.core.database import get_db
from app.core.exceptions import AuthorizationError, ErrorCode, NotFoundError from app.core.exceptions import AuthorizationError, ErrorCode
from app.crud.user import user as user_crud
from app.models.user import User from app.models.user import User
from app.schemas.common import ( from app.schemas.common import (
MessageResponse, MessageResponse,
@@ -25,6 +24,7 @@ from app.schemas.common import (
) )
from app.schemas.users import PasswordChange, UserResponse, UserUpdate from app.schemas.users import PasswordChange, UserResponse, UserUpdate
from app.services.auth_service import AuthenticationError, AuthService from app.services.auth_service import AuthenticationError, AuthService
from app.services.user_service import user_service
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -71,7 +71,7 @@ async def list_users(
filters["is_superuser"] = is_superuser filters["is_superuser"] = is_superuser
# Get paginated users with total count # Get paginated users with total count
users, total = await user_crud.get_multi_with_total( users, total = await user_service.list_users(
db, db,
skip=pagination.offset, skip=pagination.offset,
limit=pagination.limit, limit=pagination.limit,
@@ -90,7 +90,7 @@ async def list_users(
return PaginatedResponse(data=users, pagination=pagination_meta) return PaginatedResponse(data=users, pagination=pagination_meta)
except Exception as e: except Exception as e:
logger.error(f"Error listing users: {e!s}", exc_info=True) logger.exception("Error listing users: %s", e)
raise raise
@@ -107,7 +107,9 @@ async def list_users(
""", """,
operation_id="get_current_user_profile", operation_id="get_current_user_profile",
) )
def get_current_user_profile(current_user: User = Depends(get_current_user)) -> Any: async def get_current_user_profile(
current_user: User = Depends(get_current_user),
) -> Any:
"""Get current user's profile.""" """Get current user's profile."""
return current_user return current_user
@@ -138,18 +140,16 @@ async def update_current_user(
Users cannot elevate their own permissions (protected by UserUpdate schema validator). Users cannot elevate their own permissions (protected by UserUpdate schema validator).
""" """
try: try:
updated_user = await user_crud.update( updated_user = await user_service.update_user(
db, db_obj=current_user, obj_in=user_update db, user=current_user, obj_in=user_update
) )
logger.info(f"User {current_user.id} updated their profile") logger.info("User %s updated their profile", current_user.id)
return updated_user return updated_user
except ValueError as e: except ValueError as e:
logger.error(f"Error updating user {current_user.id}: {e!s}") logger.error("Error updating user %s: %s", current_user.id, e)
raise raise
except Exception as e: except Exception as e:
logger.error( logger.exception("Unexpected error updating user %s: %s", current_user.id, e)
f"Unexpected error updating user {current_user.id}: {e!s}", exc_info=True
)
raise raise
@@ -182,7 +182,9 @@ async def get_user_by_id(
# Check permissions # Check permissions
if str(user_id) != str(current_user.id) and not current_user.is_superuser: if str(user_id) != str(current_user.id) and not current_user.is_superuser:
logger.warning( logger.warning(
f"User {current_user.id} attempted to access user {user_id} without permission" "User %s attempted to access user %s without permission",
current_user.id,
user_id,
) )
raise AuthorizationError( raise AuthorizationError(
message="Not enough permissions to view this user", message="Not enough permissions to view this user",
@@ -190,13 +192,7 @@ async def get_user_by_id(
) )
# Get user # Get user
user = await user_crud.get(db, id=str(user_id)) user = await user_service.get_user(db, str(user_id))
if not user:
raise NotFoundError(
message=f"User with id {user_id} not found",
error_code=ErrorCode.USER_NOT_FOUND,
)
return user return user
@@ -233,7 +229,9 @@ async def update_user(
if not is_own_profile and not current_user.is_superuser: if not is_own_profile and not current_user.is_superuser:
logger.warning( logger.warning(
f"User {current_user.id} attempted to update user {user_id} without permission" "User %s attempted to update user %s without permission",
current_user.id,
user_id,
) )
raise AuthorizationError( raise AuthorizationError(
message="Not enough permissions to update this user", message="Not enough permissions to update this user",
@@ -241,22 +239,17 @@ async def update_user(
) )
# Get user # Get user
user = await user_crud.get(db, id=str(user_id)) user = await user_service.get_user(db, str(user_id))
if not user:
raise NotFoundError(
message=f"User with id {user_id} not found",
error_code=ErrorCode.USER_NOT_FOUND,
)
try: try:
updated_user = await user_crud.update(db, db_obj=user, obj_in=user_update) updated_user = await user_service.update_user(db, user=user, obj_in=user_update)
logger.info(f"User {user_id} updated by {current_user.id}") logger.info("User %s updated by %s", user_id, current_user.id)
return updated_user return updated_user
except ValueError as e: except ValueError as e:
logger.error(f"Error updating user {user_id}: {e!s}") logger.error("Error updating user %s: %s", user_id, e)
raise raise
except Exception as e: except Exception as e:
logger.error(f"Unexpected error updating user {user_id}: {e!s}", exc_info=True) logger.exception("Unexpected error updating user %s: %s", user_id, e)
raise raise
@@ -296,19 +289,19 @@ async def change_current_user_password(
) )
if success: if success:
logger.info(f"User {current_user.id} changed their password") logger.info("User %s changed their password", current_user.id)
return MessageResponse( return MessageResponse(
success=True, message="Password changed successfully" success=True, message="Password changed successfully"
) )
except AuthenticationError as e: except AuthenticationError as e:
logger.warning( logger.warning(
f"Failed password change attempt for user {current_user.id}: {e!s}" "Failed password change attempt for user %s: %s", current_user.id, e
) )
raise AuthorizationError( raise AuthorizationError(
message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS
) )
except Exception as e: except Exception as e:
logger.error(f"Error changing password for user {current_user.id}: {e!s}") logger.error("Error changing password for user %s: %s", current_user.id, e)
raise raise
@@ -346,24 +339,19 @@ async def delete_user(
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS, error_code=ErrorCode.INSUFFICIENT_PERMISSIONS,
) )
# Get user # Get user (raises NotFoundError if not found)
user = await user_crud.get(db, id=str(user_id)) await user_service.get_user(db, str(user_id))
if not user:
raise NotFoundError(
message=f"User with id {user_id} not found",
error_code=ErrorCode.USER_NOT_FOUND,
)
try: try:
# Use soft delete instead of hard delete # Use soft delete instead of hard delete
await user_crud.soft_delete(db, id=str(user_id)) await user_service.soft_delete_user(db, str(user_id))
logger.info(f"User {user_id} soft-deleted by {current_user.id}") logger.info("User %s soft-deleted by %s", user_id, current_user.id)
return MessageResponse( return MessageResponse(
success=True, message=f"User {user_id} deleted successfully" success=True, message=f"User {user_id} deleted successfully"
) )
except ValueError as e: except ValueError as e:
logger.error(f"Error deleting user {user_id}: {e!s}") logger.error("Error deleting user %s: %s", user_id, e)
raise raise
except Exception as e: except Exception as e:
logger.error(f"Unexpected error deleting user {user_id}: {e!s}", exc_info=True) logger.exception("Unexpected error deleting user %s: %s", user_id, e)
raise raise

View File

@@ -1,23 +1,21 @@
import asyncio import asyncio
import logging
import uuid import uuid
from datetime import UTC, datetime, timedelta from datetime import UTC, datetime, timedelta
from functools import partial from functools import partial
from typing import Any from typing import Any
from jose import JWTError, jwt import bcrypt
from passlib.context import CryptContext import jwt
from jwt.exceptions import (
ExpiredSignatureError,
InvalidTokenError,
MissingRequiredClaimError,
)
from pydantic import ValidationError from pydantic import ValidationError
from app.core.config import settings from app.core.config import settings
from app.schemas.users import TokenData, TokenPayload from app.schemas.users import TokenData, TokenPayload
# Suppress passlib bcrypt warnings about ident
logging.getLogger("passlib").setLevel(logging.ERROR)
# Password hashing context
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
# Custom exceptions for auth # Custom exceptions for auth
class AuthError(Exception): class AuthError(Exception):
@@ -37,13 +35,16 @@ class TokenMissingClaimError(AuthError):
def verify_password(plain_password: str, hashed_password: str) -> bool: def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a password against a hash.""" """Verify a password against a bcrypt hash."""
return pwd_context.verify(plain_password, hashed_password) return bcrypt.checkpw(
plain_password.encode("utf-8"), hashed_password.encode("utf-8")
)
def get_password_hash(password: str) -> str: def get_password_hash(password: str) -> str:
"""Generate a password hash.""" """Generate a bcrypt password hash."""
return pwd_context.hash(password) salt = bcrypt.gensalt()
return bcrypt.hashpw(password.encode("utf-8"), salt).decode("utf-8")
async def verify_password_async(plain_password: str, hashed_password: str) -> bool: async def verify_password_async(plain_password: str, hashed_password: str) -> bool:
@@ -60,9 +61,9 @@ async def verify_password_async(plain_password: str, hashed_password: str) -> bo
Returns: Returns:
True if password matches, False otherwise True if password matches, False otherwise
""" """
loop = asyncio.get_event_loop() loop = asyncio.get_running_loop()
return await loop.run_in_executor( return await loop.run_in_executor(
None, partial(pwd_context.verify, plain_password, hashed_password) None, partial(verify_password, plain_password, hashed_password)
) )
@@ -80,8 +81,8 @@ async def get_password_hash_async(password: str) -> str:
Returns: Returns:
Hashed password string Hashed password string
""" """
loop = asyncio.get_event_loop() loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, pwd_context.hash, password) return await loop.run_in_executor(None, get_password_hash, password)
def create_access_token( def create_access_token(
@@ -121,11 +122,7 @@ def create_access_token(
to_encode.update(claims) to_encode.update(claims)
# Create the JWT # Create the JWT
encoded_jwt = jwt.encode( return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
)
return encoded_jwt
def create_refresh_token( def create_refresh_token(
@@ -154,11 +151,7 @@ def create_refresh_token(
"type": "refresh", "type": "refresh",
} }
encoded_jwt = jwt.encode( return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
)
return encoded_jwt
def decode_token(token: str, verify_type: str | None = None) -> TokenPayload: def decode_token(token: str, verify_type: str | None = None) -> TokenPayload:
@@ -198,7 +191,7 @@ def decode_token(token: str, verify_type: str | None = None) -> TokenPayload:
# Reject weak or unexpected algorithms # Reject weak or unexpected algorithms
# NOTE: These are defensive checks that provide defense-in-depth. # NOTE: These are defensive checks that provide defense-in-depth.
# The python-jose library rejects these tokens BEFORE we reach here, # PyJWT rejects these tokens BEFORE we reach here,
# but we keep these checks in case the library changes or is misconfigured. # but we keep these checks in case the library changes or is misconfigured.
# Coverage: Marked as pragma since library catches first (see tests/core/test_auth_security.py) # Coverage: Marked as pragma since library catches first (see tests/core/test_auth_security.py)
if token_algorithm == "NONE": # pragma: no cover if token_algorithm == "NONE": # pragma: no cover
@@ -219,10 +212,11 @@ def decode_token(token: str, verify_type: str | None = None) -> TokenPayload:
token_data = TokenPayload(**payload) token_data = TokenPayload(**payload)
return token_data return token_data
except JWTError as e: except ExpiredSignatureError:
# Check if the error is due to an expired token raise TokenExpiredError("Token has expired")
if "expired" in str(e).lower(): except MissingRequiredClaimError as e:
raise TokenExpiredError("Token has expired") raise TokenMissingClaimError(f"Token missing required claim: {e}")
except InvalidTokenError:
raise TokenInvalidError("Invalid authentication token") raise TokenInvalidError("Invalid authentication token")
except ValidationError: except ValidationError:
raise TokenInvalidError("Invalid token payload") raise TokenInvalidError("Invalid token payload")

View File

@@ -128,8 +128,8 @@ async def async_transaction_scope() -> AsyncGenerator[AsyncSession, None]:
Usage: Usage:
async with async_transaction_scope() as db: async with async_transaction_scope() as db:
user = await user_crud.create(db, obj_in=user_create) user = await user_repo.create(db, obj_in=user_create)
profile = await profile_crud.create(db, obj_in=profile_create) profile = await profile_repo.create(db, obj_in=profile_create)
# Both operations committed together # Both operations committed together
""" """
async with SessionLocal() as session: async with SessionLocal() as session:
@@ -139,7 +139,7 @@ async def async_transaction_scope() -> AsyncGenerator[AsyncSession, None]:
logger.debug("Async transaction committed successfully") logger.debug("Async transaction committed successfully")
except Exception as e: except Exception as e:
await session.rollback() await session.rollback()
logger.error(f"Async transaction failed, rolling back: {e!s}") logger.error("Async transaction failed, rolling back: %s", e)
raise raise
finally: finally:
await session.close() await session.close()
@@ -155,7 +155,7 @@ async def check_async_database_health() -> bool:
await db.execute(text("SELECT 1")) await db.execute(text("SELECT 1"))
return True return True
except Exception as e: except Exception as e:
logger.error(f"Async database health check failed: {e!s}") logger.error("Async database health check failed: %s", e)
return False return False

View File

@@ -143,8 +143,11 @@ async def api_exception_handler(request: Request, exc: APIException) -> JSONResp
Returns a standardized error response with error code and message. Returns a standardized error response with error code and message.
""" """
logger.warning( logger.warning(
f"API exception: {exc.error_code} - {exc.message} " "API exception: %s - %s (status: %s, path: %s)",
f"(status: {exc.status_code}, path: {request.url.path})" exc.error_code,
exc.message,
exc.status_code,
request.url.path,
) )
error_response = ErrorResponse( error_response = ErrorResponse(
@@ -186,7 +189,9 @@ async def validation_exception_handler(
) )
) )
logger.warning(f"Validation error: {len(errors)} errors (path: {request.url.path})") logger.warning(
"Validation error: %s errors (path: %s)", len(errors), request.url.path
)
error_response = ErrorResponse(errors=errors) error_response = ErrorResponse(errors=errors)
@@ -218,11 +223,14 @@ async def http_exception_handler(request: Request, exc: HTTPException) -> JSONRe
) )
logger.warning( logger.warning(
f"HTTP exception: {exc.status_code} - {exc.detail} (path: {request.url.path})" "HTTP exception: %s - %s (path: %s)",
exc.status_code,
exc.detail,
request.url.path,
) )
error_response = ErrorResponse( error_response = ErrorResponse(
errors=[ErrorDetail(code=error_code, message=str(exc.detail))] errors=[ErrorDetail(code=error_code, message=str(exc.detail), field=None)]
) )
return JSONResponse( return JSONResponse(
@@ -239,10 +247,11 @@ async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONR
Logs the full exception and returns a generic error response to avoid Logs the full exception and returns a generic error response to avoid
leaking sensitive information in production. leaking sensitive information in production.
""" """
logger.error( logger.exception(
f"Unhandled exception: {type(exc).__name__} - {exc!s} " "Unhandled exception: %s - %s (path: %s)",
f"(path: {request.url.path})", type(exc).__name__,
exc_info=True, exc,
request.url.path,
) )
# In production, don't expose internal error details # In production, don't expose internal error details
@@ -254,7 +263,7 @@ async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONR
message = f"{type(exc).__name__}: {exc!s}" message = f"{type(exc).__name__}: {exc!s}"
error_response = ErrorResponse( error_response = ErrorResponse(
errors=[ErrorDetail(code=ErrorCode.INTERNAL_ERROR, message=message)] errors=[ErrorDetail(code=ErrorCode.INTERNAL_ERROR, message=message, field=None)]
) )
return JSONResponse( return JSONResponse(

View File

@@ -0,0 +1,26 @@
"""
Custom exceptions for the repository layer.
These exceptions allow services and routes to handle database-level errors
with proper semantics, without leaking SQLAlchemy internals.
"""
class RepositoryError(Exception):
"""Base for all repository-layer errors."""
class DuplicateEntryError(RepositoryError):
"""Raised on unique constraint violations. Maps to HTTP 409 Conflict."""
class IntegrityConstraintError(RepositoryError):
"""Raised on FK or check constraint violations."""
class RecordNotFoundError(RepositoryError):
"""Raised when an expected record doesn't exist."""
class InvalidInputError(RepositoryError):
"""Raised on bad pagination params, invalid UUIDs, or other invalid inputs."""

View File

@@ -1,14 +0,0 @@
# app/crud/__init__.py
from .oauth import oauth_account, oauth_client, oauth_state
from .organization import organization
from .session import session as session_crud
from .user import user
__all__ = [
"oauth_account",
"oauth_client",
"oauth_state",
"organization",
"session_crud",
"user",
]

View File

@@ -1,718 +0,0 @@
"""
Async CRUD operations for OAuth models using SQLAlchemy 2.0 patterns.
Provides operations for:
- OAuthAccount: Managing linked OAuth provider accounts
- OAuthState: CSRF protection state during OAuth flows
- OAuthClient: Registered OAuth clients (provider mode skeleton)
"""
import logging
import secrets
from datetime import UTC, datetime
from uuid import UUID
from pydantic import BaseModel
from sqlalchemy import and_, delete, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from app.crud.base import CRUDBase
from app.models.oauth_account import OAuthAccount
from app.models.oauth_client import OAuthClient
from app.models.oauth_state import OAuthState
from app.schemas.oauth import OAuthAccountCreate, OAuthClientCreate, OAuthStateCreate
logger = logging.getLogger(__name__)
# ============================================================================
# OAuth Account CRUD
# ============================================================================
class EmptySchema(BaseModel):
"""Placeholder schema for CRUD operations that don't need update schemas."""
class CRUDOAuthAccount(CRUDBase[OAuthAccount, OAuthAccountCreate, EmptySchema]):
"""CRUD operations for OAuth account links."""
async def get_by_provider_id(
self,
db: AsyncSession,
*,
provider: str,
provider_user_id: str,
) -> OAuthAccount | None:
"""
Get OAuth account by provider and provider user ID.
Args:
db: Database session
provider: OAuth provider name (google, github)
provider_user_id: User ID from the OAuth provider
Returns:
OAuthAccount if found, None otherwise
"""
try:
result = await db.execute(
select(OAuthAccount)
.where(
and_(
OAuthAccount.provider == provider,
OAuthAccount.provider_user_id == provider_user_id,
)
)
.options(joinedload(OAuthAccount.user))
)
return result.scalar_one_or_none()
except Exception as e: # pragma: no cover # pragma: no cover
logger.error(
f"Error getting OAuth account for {provider}:{provider_user_id}: {e!s}"
)
raise
async def get_by_provider_email(
self,
db: AsyncSession,
*,
provider: str,
email: str,
) -> OAuthAccount | None:
"""
Get OAuth account by provider and email.
Used for auto-linking existing accounts by email.
Args:
db: Database session
provider: OAuth provider name
email: Email address from the OAuth provider
Returns:
OAuthAccount if found, None otherwise
"""
try:
result = await db.execute(
select(OAuthAccount)
.where(
and_(
OAuthAccount.provider == provider,
OAuthAccount.provider_email == email,
)
)
.options(joinedload(OAuthAccount.user))
)
return result.scalar_one_or_none()
except Exception as e: # pragma: no cover # pragma: no cover
logger.error(
f"Error getting OAuth account for {provider} email {email}: {e!s}"
)
raise
async def get_user_accounts(
self,
db: AsyncSession,
*,
user_id: str | UUID,
) -> list[OAuthAccount]:
"""
Get all OAuth accounts linked to a user.
Args:
db: Database session
user_id: User ID
Returns:
List of OAuthAccount objects
"""
try:
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
result = await db.execute(
select(OAuthAccount)
.where(OAuthAccount.user_id == user_uuid)
.order_by(OAuthAccount.created_at.desc())
)
return list(result.scalars().all())
except Exception as e: # pragma: no cover
logger.error(f"Error getting OAuth accounts for user {user_id}: {e!s}")
raise
async def get_user_account_by_provider(
self,
db: AsyncSession,
*,
user_id: str | UUID,
provider: str,
) -> OAuthAccount | None:
"""
Get a specific OAuth account for a user and provider.
Args:
db: Database session
user_id: User ID
provider: OAuth provider name
Returns:
OAuthAccount if found, None otherwise
"""
try:
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
result = await db.execute(
select(OAuthAccount).where(
and_(
OAuthAccount.user_id == user_uuid,
OAuthAccount.provider == provider,
)
)
)
return result.scalar_one_or_none()
except Exception as e: # pragma: no cover
logger.error(
f"Error getting OAuth account for user {user_id}, provider {provider}: {e!s}"
)
raise
async def create_account(
self, db: AsyncSession, *, obj_in: OAuthAccountCreate
) -> OAuthAccount:
"""
Create a new OAuth account link.
Args:
db: Database session
obj_in: OAuth account creation data
Returns:
Created OAuthAccount
Raises:
ValueError: If account already exists or creation fails
"""
try:
db_obj = OAuthAccount(
user_id=obj_in.user_id,
provider=obj_in.provider,
provider_user_id=obj_in.provider_user_id,
provider_email=obj_in.provider_email,
access_token_encrypted=obj_in.access_token_encrypted,
refresh_token_encrypted=obj_in.refresh_token_encrypted,
token_expires_at=obj_in.token_expires_at,
)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
logger.info(
f"OAuth account created: {obj_in.provider} linked to user {obj_in.user_id}"
)
return db_obj
except IntegrityError as e: # pragma: no cover
await db.rollback()
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
if "uq_oauth_provider_user" in error_msg.lower():
logger.warning(
f"OAuth account already exists: {obj_in.provider}:{obj_in.provider_user_id}"
)
raise ValueError(
f"This {obj_in.provider} account is already linked to another user"
)
logger.error(f"Integrity error creating OAuth account: {error_msg}")
raise ValueError(f"Failed to create OAuth account: {error_msg}")
except Exception as e: # pragma: no cover
await db.rollback()
logger.error(f"Error creating OAuth account: {e!s}", exc_info=True)
raise
async def delete_account(
self,
db: AsyncSession,
*,
user_id: str | UUID,
provider: str,
) -> bool:
"""
Delete an OAuth account link.
Args:
db: Database session
user_id: User ID
provider: OAuth provider name
Returns:
True if deleted, False if not found
"""
try:
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
result = await db.execute(
delete(OAuthAccount).where(
and_(
OAuthAccount.user_id == user_uuid,
OAuthAccount.provider == provider,
)
)
)
await db.commit()
deleted = result.rowcount > 0
if deleted:
logger.info(
f"OAuth account deleted: {provider} unlinked from user {user_id}"
)
else:
logger.warning(
f"OAuth account not found for deletion: {provider} for user {user_id}"
)
return deleted
except Exception as e: # pragma: no cover
await db.rollback()
logger.error(
f"Error deleting OAuth account {provider} for user {user_id}: {e!s}"
)
raise
async def update_tokens(
self,
db: AsyncSession,
*,
account: OAuthAccount,
access_token_encrypted: str | None = None,
refresh_token_encrypted: str | None = None,
token_expires_at: datetime | None = None,
) -> OAuthAccount:
"""
Update OAuth tokens for an account.
Args:
db: Database session
account: OAuthAccount to update
access_token_encrypted: New encrypted access token
refresh_token_encrypted: New encrypted refresh token
token_expires_at: New token expiration time
Returns:
Updated OAuthAccount
"""
try:
if access_token_encrypted is not None:
account.access_token_encrypted = access_token_encrypted
if refresh_token_encrypted is not None:
account.refresh_token_encrypted = refresh_token_encrypted
if token_expires_at is not None:
account.token_expires_at = token_expires_at
db.add(account)
await db.commit()
await db.refresh(account)
return account
except Exception as e: # pragma: no cover
await db.rollback()
logger.error(f"Error updating OAuth tokens: {e!s}")
raise
# ============================================================================
# OAuth State CRUD
# ============================================================================
class CRUDOAuthState(CRUDBase[OAuthState, OAuthStateCreate, EmptySchema]):
"""CRUD operations for OAuth state (CSRF protection)."""
async def create_state(
self, db: AsyncSession, *, obj_in: OAuthStateCreate
) -> OAuthState:
"""
Create a new OAuth state for CSRF protection.
Args:
db: Database session
obj_in: OAuth state creation data
Returns:
Created OAuthState
"""
try:
db_obj = OAuthState(
state=obj_in.state,
code_verifier=obj_in.code_verifier,
nonce=obj_in.nonce,
provider=obj_in.provider,
redirect_uri=obj_in.redirect_uri,
user_id=obj_in.user_id,
expires_at=obj_in.expires_at,
)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
logger.debug(f"OAuth state created for {obj_in.provider}")
return db_obj
except IntegrityError as e: # pragma: no cover
await db.rollback()
# State collision (extremely rare with cryptographic random)
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
logger.error(f"OAuth state collision: {error_msg}")
raise ValueError("Failed to create OAuth state, please retry")
except Exception as e: # pragma: no cover
await db.rollback()
logger.error(f"Error creating OAuth state: {e!s}", exc_info=True)
raise
async def get_and_consume_state(
self, db: AsyncSession, *, state: str
) -> OAuthState | None:
"""
Get and delete OAuth state (consume it).
This ensures each state can only be used once (replay protection).
Args:
db: Database session
state: State string to look up
Returns:
OAuthState if found and valid, None otherwise
"""
try:
# Get the state
result = await db.execute(
select(OAuthState).where(OAuthState.state == state)
)
db_obj = result.scalar_one_or_none()
if db_obj is None:
logger.warning(f"OAuth state not found: {state[:8]}...")
return None
# Check expiration
# Handle both timezone-aware and timezone-naive datetimes
now = datetime.now(UTC)
expires_at = db_obj.expires_at
if expires_at.tzinfo is None:
# SQLite returns naive datetimes, assume UTC
expires_at = expires_at.replace(tzinfo=UTC)
if expires_at < now:
logger.warning(f"OAuth state expired: {state[:8]}...")
await db.delete(db_obj)
await db.commit()
return None
# Delete it (consume)
await db.delete(db_obj)
await db.commit()
logger.debug(f"OAuth state consumed: {state[:8]}...")
return db_obj
except Exception as e: # pragma: no cover
await db.rollback()
logger.error(f"Error consuming OAuth state: {e!s}")
raise
async def cleanup_expired(self, db: AsyncSession) -> int:
"""
Clean up expired OAuth states.
Should be called periodically to remove stale states.
Args:
db: Database session
Returns:
Number of states deleted
"""
try:
now = datetime.now(UTC)
stmt = delete(OAuthState).where(OAuthState.expires_at < now)
result = await db.execute(stmt)
await db.commit()
count = result.rowcount
if count > 0:
logger.info(f"Cleaned up {count} expired OAuth states")
return count
except Exception as e: # pragma: no cover
await db.rollback()
logger.error(f"Error cleaning up expired OAuth states: {e!s}")
raise
# ============================================================================
# OAuth Client CRUD (Provider Mode - Skeleton)
# ============================================================================
class CRUDOAuthClient(CRUDBase[OAuthClient, OAuthClientCreate, EmptySchema]):
"""
CRUD operations for OAuth clients (provider mode).
This is a skeleton implementation for MCP client registration.
Full implementation can be expanded when needed.
"""
async def get_by_client_id(
self, db: AsyncSession, *, client_id: str
) -> OAuthClient | None:
"""
Get OAuth client by client_id.
Args:
db: Database session
client_id: OAuth client ID
Returns:
OAuthClient if found, None otherwise
"""
try:
result = await db.execute(
select(OAuthClient).where(
and_(
OAuthClient.client_id == client_id,
OAuthClient.is_active == True, # noqa: E712
)
)
)
return result.scalar_one_or_none()
except Exception as e: # pragma: no cover
logger.error(f"Error getting OAuth client {client_id}: {e!s}")
raise
async def create_client(
self,
db: AsyncSession,
*,
obj_in: OAuthClientCreate,
owner_user_id: UUID | None = None,
) -> tuple[OAuthClient, str | None]:
"""
Create a new OAuth client.
Args:
db: Database session
obj_in: OAuth client creation data
owner_user_id: Optional owner user ID
Returns:
Tuple of (created OAuthClient, client_secret or None for public clients)
"""
try:
# Generate client_id
client_id = secrets.token_urlsafe(32)
# Generate client_secret for confidential clients
client_secret = None
client_secret_hash = None
if obj_in.client_type == "confidential":
client_secret = secrets.token_urlsafe(48)
# SECURITY: Use bcrypt for secret storage (not SHA-256)
# bcrypt is computationally expensive, making brute-force attacks infeasible
from app.core.auth import get_password_hash
client_secret_hash = get_password_hash(client_secret)
db_obj = OAuthClient(
client_id=client_id,
client_secret_hash=client_secret_hash,
client_name=obj_in.client_name,
client_description=obj_in.client_description,
client_type=obj_in.client_type,
redirect_uris=obj_in.redirect_uris,
allowed_scopes=obj_in.allowed_scopes,
owner_user_id=owner_user_id,
is_active=True,
)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
logger.info(
f"OAuth client created: {obj_in.client_name} ({client_id[:8]}...)"
)
return db_obj, client_secret
except IntegrityError as e: # pragma: no cover
await db.rollback()
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
logger.error(f"Error creating OAuth client: {error_msg}")
raise ValueError(f"Failed to create OAuth client: {error_msg}")
except Exception as e: # pragma: no cover
await db.rollback()
logger.error(f"Error creating OAuth client: {e!s}", exc_info=True)
raise
async def deactivate_client(
self, db: AsyncSession, *, client_id: str
) -> OAuthClient | None:
"""
Deactivate an OAuth client.
Args:
db: Database session
client_id: OAuth client ID
Returns:
Deactivated OAuthClient if found, None otherwise
"""
try:
client = await self.get_by_client_id(db, client_id=client_id)
if client is None:
return None
client.is_active = False
db.add(client)
await db.commit()
await db.refresh(client)
logger.info(f"OAuth client deactivated: {client.client_name}")
return client
except Exception as e: # pragma: no cover
await db.rollback()
logger.error(f"Error deactivating OAuth client {client_id}: {e!s}")
raise
async def validate_redirect_uri(
self, db: AsyncSession, *, client_id: str, redirect_uri: str
) -> bool:
"""
Validate that a redirect URI is allowed for a client.
Args:
db: Database session
client_id: OAuth client ID
redirect_uri: Redirect URI to validate
Returns:
True if valid, False otherwise
"""
try:
client = await self.get_by_client_id(db, client_id=client_id)
if client is None:
return False
return redirect_uri in (client.redirect_uris or [])
except Exception as e: # pragma: no cover
logger.error(f"Error validating redirect URI: {e!s}")
return False
async def verify_client_secret(
self, db: AsyncSession, *, client_id: str, client_secret: str
) -> bool:
"""
Verify client credentials.
Args:
db: Database session
client_id: OAuth client ID
client_secret: Client secret to verify
Returns:
True if valid, False otherwise
"""
try:
result = await db.execute(
select(OAuthClient).where(
and_(
OAuthClient.client_id == client_id,
OAuthClient.is_active == True, # noqa: E712
)
)
)
client = result.scalar_one_or_none()
if client is None or client.client_secret_hash is None:
return False
# SECURITY: Verify secret using bcrypt (not SHA-256)
# This supports both old SHA-256 hashes (for migration) and new bcrypt hashes
from app.core.auth import verify_password
stored_hash: str = str(client.client_secret_hash)
# Check if it's a bcrypt hash (starts with $2b$) or legacy SHA-256
if stored_hash.startswith("$2"):
# New bcrypt format
return verify_password(client_secret, stored_hash)
else:
# Legacy SHA-256 format - still support for migration
import hashlib
secret_hash = hashlib.sha256(client_secret.encode()).hexdigest()
return secrets.compare_digest(stored_hash, secret_hash)
except Exception as e: # pragma: no cover
logger.error(f"Error verifying client secret: {e!s}")
return False
async def get_all_clients(
self, db: AsyncSession, *, include_inactive: bool = False
) -> list[OAuthClient]:
"""
Get all OAuth clients.
Args:
db: Database session
include_inactive: Whether to include inactive clients
Returns:
List of OAuthClient objects
"""
try:
query = select(OAuthClient).order_by(OAuthClient.created_at.desc())
if not include_inactive:
query = query.where(OAuthClient.is_active == True) # noqa: E712
result = await db.execute(query)
return list(result.scalars().all())
except Exception as e: # pragma: no cover
logger.error(f"Error getting all OAuth clients: {e!s}")
raise
async def delete_client(self, db: AsyncSession, *, client_id: str) -> bool:
"""
Delete an OAuth client permanently.
Note: This will cascade delete related records (tokens, consents, etc.)
due to foreign key constraints.
Args:
db: Database session
client_id: OAuth client ID
Returns:
True if deleted, False if not found
"""
try:
result = await db.execute(
delete(OAuthClient).where(OAuthClient.client_id == client_id)
)
await db.commit()
deleted = result.rowcount > 0
if deleted:
logger.info(f"OAuth client deleted: {client_id}")
else:
logger.warning(f"OAuth client not found for deletion: {client_id}")
return deleted
except Exception as e: # pragma: no cover
await db.rollback()
logger.error(f"Error deleting OAuth client {client_id}: {e!s}")
raise
# ============================================================================
# Singleton instances
# ============================================================================
oauth_account = CRUDOAuthAccount(OAuthAccount)
oauth_state = CRUDOAuthState(OAuthState)
oauth_client = CRUDOAuthClient(OAuthClient)

View File

@@ -16,10 +16,10 @@ from sqlalchemy import select, text
from app.core.config import settings from app.core.config import settings
from app.core.database import SessionLocal, engine from app.core.database import SessionLocal, engine
from app.crud.user import user as user_crud
from app.models.organization import Organization from app.models.organization import Organization
from app.models.user import User from app.models.user import User
from app.models.user_organization import UserOrganization from app.models.user_organization import UserOrganization
from app.repositories.user import user_repo as user_repo
from app.schemas.users import UserCreate from app.schemas.users import UserCreate
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -44,16 +44,17 @@ async def init_db() -> User | None:
if not settings.FIRST_SUPERUSER_EMAIL or not settings.FIRST_SUPERUSER_PASSWORD: if not settings.FIRST_SUPERUSER_EMAIL or not settings.FIRST_SUPERUSER_PASSWORD:
logger.warning( logger.warning(
"First superuser credentials not configured in settings. " "First superuser credentials not configured in settings. "
f"Using defaults: {superuser_email}" "Using defaults: %s",
superuser_email,
) )
async with SessionLocal() as session: async with SessionLocal() as session:
try: try:
# Check if superuser already exists # Check if superuser already exists
existing_user = await user_crud.get_by_email(session, email=superuser_email) existing_user = await user_repo.get_by_email(session, email=superuser_email)
if existing_user: if existing_user:
logger.info(f"Superuser already exists: {existing_user.email}") logger.info("Superuser already exists: %s", existing_user.email)
return existing_user return existing_user
# Create superuser if doesn't exist # Create superuser if doesn't exist
@@ -65,11 +66,11 @@ async def init_db() -> User | None:
is_superuser=True, is_superuser=True,
) )
user = await user_crud.create(session, obj_in=user_in) user = await user_repo.create(session, obj_in=user_in)
await session.commit() await session.commit()
await session.refresh(user) await session.refresh(user)
logger.info(f"Created first superuser: {user.email}") logger.info("Created first superuser: %s", user.email)
# Create demo data if in demo mode # Create demo data if in demo mode
if settings.DEMO_MODE: if settings.DEMO_MODE:
@@ -79,7 +80,7 @@ async def init_db() -> User | None:
except Exception as e: except Exception as e:
await session.rollback() await session.rollback()
logger.error(f"Error initializing database: {e}") logger.error("Error initializing database: %s", e)
raise raise
@@ -92,7 +93,7 @@ async def load_demo_data(session):
"""Load demo data from JSON file.""" """Load demo data from JSON file."""
demo_data_path = Path(__file__).parent / "core" / "demo_data.json" demo_data_path = Path(__file__).parent / "core" / "demo_data.json"
if not demo_data_path.exists(): if not demo_data_path.exists():
logger.warning(f"Demo data file not found: {demo_data_path}") logger.warning("Demo data file not found: %s", demo_data_path)
return return
try: try:
@@ -119,7 +120,7 @@ async def load_demo_data(session):
session.add(org) session.add(org)
await session.flush() # Flush to get ID await session.flush() # Flush to get ID
org_map[org.slug] = org org_map[org.slug] = org
logger.info(f"Created demo organization: {org.name}") logger.info("Created demo organization: %s", org.name)
else: else:
# We can't easily get the ORM object from raw SQL result for map without querying again or mapping # We can't easily get the ORM object from raw SQL result for map without querying again or mapping
# So let's just query it properly if we need it for relationships # So let's just query it properly if we need it for relationships
@@ -135,7 +136,7 @@ async def load_demo_data(session):
# Create Users # Create Users
for user_data in data.get("users", []): for user_data in data.get("users", []):
existing_user = await user_crud.get_by_email( existing_user = await user_repo.get_by_email(
session, email=user_data["email"] session, email=user_data["email"]
) )
if not existing_user: if not existing_user:
@@ -148,7 +149,7 @@ async def load_demo_data(session):
is_superuser=user_data["is_superuser"], is_superuser=user_data["is_superuser"],
is_active=user_data.get("is_active", True), is_active=user_data.get("is_active", True),
) )
user = await user_crud.create(session, obj_in=user_in) user = await user_repo.create(session, obj_in=user_in)
# Randomize created_at for demo data (last 30 days) # Randomize created_at for demo data (last 30 days)
# This makes the charts look more realistic # This makes the charts look more realistic
@@ -174,7 +175,10 @@ async def load_demo_data(session):
) )
logger.info( logger.info(
f"Created demo user: {user.email} (created {days_ago} days ago, active={user_data.get('is_active', True)})" "Created demo user: %s (created %s days ago, active=%s)",
user.email,
days_ago,
user_data.get("is_active", True),
) )
# Add to organization if specified # Add to organization if specified
@@ -187,15 +191,15 @@ async def load_demo_data(session):
user_id=user.id, organization_id=org.id, role=role user_id=user.id, organization_id=org.id, role=role
) )
session.add(member) session.add(member)
logger.info(f"Added {user.email} to {org.name} as {role}") logger.info("Added %s to %s as %s", user.email, org.name, role)
else: else:
logger.info(f"Demo user already exists: {existing_user.email}") logger.info("Demo user already exists: %s", existing_user.email)
await session.commit() await session.commit()
logger.info("Demo data loaded successfully") logger.info("Demo data loaded successfully")
except Exception as e: except Exception as e:
logger.error(f"Error loading demo data: {e}") logger.error("Error loading demo data: %s", e)
raise raise

View File

@@ -1,7 +1,7 @@
import logging import logging
import os import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from datetime import datetime from datetime import UTC, datetime
from typing import Any from typing import Any
from apscheduler.schedulers.asyncio import AsyncIOScheduler from apscheduler.schedulers.asyncio import AsyncIOScheduler
@@ -16,7 +16,7 @@ from slowapi.util import get_remote_address
from app.api.main import api_router from app.api.main import api_router
from app.api.routes.oauth_provider import wellknown_router as oauth_wellknown_router from app.api.routes.oauth_provider import wellknown_router as oauth_wellknown_router
from app.core.config import settings from app.core.config import settings
from app.core.database import check_database_health from app.core.database import check_database_health, close_async_db
from app.core.exceptions import ( from app.core.exceptions import (
APIException, APIException,
api_exception_handler, api_exception_handler,
@@ -72,6 +72,7 @@ async def lifespan(app: FastAPI):
if os.getenv("IS_TEST", "False") != "True": if os.getenv("IS_TEST", "False") != "True":
scheduler.shutdown() scheduler.shutdown()
logger.info("Scheduled jobs stopped") logger.info("Scheduled jobs stopped")
await close_async_db()
logger.info("Starting app!!!") logger.info("Starting app!!!")
@@ -294,7 +295,7 @@ async def health_check() -> JSONResponse:
""" """
health_status: dict[str, Any] = { health_status: dict[str, Any] = {
"status": "healthy", "status": "healthy",
"timestamp": datetime.utcnow().isoformat() + "Z", "timestamp": datetime.now(UTC).isoformat().replace("+00:00", "Z"),
"version": settings.VERSION, "version": settings.VERSION,
"environment": settings.ENVIRONMENT, "environment": settings.ENVIRONMENT,
"checks": {}, "checks": {},
@@ -319,7 +320,7 @@ async def health_check() -> JSONResponse:
"message": f"Database connection failed: {e!s}", "message": f"Database connection failed: {e!s}",
} }
response_status = status.HTTP_503_SERVICE_UNAVAILABLE response_status = status.HTTP_503_SERVICE_UNAVAILABLE
logger.error(f"Health check failed - database error: {e}") logger.error("Health check failed - database error: %s", e)
return JSONResponse(status_code=response_status, content=health_status) return JSONResponse(status_code=response_status, content=health_status)

View File

@@ -36,9 +36,9 @@ class OAuthAccount(Base, UUIDMixin, TimestampMixin):
) # Email from provider (for reference) ) # Email from provider (for reference)
# Optional: store provider tokens for API access # Optional: store provider tokens for API access
# These should be encrypted at rest in production # TODO: Encrypt these at rest in production (requires key management infrastructure)
access_token_encrypted = Column(String(2048), nullable=True) access_token = Column(String(2048), nullable=True)
refresh_token_encrypted = Column(String(2048), nullable=True) refresh_token = Column(String(2048), nullable=True)
token_expires_at = Column(DateTime(timezone=True), nullable=True) token_expires_at = Column(DateTime(timezone=True), nullable=True)
# Relationship # Relationship

View File

@@ -92,7 +92,7 @@ class OAuthAuthorizationCode(Base, UUIDMixin, TimestampMixin):
# Handle both timezone-aware and naive datetimes from DB # Handle both timezone-aware and naive datetimes from DB
if expires_at.tzinfo is None: if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=UTC) expires_at = expires_at.replace(tzinfo=UTC)
return now > expires_at return bool(now > expires_at)
@property @property
def is_valid(self) -> bool: def is_valid(self) -> bool:

View File

@@ -99,7 +99,7 @@ class OAuthProviderRefreshToken(Base, UUIDMixin, TimestampMixin):
# Handle both timezone-aware and naive datetimes from DB # Handle both timezone-aware and naive datetimes from DB
if expires_at.tzinfo is None: if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=UTC) expires_at = expires_at.replace(tzinfo=UTC)
return now > expires_at return bool(now > expires_at)
@property @property
def is_valid(self) -> bool: def is_valid(self) -> bool:

View File

@@ -76,7 +76,11 @@ class UserSession(Base, UUIDMixin, TimestampMixin):
"""Check if session has expired.""" """Check if session has expired."""
from datetime import datetime from datetime import datetime
return self.expires_at < datetime.now(UTC) now = datetime.now(UTC)
expires_at = self.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=UTC)
return bool(expires_at < now)
def to_dict(self): def to_dict(self):
"""Convert session to dictionary for serialization.""" """Convert session to dictionary for serialization."""

View File

@@ -0,0 +1,39 @@
# app/repositories/__init__.py
"""Repository layer — all database access goes through these classes."""
from app.repositories.oauth_account import OAuthAccountRepository, oauth_account_repo
from app.repositories.oauth_authorization_code import (
OAuthAuthorizationCodeRepository,
oauth_authorization_code_repo,
)
from app.repositories.oauth_client import OAuthClientRepository, oauth_client_repo
from app.repositories.oauth_consent import OAuthConsentRepository, oauth_consent_repo
from app.repositories.oauth_provider_token import (
OAuthProviderTokenRepository,
oauth_provider_token_repo,
)
from app.repositories.oauth_state import OAuthStateRepository, oauth_state_repo
from app.repositories.organization import OrganizationRepository, organization_repo
from app.repositories.session import SessionRepository, session_repo
from app.repositories.user import UserRepository, user_repo
__all__ = [
"OAuthAccountRepository",
"OAuthAuthorizationCodeRepository",
"OAuthClientRepository",
"OAuthConsentRepository",
"OAuthProviderTokenRepository",
"OAuthStateRepository",
"OrganizationRepository",
"SessionRepository",
"UserRepository",
"oauth_account_repo",
"oauth_authorization_code_repo",
"oauth_client_repo",
"oauth_consent_repo",
"oauth_provider_token_repo",
"oauth_state_repo",
"organization_repo",
"session_repo",
"user_repo",
]

View File

@@ -1,6 +1,6 @@
# app/crud/base_async.py # app/repositories/base.py
""" """
Async CRUD operations base class using SQLAlchemy 2.0 async patterns. Base repository class for async database operations using SQLAlchemy 2.0 async patterns.
Provides reusable create, read, update, and delete operations for all models. Provides reusable create, read, update, and delete operations for all models.
""" """
@@ -18,6 +18,11 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Load from sqlalchemy.orm import Load
from app.core.database import Base from app.core.database import Base
from app.core.repository_exceptions import (
DuplicateEntryError,
IntegrityConstraintError,
InvalidInputError,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -26,16 +31,16 @@ CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel) UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
class CRUDBase[ class BaseRepository[
ModelType: Base, ModelType: Base,
CreateSchemaType: BaseModel, CreateSchemaType: BaseModel,
UpdateSchemaType: BaseModel, UpdateSchemaType: BaseModel,
]: ]:
"""Async CRUD operations for a model.""" """Async repository operations for a model."""
def __init__(self, model: type[ModelType]): def __init__(self, model: type[ModelType]):
""" """
CRUD object with default async methods to Create, Read, Update, Delete. Repository object with default async methods to Create, Read, Update, Delete.
Parameters: Parameters:
model: A SQLAlchemy model class model: A SQLAlchemy model class
@@ -56,26 +61,19 @@ class CRUDBase[
Returns: Returns:
Model instance or None if not found Model instance or None if not found
Example:
# Eager load user relationship
from sqlalchemy.orm import joinedload
session = await session_crud.get(db, id=session_id, options=[joinedload(UserSession.user)])
""" """
# Validate UUID format and convert to UUID object if string
try: try:
if isinstance(id, uuid.UUID): if isinstance(id, uuid.UUID):
uuid_obj = id uuid_obj = id
else: else:
uuid_obj = uuid.UUID(str(id)) uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e: except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format: {id} - {e!s}") logger.warning("Invalid UUID format: %s - %s", id, e)
return None return None
try: try:
query = select(self.model).where(self.model.id == uuid_obj) query = select(self.model).where(self.model.id == uuid_obj)
# Apply eager loading options if provided
if options: if options:
for option in options: for option in options:
query = query.options(option) query = query.options(option)
@@ -83,7 +81,9 @@ class CRUDBase[
result = await db.execute(query) result = await db.execute(query)
return result.scalar_one_or_none() return result.scalar_one_or_none()
except Exception as e: except Exception as e:
logger.error(f"Error retrieving {self.model.__name__} with id {id}: {e!s}") logger.error(
"Error retrieving %s with id %s: %s", self.model.__name__, id, e
)
raise raise
async def get_multi( async def get_multi(
@@ -96,28 +96,17 @@ class CRUDBase[
) -> list[ModelType]: ) -> list[ModelType]:
""" """
Get multiple records with pagination validation and optional eager loading. Get multiple records with pagination validation and optional eager loading.
Args:
db: Database session
skip: Number of records to skip
limit: Maximum number of records to return
options: Optional list of SQLAlchemy load options for eager loading
Returns:
List of model instances
""" """
# Validate pagination parameters
if skip < 0: if skip < 0:
raise ValueError("skip must be non-negative") raise InvalidInputError("skip must be non-negative")
if limit < 0: if limit < 0:
raise ValueError("limit must be non-negative") raise InvalidInputError("limit must be non-negative")
if limit > 1000: if limit > 1000:
raise ValueError("Maximum limit is 1000") raise InvalidInputError("Maximum limit is 1000")
try: try:
query = select(self.model).offset(skip).limit(limit) query = select(self.model).order_by(self.model.id).offset(skip).limit(limit)
# Apply eager loading options if provided
if options: if options:
for option in options: for option in options:
query = query.options(option) query = query.options(option)
@@ -126,7 +115,7 @@ class CRUDBase[
return list(result.scalars().all()) return list(result.scalars().all())
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Error retrieving multiple {self.model.__name__} records: {e!s}" "Error retrieving multiple %s records: %s", self.model.__name__, e
) )
raise raise
@@ -136,9 +125,8 @@ class CRUDBase[
"""Create a new record with error handling. """Create a new record with error handling.
NOTE: This method is defensive code that's never called in practice. NOTE: This method is defensive code that's never called in practice.
All CRUD subclasses (CRUDUser, CRUDOrganization, CRUDSession) override this method All repository subclasses override this method with their own implementations.
with their own implementations, so the base implementation and its exception handlers Marked as pragma: no cover to avoid false coverage gaps.
are never executed. Marked as pragma: no cover to avoid false coverage gaps.
""" """
try: # pragma: no cover try: # pragma: no cover
obj_in_data = jsonable_encoder(obj_in) obj_in_data = jsonable_encoder(obj_in)
@@ -152,22 +140,24 @@ class CRUDBase[
error_msg = str(e.orig) if hasattr(e, "orig") else str(e) error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower(): if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
logger.warning( logger.warning(
f"Duplicate entry attempted for {self.model.__name__}: {error_msg}" "Duplicate entry attempted for %s: %s",
self.model.__name__,
error_msg,
) )
raise ValueError( raise DuplicateEntryError(
f"A {self.model.__name__} with this data already exists" f"A {self.model.__name__} with this data already exists"
) )
logger.error(f"Integrity error creating {self.model.__name__}: {error_msg}") logger.error(
raise ValueError(f"Database integrity error: {error_msg}") "Integrity error creating %s: %s", self.model.__name__, error_msg
)
raise IntegrityConstraintError(f"Database integrity error: {error_msg}")
except (OperationalError, DataError) as e: # pragma: no cover except (OperationalError, DataError) as e: # pragma: no cover
await db.rollback() await db.rollback()
logger.error(f"Database error creating {self.model.__name__}: {e!s}") logger.error("Database error creating %s: %s", self.model.__name__, e)
raise ValueError(f"Database operation failed: {e!s}") raise IntegrityConstraintError(f"Database operation failed: {e!s}")
except Exception as e: # pragma: no cover except Exception as e: # pragma: no cover
await db.rollback() await db.rollback()
logger.error( logger.exception("Unexpected error creating %s: %s", self.model.__name__, e)
f"Unexpected error creating {self.model.__name__}: {e!s}", exc_info=True
)
raise raise
async def update( async def update(
@@ -198,34 +188,35 @@ class CRUDBase[
error_msg = str(e.orig) if hasattr(e, "orig") else str(e) error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower(): if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
logger.warning( logger.warning(
f"Duplicate entry attempted for {self.model.__name__}: {error_msg}" "Duplicate entry attempted for %s: %s",
self.model.__name__,
error_msg,
) )
raise ValueError( raise DuplicateEntryError(
f"A {self.model.__name__} with this data already exists" f"A {self.model.__name__} with this data already exists"
) )
logger.error(f"Integrity error updating {self.model.__name__}: {error_msg}") logger.error(
raise ValueError(f"Database integrity error: {error_msg}") "Integrity error updating %s: %s", self.model.__name__, error_msg
)
raise IntegrityConstraintError(f"Database integrity error: {error_msg}")
except (OperationalError, DataError) as e: except (OperationalError, DataError) as e:
await db.rollback() await db.rollback()
logger.error(f"Database error updating {self.model.__name__}: {e!s}") logger.error("Database error updating %s: %s", self.model.__name__, e)
raise ValueError(f"Database operation failed: {e!s}") raise IntegrityConstraintError(f"Database operation failed: {e!s}")
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error( logger.exception("Unexpected error updating %s: %s", self.model.__name__, e)
f"Unexpected error updating {self.model.__name__}: {e!s}", exc_info=True
)
raise raise
async def remove(self, db: AsyncSession, *, id: str) -> ModelType | None: async def remove(self, db: AsyncSession, *, id: str) -> ModelType | None:
"""Delete a record with error handling and null check.""" """Delete a record with error handling and null check."""
# Validate UUID format and convert to UUID object if string
try: try:
if isinstance(id, uuid.UUID): if isinstance(id, uuid.UUID):
uuid_obj = id uuid_obj = id
else: else:
uuid_obj = uuid.UUID(str(id)) uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e: except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format for deletion: {id} - {e!s}") logger.warning("Invalid UUID format for deletion: %s - %s", id, e)
return None return None
try: try:
@@ -236,7 +227,7 @@ class CRUDBase[
if obj is None: if obj is None:
logger.warning( logger.warning(
f"{self.model.__name__} with id {id} not found for deletion" "%s with id %s not found for deletion", self.model.__name__, id
) )
return None return None
@@ -246,15 +237,16 @@ class CRUDBase[
except IntegrityError as e: except IntegrityError as e:
await db.rollback() await db.rollback()
error_msg = str(e.orig) if hasattr(e, "orig") else str(e) error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
logger.error(f"Integrity error deleting {self.model.__name__}: {error_msg}") logger.error(
raise ValueError( "Integrity error deleting %s: %s", self.model.__name__, error_msg
)
raise IntegrityConstraintError(
f"Cannot delete {self.model.__name__}: referenced by other records" f"Cannot delete {self.model.__name__}: referenced by other records"
) )
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error( logger.exception(
f"Error deleting {self.model.__name__} with id {id}: {e!s}", "Error deleting %s with id %s: %s", self.model.__name__, id, e
exc_info=True,
) )
raise raise
@@ -272,57 +264,40 @@ class CRUDBase[
Get multiple records with total count, filtering, and sorting. Get multiple records with total count, filtering, and sorting.
NOTE: This method is defensive code that's never called in practice. NOTE: This method is defensive code that's never called in practice.
All CRUD subclasses (CRUDUser, CRUDOrganization, CRUDSession) override this method All repository subclasses override this method with their own implementations.
with their own implementations that include additional parameters like search.
Marked as pragma: no cover to avoid false coverage gaps. Marked as pragma: no cover to avoid false coverage gaps.
Args:
db: Database session
skip: Number of records to skip
limit: Maximum number of records to return
sort_by: Field name to sort by (must be a valid model attribute)
sort_order: Sort order ("asc" or "desc")
filters: Dictionary of filters (field_name: value)
Returns:
Tuple of (items, total_count)
""" """
# Validate pagination parameters
if skip < 0: if skip < 0:
raise ValueError("skip must be non-negative") raise InvalidInputError("skip must be non-negative")
if limit < 0: if limit < 0:
raise ValueError("limit must be non-negative") raise InvalidInputError("limit must be non-negative")
if limit > 1000: if limit > 1000:
raise ValueError("Maximum limit is 1000") raise InvalidInputError("Maximum limit is 1000")
try: try:
# Build base query
query = select(self.model) query = select(self.model)
# Exclude soft-deleted records by default
if hasattr(self.model, "deleted_at"): if hasattr(self.model, "deleted_at"):
query = query.where(self.model.deleted_at.is_(None)) query = query.where(self.model.deleted_at.is_(None))
# Apply filters
if filters: if filters:
for field, value in filters.items(): for field, value in filters.items():
if hasattr(self.model, field) and value is not None: if hasattr(self.model, field) and value is not None:
query = query.where(getattr(self.model, field) == value) query = query.where(getattr(self.model, field) == value)
# Get total count (before pagination)
count_query = select(func.count()).select_from(query.alias()) count_query = select(func.count()).select_from(query.alias())
count_result = await db.execute(count_query) count_result = await db.execute(count_query)
total = count_result.scalar_one() total = count_result.scalar_one()
# Apply sorting
if sort_by and hasattr(self.model, sort_by): if sort_by and hasattr(self.model, sort_by):
sort_column = getattr(self.model, sort_by) sort_column = getattr(self.model, sort_by)
if sort_order.lower() == "desc": if sort_order.lower() == "desc":
query = query.order_by(sort_column.desc()) query = query.order_by(sort_column.desc())
else: else:
query = query.order_by(sort_column.asc()) query = query.order_by(sort_column.asc())
else:
query = query.order_by(self.model.id)
# Apply pagination
query = query.offset(skip).limit(limit) query = query.offset(skip).limit(limit)
items_result = await db.execute(query) items_result = await db.execute(query)
items = list(items_result.scalars().all()) items = list(items_result.scalars().all())
@@ -330,7 +305,7 @@ class CRUDBase[
return items, total return items, total
except Exception as e: # pragma: no cover except Exception as e: # pragma: no cover
logger.error( logger.error(
f"Error retrieving paginated {self.model.__name__} records: {e!s}" "Error retrieving paginated %s records: %s", self.model.__name__, e
) )
raise raise
@@ -340,7 +315,7 @@ class CRUDBase[
result = await db.execute(select(func.count(self.model.id))) result = await db.execute(select(func.count(self.model.id)))
return result.scalar_one() return result.scalar_one()
except Exception as e: except Exception as e:
logger.error(f"Error counting {self.model.__name__} records: {e!s}") logger.error("Error counting %s records: %s", self.model.__name__, e)
raise raise
async def exists(self, db: AsyncSession, id: str) -> bool: async def exists(self, db: AsyncSession, id: str) -> bool:
@@ -356,14 +331,13 @@ class CRUDBase[
""" """
from datetime import datetime from datetime import datetime
# Validate UUID format and convert to UUID object if string
try: try:
if isinstance(id, uuid.UUID): if isinstance(id, uuid.UUID):
uuid_obj = id uuid_obj = id
else: else:
uuid_obj = uuid.UUID(str(id)) uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e: except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format for soft deletion: {id} - {e!s}") logger.warning("Invalid UUID format for soft deletion: %s - %s", id, e)
return None return None
try: try:
@@ -374,18 +348,16 @@ class CRUDBase[
if obj is None: if obj is None:
logger.warning( logger.warning(
f"{self.model.__name__} with id {id} not found for soft deletion" "%s with id %s not found for soft deletion", self.model.__name__, id
) )
return None return None
# Check if model supports soft deletes
if not hasattr(self.model, "deleted_at"): if not hasattr(self.model, "deleted_at"):
logger.error(f"{self.model.__name__} does not support soft deletes") logger.error("%s does not support soft deletes", self.model.__name__)
raise ValueError( raise InvalidInputError(
f"{self.model.__name__} does not have a deleted_at column" f"{self.model.__name__} does not have a deleted_at column"
) )
# Set deleted_at timestamp
obj.deleted_at = datetime.now(UTC) obj.deleted_at = datetime.now(UTC)
db.add(obj) db.add(obj)
await db.commit() await db.commit()
@@ -393,9 +365,8 @@ class CRUDBase[
return obj return obj
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error( logger.exception(
f"Error soft deleting {self.model.__name__} with id {id}: {e!s}", "Error soft deleting %s with id %s: %s", self.model.__name__, id, e
exc_info=True,
) )
raise raise
@@ -405,18 +376,16 @@ class CRUDBase[
Only works if the model has a 'deleted_at' column. Only works if the model has a 'deleted_at' column.
""" """
# Validate UUID format
try: try:
if isinstance(id, uuid.UUID): if isinstance(id, uuid.UUID):
uuid_obj = id uuid_obj = id
else: else:
uuid_obj = uuid.UUID(str(id)) uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e: except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format for restoration: {id} - {e!s}") logger.warning("Invalid UUID format for restoration: %s - %s", id, e)
return None return None
try: try:
# Find the soft-deleted record
if hasattr(self.model, "deleted_at"): if hasattr(self.model, "deleted_at"):
result = await db.execute( result = await db.execute(
select(self.model).where( select(self.model).where(
@@ -425,18 +394,19 @@ class CRUDBase[
) )
obj = result.scalar_one_or_none() obj = result.scalar_one_or_none()
else: else:
logger.error(f"{self.model.__name__} does not support soft deletes") logger.error("%s does not support soft deletes", self.model.__name__)
raise ValueError( raise InvalidInputError(
f"{self.model.__name__} does not have a deleted_at column" f"{self.model.__name__} does not have a deleted_at column"
) )
if obj is None: if obj is None:
logger.warning( logger.warning(
f"Soft-deleted {self.model.__name__} with id {id} not found for restoration" "Soft-deleted %s with id %s not found for restoration",
self.model.__name__,
id,
) )
return None return None
# Clear deleted_at timestamp
obj.deleted_at = None obj.deleted_at = None
db.add(obj) db.add(obj)
await db.commit() await db.commit()
@@ -444,8 +414,7 @@ class CRUDBase[
return obj return obj
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error( logger.exception(
f"Error restoring {self.model.__name__} with id {id}: {e!s}", "Error restoring %s with id %s: %s", self.model.__name__, id, e
exc_info=True,
) )
raise raise

View File

@@ -0,0 +1,249 @@
# app/repositories/oauth_account.py
"""Repository for OAuthAccount model async database operations."""
import logging
from datetime import datetime
from uuid import UUID
from pydantic import BaseModel
from sqlalchemy import and_, delete, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from app.core.repository_exceptions import DuplicateEntryError
from app.models.oauth_account import OAuthAccount
from app.repositories.base import BaseRepository
from app.schemas.oauth import OAuthAccountCreate
logger = logging.getLogger(__name__)
class EmptySchema(BaseModel):
"""Placeholder schema for repository operations that don't need update schemas."""
class OAuthAccountRepository(
BaseRepository[OAuthAccount, OAuthAccountCreate, EmptySchema]
):
"""Repository for OAuth account links."""
async def get_by_provider_id(
self,
db: AsyncSession,
*,
provider: str,
provider_user_id: str,
) -> OAuthAccount | None:
"""Get OAuth account by provider and provider user ID."""
try:
result = await db.execute(
select(OAuthAccount)
.where(
and_(
OAuthAccount.provider == provider,
OAuthAccount.provider_user_id == provider_user_id,
)
)
.options(joinedload(OAuthAccount.user))
)
return result.scalar_one_or_none()
except Exception as e: # pragma: no cover
logger.error(
"Error getting OAuth account for %s:%s: %s",
provider,
provider_user_id,
e,
)
raise
async def get_by_provider_email(
self,
db: AsyncSession,
*,
provider: str,
email: str,
) -> OAuthAccount | None:
"""Get OAuth account by provider and email."""
try:
result = await db.execute(
select(OAuthAccount)
.where(
and_(
OAuthAccount.provider == provider,
OAuthAccount.provider_email == email,
)
)
.options(joinedload(OAuthAccount.user))
)
return result.scalar_one_or_none()
except Exception as e: # pragma: no cover
logger.error(
"Error getting OAuth account for %s email %s: %s", provider, email, e
)
raise
async def get_user_accounts(
self,
db: AsyncSession,
*,
user_id: str | UUID,
) -> list[OAuthAccount]:
"""Get all OAuth accounts linked to a user."""
try:
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
result = await db.execute(
select(OAuthAccount)
.where(OAuthAccount.user_id == user_uuid)
.order_by(OAuthAccount.created_at.desc())
)
return list(result.scalars().all())
except Exception as e: # pragma: no cover
logger.error("Error getting OAuth accounts for user %s: %s", user_id, e)
raise
async def get_user_account_by_provider(
self,
db: AsyncSession,
*,
user_id: str | UUID,
provider: str,
) -> OAuthAccount | None:
"""Get a specific OAuth account for a user and provider."""
try:
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
result = await db.execute(
select(OAuthAccount).where(
and_(
OAuthAccount.user_id == user_uuid,
OAuthAccount.provider == provider,
)
)
)
return result.scalar_one_or_none()
except Exception as e: # pragma: no cover
logger.error(
"Error getting OAuth account for user %s, provider %s: %s",
user_id,
provider,
e,
)
raise
async def create_account(
self, db: AsyncSession, *, obj_in: OAuthAccountCreate
) -> OAuthAccount:
"""Create a new OAuth account link."""
try:
db_obj = OAuthAccount(
user_id=obj_in.user_id,
provider=obj_in.provider,
provider_user_id=obj_in.provider_user_id,
provider_email=obj_in.provider_email,
access_token=obj_in.access_token,
refresh_token=obj_in.refresh_token,
token_expires_at=obj_in.token_expires_at,
)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
logger.info(
"OAuth account created: %s linked to user %s",
obj_in.provider,
obj_in.user_id,
)
return db_obj
except IntegrityError as e: # pragma: no cover
await db.rollback()
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
if "uq_oauth_provider_user" in error_msg.lower():
logger.warning(
"OAuth account already exists: %s:%s",
obj_in.provider,
obj_in.provider_user_id,
)
raise DuplicateEntryError(
f"This {obj_in.provider} account is already linked to another user"
)
logger.error("Integrity error creating OAuth account: %s", error_msg)
raise DuplicateEntryError(f"Failed to create OAuth account: {error_msg}")
except Exception as e: # pragma: no cover
await db.rollback()
logger.exception("Error creating OAuth account: %s", e)
raise
async def delete_account(
self,
db: AsyncSession,
*,
user_id: str | UUID,
provider: str,
) -> bool:
"""Delete an OAuth account link."""
try:
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
result = await db.execute(
delete(OAuthAccount).where(
and_(
OAuthAccount.user_id == user_uuid,
OAuthAccount.provider == provider,
)
)
)
await db.commit()
deleted = result.rowcount > 0
if deleted:
logger.info(
"OAuth account deleted: %s unlinked from user %s", provider, user_id
)
else:
logger.warning(
"OAuth account not found for deletion: %s for user %s",
provider,
user_id,
)
return deleted
except Exception as e: # pragma: no cover
await db.rollback()
logger.error(
"Error deleting OAuth account %s for user %s: %s", provider, user_id, e
)
raise
async def update_tokens(
self,
db: AsyncSession,
*,
account: OAuthAccount,
access_token: str | None = None,
refresh_token: str | None = None,
token_expires_at: datetime | None = None,
) -> OAuthAccount:
"""Update OAuth tokens for an account."""
try:
if access_token is not None:
account.access_token = access_token
if refresh_token is not None:
account.refresh_token = refresh_token
if token_expires_at is not None:
account.token_expires_at = token_expires_at
db.add(account)
await db.commit()
await db.refresh(account)
return account
except Exception as e: # pragma: no cover
await db.rollback()
logger.error("Error updating OAuth tokens: %s", e)
raise
# Singleton instance
oauth_account_repo = OAuthAccountRepository(OAuthAccount)

View File

@@ -0,0 +1,108 @@
# app/repositories/oauth_authorization_code.py
"""Repository for OAuthAuthorizationCode model."""
import logging
from datetime import UTC, datetime
from uuid import UUID
from sqlalchemy import and_, delete, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.oauth_authorization_code import OAuthAuthorizationCode
logger = logging.getLogger(__name__)
class OAuthAuthorizationCodeRepository:
"""Repository for OAuth 2.0 authorization codes."""
async def create_code(
self,
db: AsyncSession,
*,
code: str,
client_id: str,
user_id: UUID,
redirect_uri: str,
scope: str,
expires_at: datetime,
code_challenge: str | None = None,
code_challenge_method: str | None = None,
state: str | None = None,
nonce: str | None = None,
) -> OAuthAuthorizationCode:
"""Create and persist a new authorization code."""
auth_code = OAuthAuthorizationCode(
code=code,
client_id=client_id,
user_id=user_id,
redirect_uri=redirect_uri,
scope=scope,
code_challenge=code_challenge,
code_challenge_method=code_challenge_method,
state=state,
nonce=nonce,
expires_at=expires_at,
used=False,
)
db.add(auth_code)
await db.commit()
return auth_code
async def consume_code_atomically(
self, db: AsyncSession, *, code: str
) -> UUID | None:
"""
Atomically mark a code as used and return its UUID.
Returns the UUID if the code was found and not yet used, None otherwise.
This prevents race conditions per RFC 6749 Section 4.1.2.
"""
stmt = (
update(OAuthAuthorizationCode)
.where(
and_(
OAuthAuthorizationCode.code == code,
OAuthAuthorizationCode.used == False, # noqa: E712
)
)
.values(used=True)
.returning(OAuthAuthorizationCode.id)
)
result = await db.execute(stmt)
row_id = result.scalar_one_or_none()
if row_id is not None:
await db.commit()
return row_id
async def get_by_id(
self, db: AsyncSession, *, code_id: UUID
) -> OAuthAuthorizationCode | None:
"""Get authorization code by its UUID primary key."""
result = await db.execute(
select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.id == code_id)
)
return result.scalar_one_or_none()
async def get_by_code(
self, db: AsyncSession, *, code: str
) -> OAuthAuthorizationCode | None:
"""Get authorization code by the code string value."""
result = await db.execute(
select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.code == code)
)
return result.scalar_one_or_none()
async def cleanup_expired(self, db: AsyncSession) -> int:
"""Delete all expired authorization codes. Returns count deleted."""
result = await db.execute(
delete(OAuthAuthorizationCode).where(
OAuthAuthorizationCode.expires_at < datetime.now(UTC)
)
)
await db.commit()
return result.rowcount # type: ignore[attr-defined]
# Singleton instance
oauth_authorization_code_repo = OAuthAuthorizationCodeRepository()

View File

@@ -0,0 +1,201 @@
# app/repositories/oauth_client.py
"""Repository for OAuthClient model async database operations."""
import logging
import secrets
from uuid import UUID
from pydantic import BaseModel
from sqlalchemy import and_, delete, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.repository_exceptions import DuplicateEntryError
from app.models.oauth_client import OAuthClient
from app.repositories.base import BaseRepository
from app.schemas.oauth import OAuthClientCreate
logger = logging.getLogger(__name__)
class EmptySchema(BaseModel):
"""Placeholder schema for repository operations that don't need update schemas."""
class OAuthClientRepository(
BaseRepository[OAuthClient, OAuthClientCreate, EmptySchema]
):
"""Repository for OAuth clients (provider mode)."""
async def get_by_client_id(
self, db: AsyncSession, *, client_id: str
) -> OAuthClient | None:
"""Get OAuth client by client_id."""
try:
result = await db.execute(
select(OAuthClient).where(
and_(
OAuthClient.client_id == client_id,
OAuthClient.is_active == True, # noqa: E712
)
)
)
return result.scalar_one_or_none()
except Exception as e: # pragma: no cover
logger.error("Error getting OAuth client %s: %s", client_id, e)
raise
async def create_client(
self,
db: AsyncSession,
*,
obj_in: OAuthClientCreate,
owner_user_id: UUID | None = None,
) -> tuple[OAuthClient, str | None]:
"""Create a new OAuth client."""
try:
client_id = secrets.token_urlsafe(32)
client_secret = None
client_secret_hash = None
if obj_in.client_type == "confidential":
client_secret = secrets.token_urlsafe(48)
from app.core.auth import get_password_hash
client_secret_hash = get_password_hash(client_secret)
db_obj = OAuthClient(
client_id=client_id,
client_secret_hash=client_secret_hash,
client_name=obj_in.client_name,
client_description=obj_in.client_description,
client_type=obj_in.client_type,
redirect_uris=obj_in.redirect_uris,
allowed_scopes=obj_in.allowed_scopes,
owner_user_id=owner_user_id,
is_active=True,
)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
logger.info(
"OAuth client created: %s (%s...)", obj_in.client_name, client_id[:8]
)
return db_obj, client_secret
except IntegrityError as e: # pragma: no cover
await db.rollback()
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
logger.error("Error creating OAuth client: %s", error_msg)
raise DuplicateEntryError(f"Failed to create OAuth client: {error_msg}")
except Exception as e: # pragma: no cover
await db.rollback()
logger.exception("Error creating OAuth client: %s", e)
raise
async def deactivate_client(
self, db: AsyncSession, *, client_id: str
) -> OAuthClient | None:
"""Deactivate an OAuth client."""
try:
client = await self.get_by_client_id(db, client_id=client_id)
if client is None:
return None
client.is_active = False
db.add(client)
await db.commit()
await db.refresh(client)
logger.info("OAuth client deactivated: %s", client.client_name)
return client
except Exception as e: # pragma: no cover
await db.rollback()
logger.error("Error deactivating OAuth client %s: %s", client_id, e)
raise
async def validate_redirect_uri(
self, db: AsyncSession, *, client_id: str, redirect_uri: str
) -> bool:
"""Validate that a redirect URI is allowed for a client."""
try:
client = await self.get_by_client_id(db, client_id=client_id)
if client is None:
return False
return redirect_uri in (client.redirect_uris or [])
except Exception as e: # pragma: no cover
logger.error("Error validating redirect URI: %s", e)
return False
async def verify_client_secret(
self, db: AsyncSession, *, client_id: str, client_secret: str
) -> bool:
"""Verify client credentials."""
try:
result = await db.execute(
select(OAuthClient).where(
and_(
OAuthClient.client_id == client_id,
OAuthClient.is_active == True, # noqa: E712
)
)
)
client = result.scalar_one_or_none()
if client is None or client.client_secret_hash is None:
return False
from app.core.auth import verify_password
stored_hash: str = str(client.client_secret_hash)
if stored_hash.startswith("$2"):
return verify_password(client_secret, stored_hash)
else:
import hashlib
secret_hash = hashlib.sha256(client_secret.encode()).hexdigest()
return secrets.compare_digest(stored_hash, secret_hash)
except Exception as e: # pragma: no cover
logger.error("Error verifying client secret: %s", e)
return False
async def get_all_clients(
self, db: AsyncSession, *, include_inactive: bool = False
) -> list[OAuthClient]:
"""Get all OAuth clients."""
try:
query = select(OAuthClient).order_by(OAuthClient.created_at.desc())
if not include_inactive:
query = query.where(OAuthClient.is_active == True) # noqa: E712
result = await db.execute(query)
return list(result.scalars().all())
except Exception as e: # pragma: no cover
logger.error("Error getting all OAuth clients: %s", e)
raise
async def delete_client(self, db: AsyncSession, *, client_id: str) -> bool:
"""Delete an OAuth client permanently."""
try:
result = await db.execute(
delete(OAuthClient).where(OAuthClient.client_id == client_id)
)
await db.commit()
deleted = result.rowcount > 0
if deleted:
logger.info("OAuth client deleted: %s", client_id)
else:
logger.warning("OAuth client not found for deletion: %s", client_id)
return deleted
except Exception as e: # pragma: no cover
await db.rollback()
logger.error("Error deleting OAuth client %s: %s", client_id, e)
raise
# Singleton instance
oauth_client_repo = OAuthClientRepository(OAuthClient)

View File

@@ -0,0 +1,113 @@
# app/repositories/oauth_consent.py
"""Repository for OAuthConsent model."""
import logging
from typing import Any
from uuid import UUID
from sqlalchemy import and_, delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.oauth_client import OAuthClient
from app.models.oauth_provider_token import OAuthConsent
logger = logging.getLogger(__name__)
class OAuthConsentRepository:
"""Repository for OAuth consent records (user grants to clients)."""
async def get_consent(
self, db: AsyncSession, *, user_id: UUID, client_id: str
) -> OAuthConsent | None:
"""Get the consent record for a user-client pair, or None if not found."""
result = await db.execute(
select(OAuthConsent).where(
and_(
OAuthConsent.user_id == user_id,
OAuthConsent.client_id == client_id,
)
)
)
return result.scalar_one_or_none()
async def grant_consent(
self,
db: AsyncSession,
*,
user_id: UUID,
client_id: str,
scopes: list[str],
) -> OAuthConsent:
"""
Create or update consent for a user-client pair.
If consent already exists, the new scopes are merged with existing ones.
Returns the created or updated consent record.
"""
consent = await self.get_consent(db, user_id=user_id, client_id=client_id)
if consent:
existing = (
set(consent.granted_scopes.split()) if consent.granted_scopes else set()
)
merged = existing | set(scopes)
consent.granted_scopes = " ".join(sorted(merged)) # type: ignore[assignment]
else:
consent = OAuthConsent(
user_id=user_id,
client_id=client_id,
granted_scopes=" ".join(sorted(set(scopes))),
)
db.add(consent)
await db.commit()
await db.refresh(consent)
return consent
async def get_user_consents_with_clients(
self, db: AsyncSession, *, user_id: UUID
) -> list[dict[str, Any]]:
"""Get all consent records for a user joined with client details."""
result = await db.execute(
select(OAuthConsent, OAuthClient)
.join(OAuthClient, OAuthConsent.client_id == OAuthClient.client_id)
.where(OAuthConsent.user_id == user_id)
)
rows = result.all()
return [
{
"client_id": consent.client_id,
"client_name": client.client_name,
"client_description": client.client_description,
"granted_scopes": consent.granted_scopes.split()
if consent.granted_scopes
else [],
"granted_at": consent.created_at.isoformat(),
}
for consent, client in rows
]
async def revoke_consent(
self, db: AsyncSession, *, user_id: UUID, client_id: str
) -> bool:
"""
Delete the consent record for a user-client pair.
Returns True if a record was found and deleted.
Note: Callers are responsible for also revoking associated tokens.
"""
result = await db.execute(
delete(OAuthConsent).where(
and_(
OAuthConsent.user_id == user_id,
OAuthConsent.client_id == client_id,
)
)
)
await db.commit()
return result.rowcount > 0 # type: ignore[attr-defined]
# Singleton instance
oauth_consent_repo = OAuthConsentRepository()

View File

@@ -0,0 +1,142 @@
# app/repositories/oauth_provider_token.py
"""Repository for OAuthProviderRefreshToken model."""
import logging
from datetime import UTC, datetime, timedelta
from uuid import UUID
from sqlalchemy import and_, delete, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.oauth_provider_token import OAuthProviderRefreshToken
logger = logging.getLogger(__name__)
class OAuthProviderTokenRepository:
"""Repository for OAuth provider refresh tokens."""
async def create_token(
self,
db: AsyncSession,
*,
token_hash: str,
jti: str,
client_id: str,
user_id: UUID,
scope: str,
expires_at: datetime,
device_info: str | None = None,
ip_address: str | None = None,
) -> OAuthProviderRefreshToken:
"""Create and persist a new refresh token record."""
token = OAuthProviderRefreshToken(
token_hash=token_hash,
jti=jti,
client_id=client_id,
user_id=user_id,
scope=scope,
expires_at=expires_at,
device_info=device_info,
ip_address=ip_address,
)
db.add(token)
await db.commit()
return token
async def get_by_token_hash(
self, db: AsyncSession, *, token_hash: str
) -> OAuthProviderRefreshToken | None:
"""Get refresh token record by SHA-256 token hash."""
result = await db.execute(
select(OAuthProviderRefreshToken).where(
OAuthProviderRefreshToken.token_hash == token_hash
)
)
return result.scalar_one_or_none()
async def get_by_jti(
self, db: AsyncSession, *, jti: str
) -> OAuthProviderRefreshToken | None:
"""Get refresh token record by JWT ID (JTI)."""
result = await db.execute(
select(OAuthProviderRefreshToken).where(
OAuthProviderRefreshToken.jti == jti
)
)
return result.scalar_one_or_none()
async def revoke(
self, db: AsyncSession, *, token: OAuthProviderRefreshToken
) -> None:
"""Mark a specific token record as revoked."""
token.revoked = True # type: ignore[assignment]
token.last_used_at = datetime.now(UTC) # type: ignore[assignment]
await db.commit()
async def revoke_all_for_user_client(
self, db: AsyncSession, *, user_id: UUID, client_id: str
) -> int:
"""
Revoke all active tokens for a specific user-client pair.
Used when security incidents are detected (e.g., authorization code reuse).
Returns the number of tokens revoked.
"""
result = await db.execute(
update(OAuthProviderRefreshToken)
.where(
and_(
OAuthProviderRefreshToken.user_id == user_id,
OAuthProviderRefreshToken.client_id == client_id,
OAuthProviderRefreshToken.revoked == False, # noqa: E712
)
)
.values(revoked=True)
)
count = result.rowcount # type: ignore[attr-defined]
if count > 0:
await db.commit()
return count
async def revoke_all_for_user(self, db: AsyncSession, *, user_id: UUID) -> int:
"""
Revoke all active tokens for a user across all clients.
Used when user changes password or logs out everywhere.
Returns the number of tokens revoked.
"""
result = await db.execute(
update(OAuthProviderRefreshToken)
.where(
and_(
OAuthProviderRefreshToken.user_id == user_id,
OAuthProviderRefreshToken.revoked == False, # noqa: E712
)
)
.values(revoked=True)
)
count = result.rowcount # type: ignore[attr-defined]
if count > 0:
await db.commit()
return count
async def cleanup_expired(self, db: AsyncSession, *, cutoff_days: int = 7) -> int:
"""
Delete expired refresh tokens older than cutoff_days.
Should be called periodically (e.g., daily).
Returns the number of tokens deleted.
"""
cutoff = datetime.now(UTC) - timedelta(days=cutoff_days)
result = await db.execute(
delete(OAuthProviderRefreshToken).where(
OAuthProviderRefreshToken.expires_at < cutoff
)
)
await db.commit()
return result.rowcount # type: ignore[attr-defined]
# Singleton instance
oauth_provider_token_repo = OAuthProviderTokenRepository()

View File

@@ -0,0 +1,113 @@
# app/repositories/oauth_state.py
"""Repository for OAuthState model async database operations."""
import logging
from datetime import UTC, datetime
from pydantic import BaseModel
from sqlalchemy import delete, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.repository_exceptions import DuplicateEntryError
from app.models.oauth_state import OAuthState
from app.repositories.base import BaseRepository
from app.schemas.oauth import OAuthStateCreate
logger = logging.getLogger(__name__)
class EmptySchema(BaseModel):
"""Placeholder schema for repository operations that don't need update schemas."""
class OAuthStateRepository(BaseRepository[OAuthState, OAuthStateCreate, EmptySchema]):
"""Repository for OAuth state (CSRF protection)."""
async def create_state(
self, db: AsyncSession, *, obj_in: OAuthStateCreate
) -> OAuthState:
"""Create a new OAuth state for CSRF protection."""
try:
db_obj = OAuthState(
state=obj_in.state,
code_verifier=obj_in.code_verifier,
nonce=obj_in.nonce,
provider=obj_in.provider,
redirect_uri=obj_in.redirect_uri,
user_id=obj_in.user_id,
expires_at=obj_in.expires_at,
)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
logger.debug("OAuth state created for %s", obj_in.provider)
return db_obj
except IntegrityError as e: # pragma: no cover
await db.rollback()
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
logger.error("OAuth state collision: %s", error_msg)
raise DuplicateEntryError("Failed to create OAuth state, please retry")
except Exception as e: # pragma: no cover
await db.rollback()
logger.exception("Error creating OAuth state: %s", e)
raise
async def get_and_consume_state(
self, db: AsyncSession, *, state: str
) -> OAuthState | None:
"""Get and delete OAuth state (consume it)."""
try:
result = await db.execute(
select(OAuthState).where(OAuthState.state == state)
)
db_obj = result.scalar_one_or_none()
if db_obj is None:
logger.warning("OAuth state not found: %s...", state[:8])
return None
now = datetime.now(UTC)
expires_at = db_obj.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=UTC)
if expires_at < now:
logger.warning("OAuth state expired: %s...", state[:8])
await db.delete(db_obj)
await db.commit()
return None
await db.delete(db_obj)
await db.commit()
logger.debug("OAuth state consumed: %s...", state[:8])
return db_obj
except Exception as e: # pragma: no cover
await db.rollback()
logger.error("Error consuming OAuth state: %s", e)
raise
async def cleanup_expired(self, db: AsyncSession) -> int:
"""Clean up expired OAuth states."""
try:
now = datetime.now(UTC)
stmt = delete(OAuthState).where(OAuthState.expires_at < now)
result = await db.execute(stmt)
await db.commit()
count = result.rowcount
if count > 0:
logger.info("Cleaned up %s expired OAuth states", count)
return count
except Exception as e: # pragma: no cover
await db.rollback()
logger.error("Error cleaning up expired OAuth states: %s", e)
raise
# Singleton instance
oauth_state_repo = OAuthStateRepository(OAuthState)

View File

@@ -1,5 +1,5 @@
# app/crud/organization_async.py # app/repositories/organization.py
"""Async CRUD operations for Organization model using SQLAlchemy 2.0 patterns.""" """Repository for Organization model async database operations using SQLAlchemy 2.0 patterns."""
import logging import logging
from typing import Any from typing import Any
@@ -9,10 +9,11 @@ from sqlalchemy import and_, case, func, or_, select
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.crud.base import CRUDBase from app.core.repository_exceptions import DuplicateEntryError, IntegrityConstraintError
from app.models.organization import Organization from app.models.organization import Organization
from app.models.user import User from app.models.user import User
from app.models.user_organization import OrganizationRole, UserOrganization from app.models.user_organization import OrganizationRole, UserOrganization
from app.repositories.base import BaseRepository
from app.schemas.organizations import ( from app.schemas.organizations import (
OrganizationCreate, OrganizationCreate,
OrganizationUpdate, OrganizationUpdate,
@@ -21,8 +22,10 @@ from app.schemas.organizations import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUpdate]): class OrganizationRepository(
"""Async CRUD operations for Organization model.""" BaseRepository[Organization, OrganizationCreate, OrganizationUpdate]
):
"""Repository for Organization model."""
async def get_by_slug(self, db: AsyncSession, *, slug: str) -> Organization | None: async def get_by_slug(self, db: AsyncSession, *, slug: str) -> Organization | None:
"""Get organization by slug.""" """Get organization by slug."""
@@ -32,7 +35,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
) )
return result.scalar_one_or_none() return result.scalar_one_or_none()
except Exception as e: except Exception as e:
logger.error(f"Error getting organization by slug {slug}: {e!s}") logger.error("Error getting organization by slug %s: %s", slug, e)
raise raise
async def create( async def create(
@@ -54,18 +57,20 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
except IntegrityError as e: except IntegrityError as e:
await db.rollback() await db.rollback()
error_msg = str(e.orig) if hasattr(e, "orig") else str(e) error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
if "slug" in error_msg.lower(): if (
logger.warning(f"Duplicate slug attempted: {obj_in.slug}") "slug" in error_msg.lower()
raise ValueError( or "unique" in error_msg.lower()
or "duplicate" in error_msg.lower()
):
logger.warning("Duplicate slug attempted: %s", obj_in.slug)
raise DuplicateEntryError(
f"Organization with slug '{obj_in.slug}' already exists" f"Organization with slug '{obj_in.slug}' already exists"
) )
logger.error(f"Integrity error creating organization: {error_msg}") logger.error("Integrity error creating organization: %s", error_msg)
raise ValueError(f"Database integrity error: {error_msg}") raise IntegrityConstraintError(f"Database integrity error: {error_msg}")
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error( logger.exception("Unexpected error creating organization: %s", e)
f"Unexpected error creating organization: {e!s}", exc_info=True
)
raise raise
async def get_multi_with_filters( async def get_multi_with_filters(
@@ -79,16 +84,10 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
sort_by: str = "created_at", sort_by: str = "created_at",
sort_order: str = "desc", sort_order: str = "desc",
) -> tuple[list[Organization], int]: ) -> tuple[list[Organization], int]:
""" """Get multiple organizations with filtering, searching, and sorting."""
Get multiple organizations with filtering, searching, and sorting.
Returns:
Tuple of (organizations list, total count)
"""
try: try:
query = select(Organization) query = select(Organization)
# Apply filters
if is_active is not None: if is_active is not None:
query = query.where(Organization.is_active == is_active) query = query.where(Organization.is_active == is_active)
@@ -100,26 +99,23 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
) )
query = query.where(search_filter) query = query.where(search_filter)
# Get total count before pagination
count_query = select(func.count()).select_from(query.alias()) count_query = select(func.count()).select_from(query.alias())
count_result = await db.execute(count_query) count_result = await db.execute(count_query)
total = count_result.scalar_one() total = count_result.scalar_one()
# Apply sorting
sort_column = getattr(Organization, sort_by, Organization.created_at) sort_column = getattr(Organization, sort_by, Organization.created_at)
if sort_order == "desc": if sort_order == "desc":
query = query.order_by(sort_column.desc()) query = query.order_by(sort_column.desc())
else: else:
query = query.order_by(sort_column.asc()) query = query.order_by(sort_column.asc())
# Apply pagination
query = query.offset(skip).limit(limit) query = query.offset(skip).limit(limit)
result = await db.execute(query) result = await db.execute(query)
organizations = list(result.scalars().all()) organizations = list(result.scalars().all())
return organizations, total return organizations, total
except Exception as e: except Exception as e:
logger.error(f"Error getting organizations with filters: {e!s}") logger.error("Error getting organizations with filters: %s", e)
raise raise
async def get_member_count(self, db: AsyncSession, *, organization_id: UUID) -> int: async def get_member_count(self, db: AsyncSession, *, organization_id: UUID) -> int:
@@ -136,7 +132,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
return result.scalar_one() or 0 return result.scalar_one() or 0
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Error getting member count for organization {organization_id}: {e!s}" "Error getting member count for organization %s: %s", organization_id, e
) )
raise raise
@@ -149,16 +145,8 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
is_active: bool | None = None, is_active: bool | None = None,
search: str | None = None, search: str | None = None,
) -> tuple[list[dict[str, Any]], int]: ) -> tuple[list[dict[str, Any]], int]:
""" """Get organizations with member counts in a SINGLE QUERY using JOIN and GROUP BY."""
Get organizations with member counts in a SINGLE QUERY using JOIN and GROUP BY.
This eliminates the N+1 query problem.
Returns:
Tuple of (list of dicts with org and member_count, total count)
"""
try: try:
# Build base query with LEFT JOIN and GROUP BY
# Use CASE statement to count only active members
query = ( query = (
select( select(
Organization, Organization,
@@ -181,10 +169,10 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
.group_by(Organization.id) .group_by(Organization.id)
) )
# Apply filters
if is_active is not None: if is_active is not None:
query = query.where(Organization.is_active == is_active) query = query.where(Organization.is_active == is_active)
search_filter = None
if search: if search:
search_filter = or_( search_filter = or_(
Organization.name.ilike(f"%{search}%"), Organization.name.ilike(f"%{search}%"),
@@ -193,17 +181,15 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
) )
query = query.where(search_filter) query = query.where(search_filter)
# Get total count
count_query = select(func.count(Organization.id)) count_query = select(func.count(Organization.id))
if is_active is not None: if is_active is not None:
count_query = count_query.where(Organization.is_active == is_active) count_query = count_query.where(Organization.is_active == is_active)
if search: if search_filter is not None:
count_query = count_query.where(search_filter) count_query = count_query.where(search_filter)
count_result = await db.execute(count_query) count_result = await db.execute(count_query)
total = count_result.scalar_one() total = count_result.scalar_one()
# Apply pagination and ordering
query = ( query = (
query.order_by(Organization.created_at.desc()).offset(skip).limit(limit) query.order_by(Organization.created_at.desc()).offset(skip).limit(limit)
) )
@@ -211,7 +197,6 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
result = await db.execute(query) result = await db.execute(query)
rows = result.all() rows = result.all()
# Convert to list of dicts
orgs_with_counts = [ orgs_with_counts = [
{"organization": org, "member_count": member_count} {"organization": org, "member_count": member_count}
for org, member_count in rows for org, member_count in rows
@@ -220,9 +205,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
return orgs_with_counts, total return orgs_with_counts, total
except Exception as e: except Exception as e:
logger.error( logger.exception("Error getting organizations with member counts: %s", e)
f"Error getting organizations with member counts: {e!s}", exc_info=True
)
raise raise
async def add_user( async def add_user(
@@ -236,7 +219,6 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
) -> UserOrganization: ) -> UserOrganization:
"""Add a user to an organization with a specific role.""" """Add a user to an organization with a specific role."""
try: try:
# Check if relationship already exists
result = await db.execute( result = await db.execute(
select(UserOrganization).where( select(UserOrganization).where(
and_( and_(
@@ -248,7 +230,6 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
existing = result.scalar_one_or_none() existing = result.scalar_one_or_none()
if existing: if existing:
# Reactivate if inactive, or raise error if already active
if not existing.is_active: if not existing.is_active:
existing.is_active = True existing.is_active = True
existing.role = role existing.role = role
@@ -257,9 +238,10 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
await db.refresh(existing) await db.refresh(existing)
return existing return existing
else: else:
raise ValueError("User is already a member of this organization") raise DuplicateEntryError(
"User is already a member of this organization"
)
# Create new relationship
user_org = UserOrganization( user_org = UserOrganization(
user_id=user_id, user_id=user_id,
organization_id=organization_id, organization_id=organization_id,
@@ -273,11 +255,11 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
return user_org return user_org
except IntegrityError as e: except IntegrityError as e:
await db.rollback() await db.rollback()
logger.error(f"Integrity error adding user to organization: {e!s}") logger.error("Integrity error adding user to organization: %s", e)
raise ValueError("Failed to add user to organization") raise IntegrityConstraintError("Failed to add user to organization")
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error(f"Error adding user to organization: {e!s}", exc_info=True) logger.exception("Error adding user to organization: %s", e)
raise raise
async def remove_user( async def remove_user(
@@ -303,7 +285,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
return True return True
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error(f"Error removing user from organization: {e!s}", exc_info=True) logger.exception("Error removing user from organization: %s", e)
raise raise
async def update_user_role( async def update_user_role(
@@ -338,7 +320,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
return user_org return user_org
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error(f"Error updating user role: {e!s}", exc_info=True) logger.exception("Error updating user role: %s", e)
raise raise
async def get_organization_members( async def get_organization_members(
@@ -348,16 +330,10 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
organization_id: UUID, organization_id: UUID,
skip: int = 0, skip: int = 0,
limit: int = 100, limit: int = 100,
is_active: bool = True, is_active: bool | None = True,
) -> tuple[list[dict[str, Any]], int]: ) -> tuple[list[dict[str, Any]], int]:
""" """Get members of an organization with user details."""
Get members of an organization with user details.
Returns:
Tuple of (members list with user details, total count)
"""
try: try:
# Build query with join
query = ( query = (
select(UserOrganization, User) select(UserOrganization, User)
.join(User, UserOrganization.user_id == User.id) .join(User, UserOrganization.user_id == User.id)
@@ -367,7 +343,6 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
if is_active is not None: if is_active is not None:
query = query.where(UserOrganization.is_active == is_active) query = query.where(UserOrganization.is_active == is_active)
# Get total count
count_query = select(func.count()).select_from( count_query = select(func.count()).select_from(
select(UserOrganization) select(UserOrganization)
.where(UserOrganization.organization_id == organization_id) .where(UserOrganization.organization_id == organization_id)
@@ -381,7 +356,6 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
count_result = await db.execute(count_query) count_result = await db.execute(count_query)
total = count_result.scalar_one() total = count_result.scalar_one()
# Apply ordering and pagination
query = ( query = (
query.order_by(UserOrganization.created_at.desc()) query.order_by(UserOrganization.created_at.desc())
.offset(skip) .offset(skip)
@@ -406,11 +380,11 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
return members, total return members, total
except Exception as e: except Exception as e:
logger.error(f"Error getting organization members: {e!s}") logger.error("Error getting organization members: %s", e)
raise raise
async def get_user_organizations( async def get_user_organizations(
self, db: AsyncSession, *, user_id: UUID, is_active: bool = True self, db: AsyncSession, *, user_id: UUID, is_active: bool | None = True
) -> list[Organization]: ) -> list[Organization]:
"""Get all organizations a user belongs to.""" """Get all organizations a user belongs to."""
try: try:
@@ -429,21 +403,14 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
result = await db.execute(query) result = await db.execute(query)
return list(result.scalars().all()) return list(result.scalars().all())
except Exception as e: except Exception as e:
logger.error(f"Error getting user organizations: {e!s}") logger.error("Error getting user organizations: %s", e)
raise raise
async def get_user_organizations_with_details( async def get_user_organizations_with_details(
self, db: AsyncSession, *, user_id: UUID, is_active: bool = True self, db: AsyncSession, *, user_id: UUID, is_active: bool | None = True
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
""" """Get user's organizations with role and member count in SINGLE QUERY."""
Get user's organizations with role and member count in SINGLE QUERY.
Eliminates N+1 problem by using subquery for member counts.
Returns:
List of dicts with organization, role, and member_count
"""
try: try:
# Subquery to get member counts for each organization
member_count_subq = ( member_count_subq = (
select( select(
UserOrganization.organization_id, UserOrganization.organization_id,
@@ -454,7 +421,6 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
.subquery() .subquery()
) )
# Main query with JOIN to get org, role, and member count
query = ( query = (
select( select(
Organization, Organization,
@@ -486,9 +452,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
] ]
except Exception as e: except Exception as e:
logger.error( logger.exception("Error getting user organizations with details: %s", e)
f"Error getting user organizations with details: {e!s}", exc_info=True
)
raise raise
async def get_user_role_in_org( async def get_user_role_in_org(
@@ -507,9 +471,9 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
) )
user_org = result.scalar_one_or_none() user_org = result.scalar_one_or_none()
return user_org.role if user_org else None return user_org.role if user_org else None # pyright: ignore[reportReturnType]
except Exception as e: except Exception as e:
logger.error(f"Error getting user role in org: {e!s}") logger.error("Error getting user role in org: %s", e)
raise raise
async def is_user_org_owner( async def is_user_org_owner(
@@ -531,5 +495,5 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
return role in [OrganizationRole.OWNER, OrganizationRole.ADMIN] return role in [OrganizationRole.OWNER, OrganizationRole.ADMIN]
# Create a singleton instance for use across the application # Singleton instance
organization = CRUDOrganization(Organization) organization_repo = OrganizationRepository(Organization)

View File

@@ -1,6 +1,5 @@
""" # app/repositories/session.py
Async CRUD operations for user sessions using SQLAlchemy 2.0 patterns. """Repository for UserSession model async database operations using SQLAlchemy 2.0 patterns."""
"""
import logging import logging
import uuid import uuid
@@ -11,49 +10,32 @@ from sqlalchemy import and_, delete, func, select, update
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload
from app.crud.base import CRUDBase from app.core.repository_exceptions import IntegrityConstraintError, InvalidInputError
from app.models.user_session import UserSession from app.models.user_session import UserSession
from app.repositories.base import BaseRepository
from app.schemas.sessions import SessionCreate, SessionUpdate from app.schemas.sessions import SessionCreate, SessionUpdate
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): class SessionRepository(BaseRepository[UserSession, SessionCreate, SessionUpdate]):
"""Async CRUD operations for user sessions.""" """Repository for UserSession model."""
async def get_by_jti(self, db: AsyncSession, *, jti: str) -> UserSession | None: async def get_by_jti(self, db: AsyncSession, *, jti: str) -> UserSession | None:
""" """Get session by refresh token JTI."""
Get session by refresh token JTI.
Args:
db: Database session
jti: Refresh token JWT ID
Returns:
UserSession if found, None otherwise
"""
try: try:
result = await db.execute( result = await db.execute(
select(UserSession).where(UserSession.refresh_token_jti == jti) select(UserSession).where(UserSession.refresh_token_jti == jti)
) )
return result.scalar_one_or_none() return result.scalar_one_or_none()
except Exception as e: except Exception as e:
logger.error(f"Error getting session by JTI {jti}: {e!s}") logger.error("Error getting session by JTI %s: %s", jti, e)
raise raise
async def get_active_by_jti( async def get_active_by_jti(
self, db: AsyncSession, *, jti: str self, db: AsyncSession, *, jti: str
) -> UserSession | None: ) -> UserSession | None:
""" """Get active session by refresh token JTI."""
Get active session by refresh token JTI.
Args:
db: Database session
jti: Refresh token JWT ID
Returns:
Active UserSession if found, None otherwise
"""
try: try:
result = await db.execute( result = await db.execute(
select(UserSession).where( select(UserSession).where(
@@ -65,7 +47,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
) )
return result.scalar_one_or_none() return result.scalar_one_or_none()
except Exception as e: except Exception as e:
logger.error(f"Error getting active session by JTI {jti}: {e!s}") logger.error("Error getting active session by JTI %s: %s", jti, e)
raise raise
async def get_user_sessions( async def get_user_sessions(
@@ -76,25 +58,12 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
active_only: bool = True, active_only: bool = True,
with_user: bool = False, with_user: bool = False,
) -> list[UserSession]: ) -> list[UserSession]:
""" """Get all sessions for a user with optional eager loading."""
Get all sessions for a user with optional eager loading.
Args:
db: Database session
user_id: User ID
active_only: If True, return only active sessions
with_user: If True, eager load user relationship to prevent N+1
Returns:
List of UserSession objects
"""
try: try:
# Convert user_id string to UUID if needed
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
query = select(UserSession).where(UserSession.user_id == user_uuid) query = select(UserSession).where(UserSession.user_id == user_uuid)
# Add eager loading if requested to prevent N+1 queries
if with_user: if with_user:
query = query.options(joinedload(UserSession.user)) query = query.options(joinedload(UserSession.user))
@@ -105,25 +74,13 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
result = await db.execute(query) result = await db.execute(query)
return list(result.scalars().all()) return list(result.scalars().all())
except Exception as e: except Exception as e:
logger.error(f"Error getting sessions for user {user_id}: {e!s}") logger.error("Error getting sessions for user %s: %s", user_id, e)
raise raise
async def create_session( async def create_session(
self, db: AsyncSession, *, obj_in: SessionCreate self, db: AsyncSession, *, obj_in: SessionCreate
) -> UserSession: ) -> UserSession:
""" """Create a new user session."""
Create a new user session.
Args:
db: Database session
obj_in: SessionCreate schema with session data
Returns:
Created UserSession
Raises:
ValueError: If session creation fails
"""
try: try:
db_obj = UserSession( db_obj = UserSession(
user_id=obj_in.user_id, user_id=obj_in.user_id,
@@ -143,33 +100,26 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
await db.refresh(db_obj) await db.refresh(db_obj)
logger.info( logger.info(
f"Session created for user {obj_in.user_id} from {obj_in.device_name} " "Session created for user %s from %s (IP: %s)",
f"(IP: {obj_in.ip_address})" obj_in.user_id,
obj_in.device_name,
obj_in.ip_address,
) )
return db_obj return db_obj
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error(f"Error creating session: {e!s}", exc_info=True) logger.exception("Error creating session: %s", e)
raise ValueError(f"Failed to create session: {e!s}") raise IntegrityConstraintError(f"Failed to create session: {e!s}")
async def deactivate( async def deactivate(
self, db: AsyncSession, *, session_id: str self, db: AsyncSession, *, session_id: str
) -> UserSession | None: ) -> UserSession | None:
""" """Deactivate a session (logout from device)."""
Deactivate a session (logout from device).
Args:
db: Database session
session_id: Session UUID
Returns:
Deactivated UserSession if found, None otherwise
"""
try: try:
session = await self.get(db, id=session_id) session = await self.get(db, id=session_id)
if not session: if not session:
logger.warning(f"Session {session_id} not found for deactivation") logger.warning("Session %s not found for deactivation", session_id)
return None return None
session.is_active = False session.is_active = False
@@ -178,31 +128,23 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
await db.refresh(session) await db.refresh(session)
logger.info( logger.info(
f"Session {session_id} deactivated for user {session.user_id} " "Session %s deactivated for user %s (%s)",
f"({session.device_name})" session_id,
session.user_id,
session.device_name,
) )
return session return session
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error(f"Error deactivating session {session_id}: {e!s}") logger.error("Error deactivating session %s: %s", session_id, e)
raise raise
async def deactivate_all_user_sessions( async def deactivate_all_user_sessions(
self, db: AsyncSession, *, user_id: str self, db: AsyncSession, *, user_id: str
) -> int: ) -> int:
""" """Deactivate all active sessions for a user (logout from all devices)."""
Deactivate all active sessions for a user (logout from all devices).
Args:
db: Database session
user_id: User ID
Returns:
Number of sessions deactivated
"""
try: try:
# Convert user_id string to UUID if needed
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
stmt = ( stmt = (
@@ -216,27 +158,18 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
count = result.rowcount count = result.rowcount
logger.info(f"Deactivated {count} sessions for user {user_id}") logger.info("Deactivated %s sessions for user %s", count, user_id)
return count return count
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error(f"Error deactivating all sessions for user {user_id}: {e!s}") logger.error("Error deactivating all sessions for user %s: %s", user_id, e)
raise raise
async def update_last_used( async def update_last_used(
self, db: AsyncSession, *, session: UserSession self, db: AsyncSession, *, session: UserSession
) -> UserSession: ) -> UserSession:
""" """Update the last_used_at timestamp for a session."""
Update the last_used_at timestamp for a session.
Args:
db: Database session
session: UserSession object
Returns:
Updated UserSession
"""
try: try:
session.last_used_at = datetime.now(UTC) session.last_used_at = datetime.now(UTC)
db.add(session) db.add(session)
@@ -245,7 +178,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
return session return session
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error(f"Error updating last_used for session {session.id}: {e!s}") logger.error("Error updating last_used for session %s: %s", session.id, e)
raise raise
async def update_refresh_token( async def update_refresh_token(
@@ -256,20 +189,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
new_jti: str, new_jti: str,
new_expires_at: datetime, new_expires_at: datetime,
) -> UserSession: ) -> UserSession:
""" """Update session with new refresh token JTI and expiration."""
Update session with new refresh token JTI and expiration.
Called during token refresh.
Args:
db: Database session
session: UserSession object
new_jti: New refresh token JTI
new_expires_at: New expiration datetime
Returns:
Updated UserSession
"""
try: try:
session.refresh_token_jti = new_jti session.refresh_token_jti = new_jti
session.expires_at = new_expires_at session.expires_at = new_expires_at
@@ -281,32 +201,16 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error( logger.error(
f"Error updating refresh token for session {session.id}: {e!s}" "Error updating refresh token for session %s: %s", session.id, e
) )
raise raise
async def cleanup_expired(self, db: AsyncSession, *, keep_days: int = 30) -> int: async def cleanup_expired(self, db: AsyncSession, *, keep_days: int = 30) -> int:
""" """Clean up expired sessions using optimized bulk DELETE."""
Clean up expired sessions using optimized bulk DELETE.
Deletes sessions that are:
- Expired AND inactive
- Older than keep_days
Uses single DELETE query instead of N individual deletes for efficiency.
Args:
db: Database session
keep_days: Keep inactive sessions for this many days (for audit)
Returns:
Number of sessions deleted
"""
try: try:
cutoff_date = datetime.now(UTC) - timedelta(days=keep_days) cutoff_date = datetime.now(UTC) - timedelta(days=keep_days)
now = datetime.now(UTC) now = datetime.now(UTC)
# Use bulk DELETE with WHERE clause - single query
stmt = delete(UserSession).where( stmt = delete(UserSession).where(
and_( and_(
UserSession.is_active == False, # noqa: E712 UserSession.is_active == False, # noqa: E712
@@ -321,38 +225,25 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
count = result.rowcount count = result.rowcount
if count > 0: if count > 0:
logger.info(f"Cleaned up {count} expired sessions using bulk DELETE") logger.info("Cleaned up %s expired sessions using bulk DELETE", count)
return count return count
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error(f"Error cleaning up expired sessions: {e!s}") logger.error("Error cleaning up expired sessions: %s", e)
raise raise
async def cleanup_expired_for_user(self, db: AsyncSession, *, user_id: str) -> int: async def cleanup_expired_for_user(self, db: AsyncSession, *, user_id: str) -> int:
""" """Clean up expired and inactive sessions for a specific user."""
Clean up expired and inactive sessions for a specific user.
Uses single bulk DELETE query for efficiency instead of N individual deletes.
Args:
db: Database session
user_id: User ID to cleanup sessions for
Returns:
Number of sessions deleted
"""
try: try:
# Validate UUID
try: try:
uuid_obj = uuid.UUID(user_id) uuid_obj = uuid.UUID(user_id)
except (ValueError, AttributeError): except (ValueError, AttributeError):
logger.error(f"Invalid UUID format: {user_id}") logger.error("Invalid UUID format: %s", user_id)
raise ValueError(f"Invalid user ID format: {user_id}") raise InvalidInputError(f"Invalid user ID format: {user_id}")
now = datetime.now(UTC) now = datetime.now(UTC)
# Use bulk DELETE with WHERE clause - single query
stmt = delete(UserSession).where( stmt = delete(UserSession).where(
and_( and_(
UserSession.user_id == uuid_obj, UserSession.user_id == uuid_obj,
@@ -368,30 +259,22 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
if count > 0: if count > 0:
logger.info( logger.info(
f"Cleaned up {count} expired sessions for user {user_id} using bulk DELETE" "Cleaned up %s expired sessions for user %s using bulk DELETE",
count,
user_id,
) )
return count return count
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error( logger.error(
f"Error cleaning up expired sessions for user {user_id}: {e!s}" "Error cleaning up expired sessions for user %s: %s", user_id, e
) )
raise raise
async def get_user_session_count(self, db: AsyncSession, *, user_id: str) -> int: async def get_user_session_count(self, db: AsyncSession, *, user_id: str) -> int:
""" """Get count of active sessions for a user."""
Get count of active sessions for a user.
Args:
db: Database session
user_id: User ID
Returns:
Number of active sessions
"""
try: try:
# Convert user_id string to UUID if needed
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
result = await db.execute( result = await db.execute(
@@ -401,7 +284,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
) )
return result.scalar_one() return result.scalar_one()
except Exception as e: except Exception as e:
logger.error(f"Error counting sessions for user {user_id}: {e!s}") logger.error("Error counting sessions for user %s: %s", user_id, e)
raise raise
async def get_all_sessions( async def get_all_sessions(
@@ -413,31 +296,16 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
active_only: bool = True, active_only: bool = True,
with_user: bool = True, with_user: bool = True,
) -> tuple[list[UserSession], int]: ) -> tuple[list[UserSession], int]:
""" """Get all sessions across all users with pagination (admin only)."""
Get all sessions across all users with pagination (admin only).
Args:
db: Database session
skip: Number of records to skip
limit: Maximum number of records to return
active_only: If True, return only active sessions
with_user: If True, eager load user relationship to prevent N+1
Returns:
Tuple of (list of UserSession objects, total count)
"""
try: try:
# Build query
query = select(UserSession) query = select(UserSession)
# Add eager loading if requested to prevent N+1 queries
if with_user: if with_user:
query = query.options(joinedload(UserSession.user)) query = query.options(joinedload(UserSession.user))
if active_only: if active_only:
query = query.where(UserSession.is_active) query = query.where(UserSession.is_active)
# Get total count
count_query = select(func.count(UserSession.id)) count_query = select(func.count(UserSession.id))
if active_only: if active_only:
count_query = count_query.where(UserSession.is_active) count_query = count_query.where(UserSession.is_active)
@@ -445,7 +313,6 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
count_result = await db.execute(count_query) count_result = await db.execute(count_query)
total = count_result.scalar_one() total = count_result.scalar_one()
# Apply pagination and ordering
query = ( query = (
query.order_by(UserSession.last_used_at.desc()) query.order_by(UserSession.last_used_at.desc())
.offset(skip) .offset(skip)
@@ -458,9 +325,9 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
return sessions, total return sessions, total
except Exception as e: except Exception as e:
logger.error(f"Error getting all sessions: {e!s}", exc_info=True) logger.exception("Error getting all sessions: %s", e)
raise raise
# Create singleton instance # Singleton instance
session = CRUDSession(UserSession) session_repo = SessionRepository(UserSession)

View File

@@ -1,5 +1,5 @@
# app/crud/user_async.py # app/repositories/user.py
"""Async CRUD operations for User model using SQLAlchemy 2.0 patterns.""" """Repository for User model async database operations using SQLAlchemy 2.0 patterns."""
import logging import logging
from datetime import UTC, datetime from datetime import UTC, datetime
@@ -11,15 +11,16 @@ from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.core.auth import get_password_hash_async from app.core.auth import get_password_hash_async
from app.crud.base import CRUDBase from app.core.repository_exceptions import DuplicateEntryError, InvalidInputError
from app.models.user import User from app.models.user import User
from app.repositories.base import BaseRepository
from app.schemas.users import UserCreate, UserUpdate from app.schemas.users import UserCreate, UserUpdate
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]): class UserRepository(BaseRepository[User, UserCreate, UserUpdate]):
"""Async CRUD operations for User model.""" """Repository for User model."""
async def get_by_email(self, db: AsyncSession, *, email: str) -> User | None: async def get_by_email(self, db: AsyncSession, *, email: str) -> User | None:
"""Get user by email address.""" """Get user by email address."""
@@ -27,13 +28,12 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
result = await db.execute(select(User).where(User.email == email)) result = await db.execute(select(User).where(User.email == email))
return result.scalar_one_or_none() return result.scalar_one_or_none()
except Exception as e: except Exception as e:
logger.error(f"Error getting user by email {email}: {e!s}") logger.error("Error getting user by email %s: %s", email, e)
raise raise
async def create(self, db: AsyncSession, *, obj_in: UserCreate) -> User: async def create(self, db: AsyncSession, *, obj_in: UserCreate) -> User:
"""Create a new user with async password hashing and error handling.""" """Create a new user with async password hashing and error handling."""
try: try:
# Hash password asynchronously to avoid blocking event loop
password_hash = await get_password_hash_async(obj_in.password) password_hash = await get_password_hash_async(obj_in.password)
db_obj = User( db_obj = User(
@@ -57,13 +57,49 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
await db.rollback() await db.rollback()
error_msg = str(e.orig) if hasattr(e, "orig") else str(e) error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
if "email" in error_msg.lower(): if "email" in error_msg.lower():
logger.warning(f"Duplicate email attempted: {obj_in.email}") logger.warning("Duplicate email attempted: %s", obj_in.email)
raise ValueError(f"User with email {obj_in.email} already exists") raise DuplicateEntryError(
logger.error(f"Integrity error creating user: {error_msg}") f"User with email {obj_in.email} already exists"
raise ValueError(f"Database integrity error: {error_msg}") )
logger.error("Integrity error creating user: %s", error_msg)
raise DuplicateEntryError(f"Database integrity error: {error_msg}")
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error(f"Unexpected error creating user: {e!s}", exc_info=True) logger.exception("Unexpected error creating user: %s", e)
raise
async def create_oauth_user(
self,
db: AsyncSession,
*,
email: str,
first_name: str = "User",
last_name: str | None = None,
) -> User:
"""Create a new passwordless user for OAuth sign-in."""
try:
db_obj = User(
email=email,
password_hash=None, # OAuth-only user
first_name=first_name,
last_name=last_name,
is_active=True,
is_superuser=False,
)
db.add(db_obj)
await db.flush() # Get user.id without committing
return db_obj
except IntegrityError as e:
await db.rollback()
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
if "email" in error_msg.lower():
logger.warning("Duplicate email attempted: %s", email)
raise DuplicateEntryError(f"User with email {email} already exists")
logger.error("Integrity error creating OAuth user: %s", error_msg)
raise DuplicateEntryError(f"Database integrity error: {error_msg}")
except Exception as e:
await db.rollback()
logger.exception("Unexpected error creating OAuth user: %s", e)
raise raise
async def update( async def update(
@@ -75,8 +111,6 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
else: else:
update_data = obj_in.model_dump(exclude_unset=True) update_data = obj_in.model_dump(exclude_unset=True)
# Handle password separately if it exists in update data
# Hash password asynchronously to avoid blocking event loop
if "password" in update_data: if "password" in update_data:
update_data["password_hash"] = await get_password_hash_async( update_data["password_hash"] = await get_password_hash_async(
update_data["password"] update_data["password"]
@@ -85,6 +119,15 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
return await super().update(db, db_obj=db_obj, obj_in=update_data) return await super().update(db, db_obj=db_obj, obj_in=update_data)
async def update_password(
self, db: AsyncSession, *, user: User, password_hash: str
) -> User:
"""Set a new password hash on a user and commit."""
user.password_hash = password_hash
await db.commit()
await db.refresh(user)
return user
async def get_multi_with_total( async def get_multi_with_total(
self, self,
db: AsyncSession, db: AsyncSession,
@@ -96,43 +139,23 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
filters: dict[str, Any] | None = None, filters: dict[str, Any] | None = None,
search: str | None = None, search: str | None = None,
) -> tuple[list[User], int]: ) -> tuple[list[User], int]:
""" """Get multiple users with total count, filtering, sorting, and search."""
Get multiple users with total count, filtering, sorting, and search.
Args:
db: Database session
skip: Number of records to skip
limit: Maximum number of records to return
sort_by: Field name to sort by
sort_order: Sort order ("asc" or "desc")
filters: Dictionary of filters (field_name: value)
search: Search term to match against email, first_name, last_name
Returns:
Tuple of (users list, total count)
"""
# Validate pagination
if skip < 0: if skip < 0:
raise ValueError("skip must be non-negative") raise InvalidInputError("skip must be non-negative")
if limit < 0: if limit < 0:
raise ValueError("limit must be non-negative") raise InvalidInputError("limit must be non-negative")
if limit > 1000: if limit > 1000:
raise ValueError("Maximum limit is 1000") raise InvalidInputError("Maximum limit is 1000")
try: try:
# Build base query
query = select(User) query = select(User)
# Exclude soft-deleted users
query = query.where(User.deleted_at.is_(None)) query = query.where(User.deleted_at.is_(None))
# Apply filters
if filters: if filters:
for field, value in filters.items(): for field, value in filters.items():
if hasattr(User, field) and value is not None: if hasattr(User, field) and value is not None:
query = query.where(getattr(User, field) == value) query = query.where(getattr(User, field) == value)
# Apply search
if search: if search:
search_filter = or_( search_filter = or_(
User.email.ilike(f"%{search}%"), User.email.ilike(f"%{search}%"),
@@ -141,14 +164,12 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
) )
query = query.where(search_filter) query = query.where(search_filter)
# Get total count
from sqlalchemy import func from sqlalchemy import func
count_query = select(func.count()).select_from(query.alias()) count_query = select(func.count()).select_from(query.alias())
count_result = await db.execute(count_query) count_result = await db.execute(count_query)
total = count_result.scalar_one() total = count_result.scalar_one()
# Apply sorting
if sort_by and hasattr(User, sort_by): if sort_by and hasattr(User, sort_by):
sort_column = getattr(User, sort_by) sort_column = getattr(User, sort_by)
if sort_order.lower() == "desc": if sort_order.lower() == "desc":
@@ -156,7 +177,6 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
else: else:
query = query.order_by(sort_column.asc()) query = query.order_by(sort_column.asc())
# Apply pagination
query = query.offset(skip).limit(limit) query = query.offset(skip).limit(limit)
result = await db.execute(query) result = await db.execute(query)
users = list(result.scalars().all()) users = list(result.scalars().all())
@@ -164,32 +184,21 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
return users, total return users, total
except Exception as e: except Exception as e:
logger.error(f"Error retrieving paginated users: {e!s}") logger.error("Error retrieving paginated users: %s", e)
raise raise
async def bulk_update_status( async def bulk_update_status(
self, db: AsyncSession, *, user_ids: list[UUID], is_active: bool self, db: AsyncSession, *, user_ids: list[UUID], is_active: bool
) -> int: ) -> int:
""" """Bulk update is_active status for multiple users."""
Bulk update is_active status for multiple users.
Args:
db: Database session
user_ids: List of user IDs to update
is_active: New active status
Returns:
Number of users updated
"""
try: try:
if not user_ids: if not user_ids:
return 0 return 0
# Use UPDATE with WHERE IN for efficiency
stmt = ( stmt = (
update(User) update(User)
.where(User.id.in_(user_ids)) .where(User.id.in_(user_ids))
.where(User.deleted_at.is_(None)) # Don't update deleted users .where(User.deleted_at.is_(None))
.values(is_active=is_active, updated_at=datetime.now(UTC)) .values(is_active=is_active, updated_at=datetime.now(UTC))
) )
@@ -197,12 +206,14 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
await db.commit() await db.commit()
updated_count = result.rowcount updated_count = result.rowcount
logger.info(f"Bulk updated {updated_count} users to is_active={is_active}") logger.info(
"Bulk updated %s users to is_active=%s", updated_count, is_active
)
return updated_count return updated_count
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error(f"Error bulk updating user status: {e!s}", exc_info=True) logger.exception("Error bulk updating user status: %s", e)
raise raise
async def bulk_soft_delete( async def bulk_soft_delete(
@@ -212,34 +223,20 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
user_ids: list[UUID], user_ids: list[UUID],
exclude_user_id: UUID | None = None, exclude_user_id: UUID | None = None,
) -> int: ) -> int:
""" """Bulk soft delete multiple users."""
Bulk soft delete multiple users.
Args:
db: Database session
user_ids: List of user IDs to delete
exclude_user_id: Optional user ID to exclude (e.g., the admin performing the action)
Returns:
Number of users deleted
"""
try: try:
if not user_ids: if not user_ids:
return 0 return 0
# Remove excluded user from list
filtered_ids = [uid for uid in user_ids if uid != exclude_user_id] filtered_ids = [uid for uid in user_ids if uid != exclude_user_id]
if not filtered_ids: if not filtered_ids:
return 0 return 0
# Use UPDATE with WHERE IN for efficiency
stmt = ( stmt = (
update(User) update(User)
.where(User.id.in_(filtered_ids)) .where(User.id.in_(filtered_ids))
.where( .where(User.deleted_at.is_(None))
User.deleted_at.is_(None)
) # Don't re-delete already deleted users
.values( .values(
deleted_at=datetime.now(UTC), deleted_at=datetime.now(UTC),
is_active=False, is_active=False,
@@ -251,22 +248,22 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
await db.commit() await db.commit()
deleted_count = result.rowcount deleted_count = result.rowcount
logger.info(f"Bulk soft deleted {deleted_count} users") logger.info("Bulk soft deleted %s users", deleted_count)
return deleted_count return deleted_count
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error(f"Error bulk deleting users: {e!s}", exc_info=True) logger.exception("Error bulk deleting users: %s", e)
raise raise
def is_active(self, user: User) -> bool: def is_active(self, user: User) -> bool:
"""Check if user is active.""" """Check if user is active."""
return user.is_active return bool(user.is_active)
def is_superuser(self, user: User) -> bool: def is_superuser(self, user: User) -> bool:
"""Check if user is a superuser.""" """Check if user is a superuser."""
return user.is_superuser return bool(user.is_superuser)
# Create a singleton instance for use across the application # Singleton instance
user = CRUDUser(User) user_repo = UserRepository(User)

View File

@@ -60,8 +60,8 @@ class OAuthAccountCreate(OAuthAccountBase):
user_id: UUID user_id: UUID
provider_user_id: str = Field(..., max_length=255) provider_user_id: str = Field(..., max_length=255)
access_token_encrypted: str | None = None access_token: str | None = None
refresh_token_encrypted: str | None = None refresh_token: str | None = None
token_expires_at: datetime | None = None token_expires_at: datetime | None = None

View File

@@ -48,7 +48,7 @@ class OrganizationCreate(OrganizationBase):
"""Schema for creating a new organization.""" """Schema for creating a new organization."""
name: str = Field(..., min_length=1, max_length=255) name: str = Field(..., min_length=1, max_length=255)
slug: str = Field(..., min_length=1, max_length=255) slug: str = Field(..., min_length=1, max_length=255) # pyright: ignore[reportIncompatibleVariableOverride]
class OrganizationUpdate(BaseModel): class OrganizationUpdate(BaseModel):

View File

@@ -1,5 +1,19 @@
# app/services/__init__.py # app/services/__init__.py
from . import oauth_provider_service
from .auth_service import AuthService from .auth_service import AuthService
from .oauth_service import OAuthService from .oauth_service import OAuthService
from .organization_service import OrganizationService, organization_service
from .session_service import SessionService, session_service
from .user_service import UserService, user_service
__all__ = ["AuthService", "OAuthService"] __all__ = [
"AuthService",
"OAuthService",
"OrganizationService",
"SessionService",
"UserService",
"oauth_provider_service",
"organization_service",
"session_service",
"user_service",
]

View File

@@ -2,7 +2,6 @@
import logging import logging
from uuid import UUID from uuid import UUID
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.core.auth import ( from app.core.auth import (
@@ -14,12 +13,18 @@ from app.core.auth import (
verify_password_async, verify_password_async,
) )
from app.core.config import settings from app.core.config import settings
from app.core.exceptions import AuthenticationError from app.core.exceptions import AuthenticationError, DuplicateError
from app.core.repository_exceptions import DuplicateEntryError
from app.models.user import User from app.models.user import User
from app.repositories.user import user_repo
from app.schemas.users import Token, UserCreate, UserResponse from app.schemas.users import Token, UserCreate, UserResponse
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Pre-computed bcrypt hash used for constant-time comparison when user is not found,
# preventing timing attacks that could enumerate valid email addresses.
_DUMMY_HASH = "$2b$12$EixZaYVK1fsbw1ZfbX3OXePaWxn96p36zLFbnJHfxPSEFBzXKiHia"
class AuthService: class AuthService:
"""Service for handling authentication operations""" """Service for handling authentication operations"""
@@ -39,10 +44,12 @@ class AuthService:
Returns: Returns:
User if authenticated, None otherwise User if authenticated, None otherwise
""" """
result = await db.execute(select(User).where(User.email == email)) user = await user_repo.get_by_email(db, email=email)
user = result.scalar_one_or_none()
if not user: if not user:
# Perform a dummy verification to match timing of a real bcrypt check,
# preventing email enumeration via response-time differences.
await verify_password_async(password, _DUMMY_HASH)
return None return None
# Verify password asynchronously to avoid blocking event loop # Verify password asynchronously to avoid blocking event loop
@@ -71,40 +78,23 @@ class AuthService:
""" """
try: try:
# Check if user already exists # Check if user already exists
result = await db.execute(select(User).where(User.email == user_data.email)) existing_user = await user_repo.get_by_email(db, email=user_data.email)
existing_user = result.scalar_one_or_none()
if existing_user: if existing_user:
raise AuthenticationError("User with this email already exists") raise DuplicateError("User with this email already exists")
# Create new user with async password hashing # Delegate creation (hashing + commit) to the repository
# Hash password asynchronously to avoid blocking event loop user = await user_repo.create(db, obj_in=user_data)
hashed_password = await get_password_hash_async(user_data.password)
# Create user object from model logger.info("User created successfully: %s", user.email)
user = User(
email=user_data.email,
password_hash=hashed_password,
first_name=user_data.first_name,
last_name=user_data.last_name,
phone_number=user_data.phone_number,
is_active=True,
is_superuser=False,
)
db.add(user)
await db.commit()
await db.refresh(user)
logger.info(f"User created successfully: {user.email}")
return user return user
except AuthenticationError: except (AuthenticationError, DuplicateError):
# Re-raise authentication errors without rollback # Re-raise API exceptions without rollback
raise raise
except DuplicateEntryError as e:
raise DuplicateError(str(e))
except Exception as e: except Exception as e:
# Rollback on any database errors logger.exception("Error creating user: %s", e)
await db.rollback()
logger.error(f"Error creating user: {e!s}", exc_info=True)
raise AuthenticationError(f"Failed to create user: {e!s}") raise AuthenticationError(f"Failed to create user: {e!s}")
@staticmethod @staticmethod
@@ -168,8 +158,7 @@ class AuthService:
user_id = token_data.user_id user_id = token_data.user_id
# Get user from database # Get user from database
result = await db.execute(select(User).where(User.id == user_id)) user = await user_repo.get(db, id=str(user_id))
user = result.scalar_one_or_none()
if not user or not user.is_active: if not user or not user.is_active:
raise TokenInvalidError("Invalid user or inactive account") raise TokenInvalidError("Invalid user or inactive account")
@@ -177,7 +166,7 @@ class AuthService:
return AuthService.create_tokens(user) return AuthService.create_tokens(user)
except (TokenExpiredError, TokenInvalidError) as e: except (TokenExpiredError, TokenInvalidError) as e:
logger.warning(f"Token refresh failed: {e!s}") logger.warning("Token refresh failed: %s", e)
raise raise
@staticmethod @staticmethod
@@ -200,8 +189,7 @@ class AuthService:
AuthenticationError: If current password is incorrect or update fails AuthenticationError: If current password is incorrect or update fails
""" """
try: try:
result = await db.execute(select(User).where(User.id == user_id)) user = await user_repo.get(db, id=str(user_id))
user = result.scalar_one_or_none()
if not user: if not user:
raise AuthenticationError("User not found") raise AuthenticationError("User not found")
@@ -210,10 +198,10 @@ class AuthService:
raise AuthenticationError("Current password is incorrect") raise AuthenticationError("Current password is incorrect")
# Hash new password asynchronously to avoid blocking event loop # Hash new password asynchronously to avoid blocking event loop
user.password_hash = await get_password_hash_async(new_password) new_hash = await get_password_hash_async(new_password)
await db.commit() await user_repo.update_password(db, user=user, password_hash=new_hash)
logger.info(f"Password changed successfully for user {user_id}") logger.info("Password changed successfully for user %s", user_id)
return True return True
except AuthenticationError: except AuthenticationError:
@@ -222,7 +210,34 @@ class AuthService:
except Exception as e: except Exception as e:
# Rollback on any database errors # Rollback on any database errors
await db.rollback() await db.rollback()
logger.error( logger.exception("Error changing password for user %s: %s", user_id, e)
f"Error changing password for user {user_id}: {e!s}", exc_info=True
)
raise AuthenticationError(f"Failed to change password: {e!s}") raise AuthenticationError(f"Failed to change password: {e!s}")
@staticmethod
async def reset_password(
db: AsyncSession, *, email: str, new_password: str
) -> User:
"""
Reset a user's password without requiring the current password.
Args:
db: Database session
email: User email address
new_password: New password to set
Returns:
Updated user
Raises:
AuthenticationError: If user not found or inactive
"""
user = await user_repo.get_by_email(db, email=email)
if not user:
raise AuthenticationError("User not found")
if not user.is_active:
raise AuthenticationError("User account is inactive")
new_hash = await get_password_hash_async(new_password)
user = await user_repo.update_password(db, user=user, password_hash=new_hash)
logger.info("Password reset successfully for %s", email)
return user

View File

@@ -58,8 +58,8 @@ class ConsoleEmailBackend(EmailBackend):
logger.info("=" * 80) logger.info("=" * 80)
logger.info("EMAIL SENT (Console Backend)") logger.info("EMAIL SENT (Console Backend)")
logger.info("=" * 80) logger.info("=" * 80)
logger.info(f"To: {', '.join(to)}") logger.info("To: %s", ", ".join(to))
logger.info(f"Subject: {subject}") logger.info("Subject: %s", subject)
logger.info("-" * 80) logger.info("-" * 80)
if text_content: if text_content:
logger.info("Plain Text Content:") logger.info("Plain Text Content:")
@@ -199,7 +199,7 @@ The {settings.PROJECT_NAME} Team
text_content=text_content, text_content=text_content,
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to send password reset email to {to_email}: {e!s}") logger.error("Failed to send password reset email to %s: %s", to_email, e)
return False return False
async def send_email_verification( async def send_email_verification(
@@ -287,7 +287,7 @@ The {settings.PROJECT_NAME} Team
text_content=text_content, text_content=text_content,
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to send verification email to {to_email}: {e!s}") logger.error("Failed to send verification email to %s: %s", to_email, e)
return False return False

View File

@@ -25,15 +25,19 @@ from datetime import UTC, datetime, timedelta
from typing import Any from typing import Any
from uuid import UUID from uuid import UUID
from jose import jwt import jwt
from sqlalchemy import and_, delete, select from jwt.exceptions import ExpiredSignatureError, InvalidTokenError
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings from app.core.config import settings
from app.models.oauth_authorization_code import OAuthAuthorizationCode
from app.models.oauth_client import OAuthClient from app.models.oauth_client import OAuthClient
from app.models.oauth_provider_token import OAuthConsent, OAuthProviderRefreshToken
from app.models.user import User from app.models.user import User
from app.repositories.oauth_authorization_code import oauth_authorization_code_repo
from app.repositories.oauth_client import oauth_client_repo
from app.repositories.oauth_consent import oauth_consent_repo
from app.repositories.oauth_provider_token import oauth_provider_token_repo
from app.repositories.user import user_repo
from app.schemas.oauth import OAuthClientCreate
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -135,7 +139,7 @@ def verify_pkce(code_verifier: str, code_challenge: str, method: str) -> bool:
if method != "S256": if method != "S256":
# SECURITY: Reject any method other than S256 # SECURITY: Reject any method other than S256
# 'plain' method provides no security against code interception attacks # 'plain' method provides no security against code interception attacks
logger.warning(f"PKCE verification rejected for unsupported method: {method}") logger.warning("PKCE verification rejected for unsupported method: %s", method)
return False return False
# SHA-256 hash, then base64url encode (RFC 7636 Section 4.2) # SHA-256 hash, then base64url encode (RFC 7636 Section 4.2)
@@ -161,15 +165,7 @@ def join_scope(scopes: list[str]) -> str:
async def get_client(db: AsyncSession, client_id: str) -> OAuthClient | None: async def get_client(db: AsyncSession, client_id: str) -> OAuthClient | None:
"""Get OAuth client by client_id.""" """Get OAuth client by client_id."""
result = await db.execute( return await oauth_client_repo.get_by_client_id(db, client_id=client_id)
select(OAuthClient).where(
and_(
OAuthClient.client_id == client_id,
OAuthClient.is_active == True, # noqa: E712
)
)
)
return result.scalar_one_or_none()
async def validate_client( async def validate_client(
@@ -204,21 +200,19 @@ async def validate_client(
if not client.client_secret_hash: if not client.client_secret_hash:
raise InvalidClientError("Client not configured with secret") raise InvalidClientError("Client not configured with secret")
# SECURITY: Verify secret using bcrypt (not SHA-256) # SECURITY: Verify secret using bcrypt
# Supports both bcrypt and legacy SHA-256 hashes for migration
from app.core.auth import verify_password from app.core.auth import verify_password
stored_hash = str(client.client_secret_hash) stored_hash = str(client.client_secret_hash)
if stored_hash.startswith("$2"): if not stored_hash.startswith("$2"):
# New bcrypt format raise InvalidClientError(
if not verify_password(client_secret, stored_hash): "Client secret uses deprecated hash format. "
raise InvalidClientError("Invalid client secret") "Please regenerate your client credentials."
else: )
# Legacy SHA-256 format
computed_hash = hashlib.sha256(client_secret.encode()).hexdigest() if not verify_password(client_secret, stored_hash):
if not secrets.compare_digest(computed_hash, stored_hash): raise InvalidClientError("Invalid client secret")
raise InvalidClientError("Invalid client secret")
return client return client
@@ -263,7 +257,9 @@ def validate_scopes(client: OAuthClient, requested_scopes: list[str]) -> list[st
# Warn if some scopes were filtered out # Warn if some scopes were filtered out
invalid = requested - allowed invalid = requested - allowed
if invalid: if invalid:
logger.warning(f"Client {client.client_id} requested invalid scopes: {invalid}") logger.warning(
"Client %s requested invalid scopes: %s", client.client_id, invalid
)
return list(valid) return list(valid)
@@ -311,25 +307,24 @@ async def create_authorization_code(
minutes=AUTHORIZATION_CODE_EXPIRY_MINUTES minutes=AUTHORIZATION_CODE_EXPIRY_MINUTES
) )
auth_code = OAuthAuthorizationCode( await oauth_authorization_code_repo.create_code(
db,
code=code, code=code,
client_id=client.client_id, client_id=client.client_id,
user_id=user.id, user_id=user.id,
redirect_uri=redirect_uri, redirect_uri=redirect_uri,
scope=scope, scope=scope,
expires_at=expires_at,
code_challenge=code_challenge, code_challenge=code_challenge,
code_challenge_method=code_challenge_method, code_challenge_method=code_challenge_method,
state=state, state=state,
nonce=nonce, nonce=nonce,
expires_at=expires_at,
used=False,
) )
db.add(auth_code)
await db.commit()
logger.info( logger.info(
f"Created authorization code for user {user.id} and client {client.client_id}" "Created authorization code for user %s and client %s",
user.id,
client.client_id,
) )
return code return code
@@ -366,35 +361,20 @@ async def exchange_authorization_code(
""" """
# Atomically mark code as used and fetch it (prevents race condition) # Atomically mark code as used and fetch it (prevents race condition)
# RFC 6749 Section 4.1.2: Authorization codes MUST be single-use # RFC 6749 Section 4.1.2: Authorization codes MUST be single-use
from sqlalchemy import update updated_id = await oauth_authorization_code_repo.consume_code_atomically(
db, code=code
# First, atomically mark the code as used and get affected count
update_stmt = (
update(OAuthAuthorizationCode)
.where(
and_(
OAuthAuthorizationCode.code == code,
OAuthAuthorizationCode.used == False, # noqa: E712
)
)
.values(used=True)
.returning(OAuthAuthorizationCode.id)
) )
result = await db.execute(update_stmt)
updated_id = result.scalar_one_or_none()
if not updated_id: if not updated_id:
# Either code doesn't exist or was already used # Either code doesn't exist or was already used
# Check if it exists to provide appropriate error # Check if it exists to provide appropriate error
check_result = await db.execute( existing_code = await oauth_authorization_code_repo.get_by_code(db, code=code)
select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.code == code)
)
existing_code = check_result.scalar_one_or_none()
if existing_code and existing_code.used: if existing_code and existing_code.used:
# Code reuse is a security incident - revoke all tokens for this grant # Code reuse is a security incident - revoke all tokens for this grant
logger.warning( logger.warning(
f"Authorization code reuse detected for client {existing_code.client_id}" "Authorization code reuse detected for client %s",
existing_code.client_id,
) )
await revoke_tokens_for_user_client( await revoke_tokens_for_user_client(
db, UUID(str(existing_code.user_id)), str(existing_code.client_id) db, UUID(str(existing_code.user_id)), str(existing_code.client_id)
@@ -404,11 +384,9 @@ async def exchange_authorization_code(
raise InvalidGrantError("Invalid authorization code") raise InvalidGrantError("Invalid authorization code")
# Now fetch the full auth code record # Now fetch the full auth code record
auth_code_result = await db.execute( auth_code = await oauth_authorization_code_repo.get_by_id(db, code_id=updated_id)
select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.id == updated_id) if auth_code is None:
) raise InvalidGrantError("Authorization code not found after consumption")
auth_code = auth_code_result.scalar_one()
await db.commit()
if auth_code.is_expired: if auth_code.is_expired:
raise InvalidGrantError("Authorization code has expired") raise InvalidGrantError("Authorization code has expired")
@@ -452,8 +430,7 @@ async def exchange_authorization_code(
raise InvalidGrantError("PKCE required for public clients") raise InvalidGrantError("PKCE required for public clients")
# Get user # Get user
user_result = await db.execute(select(User).where(User.id == auth_code.user_id)) user = await user_repo.get(db, id=str(auth_code.user_id))
user = user_result.scalar_one_or_none()
if not user or not user.is_active: if not user or not user.is_active:
raise InvalidGrantError("User not found or inactive") raise InvalidGrantError("User not found or inactive")
@@ -543,7 +520,8 @@ async def create_tokens(
refresh_token_hash = hash_token(refresh_token) refresh_token_hash = hash_token(refresh_token)
# Store refresh token in database # Store refresh token in database
refresh_token_record = OAuthProviderRefreshToken( await oauth_provider_token_repo.create_token(
db,
token_hash=refresh_token_hash, token_hash=refresh_token_hash,
jti=jti, jti=jti,
client_id=client.client_id, client_id=client.client_id,
@@ -553,10 +531,8 @@ async def create_tokens(
device_info=device_info, device_info=device_info,
ip_address=ip_address, ip_address=ip_address,
) )
db.add(refresh_token_record)
await db.commit()
logger.info(f"Issued tokens for user {user.id} to client {client.client_id}") logger.info("Issued tokens for user %s to client %s", user.id, client.client_id)
return { return {
"access_token": access_token, "access_token": access_token,
@@ -599,12 +575,9 @@ async def refresh_tokens(
""" """
# Find refresh token # Find refresh token
token_hash = hash_token(refresh_token) token_hash = hash_token(refresh_token)
result = await db.execute( token_record = await oauth_provider_token_repo.get_by_token_hash(
select(OAuthProviderRefreshToken).where( db, token_hash=token_hash
OAuthProviderRefreshToken.token_hash == token_hash
)
) )
token_record: OAuthProviderRefreshToken | None = result.scalar_one_or_none()
if not token_record: if not token_record:
raise InvalidGrantError("Invalid refresh token") raise InvalidGrantError("Invalid refresh token")
@@ -612,7 +585,7 @@ async def refresh_tokens(
if token_record.revoked: if token_record.revoked:
# Token reuse after revocation - security incident # Token reuse after revocation - security incident
logger.warning( logger.warning(
f"Revoked refresh token reuse detected for client {token_record.client_id}" "Revoked refresh token reuse detected for client %s", token_record.client_id
) )
raise InvalidGrantError("Refresh token has been revoked") raise InvalidGrantError("Refresh token has been revoked")
@@ -631,8 +604,7 @@ async def refresh_tokens(
) )
# Get user # Get user
user_result = await db.execute(select(User).where(User.id == token_record.user_id)) user = await user_repo.get(db, id=str(token_record.user_id))
user = user_result.scalar_one_or_none()
if not user or not user.is_active: if not user or not user.is_active:
raise InvalidGrantError("User not found or inactive") raise InvalidGrantError("User not found or inactive")
@@ -648,9 +620,7 @@ async def refresh_tokens(
final_scope = token_scope final_scope = token_scope
# Revoke old refresh token (token rotation) # Revoke old refresh token (token rotation)
token_record.revoked = True # type: ignore[assignment] await oauth_provider_token_repo.revoke(db, token=token_record)
token_record.last_used_at = datetime.now(UTC) # type: ignore[assignment]
await db.commit()
# Issue new tokens # Issue new tokens
device = str(token_record.device_info) if token_record.device_info else None device = str(token_record.device_info) if token_record.device_info else None
@@ -697,28 +667,22 @@ async def revoke_token(
# Try as refresh token first (more likely) # Try as refresh token first (more likely)
if token_type_hint != "access_token": if token_type_hint != "access_token":
token_hash = hash_token(token) token_hash = hash_token(token)
result = await db.execute( refresh_record = await oauth_provider_token_repo.get_by_token_hash(
select(OAuthProviderRefreshToken).where( db, token_hash=token_hash
OAuthProviderRefreshToken.token_hash == token_hash
)
) )
refresh_record = result.scalar_one_or_none()
if refresh_record: if refresh_record:
# Validate client if provided # Validate client if provided
if client_id and refresh_record.client_id != client_id: if client_id and refresh_record.client_id != client_id:
raise InvalidClientError("Token was not issued to this client") raise InvalidClientError("Token was not issued to this client")
refresh_record.revoked = True # type: ignore[assignment] await oauth_provider_token_repo.revoke(db, token=refresh_record)
await db.commit() logger.info("Revoked refresh token %s...", refresh_record.jti[:8])
logger.info(f"Revoked refresh token {refresh_record.jti[:8]}...")
return True return True
# Try as access token (JWT) # Try as access token (JWT)
if token_type_hint != "refresh_token": if token_type_hint != "refresh_token":
try: try:
from jose.exceptions import JWTError
payload = jwt.decode( payload = jwt.decode(
token, token,
settings.SECRET_KEY, settings.SECRET_KEY,
@@ -731,22 +695,18 @@ async def revoke_token(
jti = payload.get("jti") jti = payload.get("jti")
if jti: if jti:
# Find and revoke the associated refresh token # Find and revoke the associated refresh token
result = await db.execute( refresh_record = await oauth_provider_token_repo.get_by_jti(db, jti=jti)
select(OAuthProviderRefreshToken).where(
OAuthProviderRefreshToken.jti == jti
)
)
refresh_record = result.scalar_one_or_none()
if refresh_record: if refresh_record:
if client_id and refresh_record.client_id != client_id: if client_id and refresh_record.client_id != client_id:
raise InvalidClientError("Token was not issued to this client") raise InvalidClientError("Token was not issued to this client")
refresh_record.revoked = True # type: ignore[assignment] await oauth_provider_token_repo.revoke(db, token=refresh_record)
await db.commit()
logger.info( logger.info(
f"Revoked refresh token via access token JTI {jti[:8]}..." "Revoked refresh token via access token JTI %s...", jti[:8]
) )
return True return True
except (JWTError, Exception): # noqa: S110 - Intentional: invalid JWT not an error except InvalidTokenError:
pass
except Exception: # noqa: S110 - Intentional: invalid JWT not an error
pass pass
return False return False
@@ -770,26 +730,13 @@ async def revoke_tokens_for_user_client(
Returns: Returns:
Number of tokens revoked Number of tokens revoked
""" """
result = await db.execute( count = await oauth_provider_token_repo.revoke_all_for_user_client(
select(OAuthProviderRefreshToken).where( db, user_id=user_id, client_id=client_id
and_(
OAuthProviderRefreshToken.user_id == user_id,
OAuthProviderRefreshToken.client_id == client_id,
OAuthProviderRefreshToken.revoked == False, # noqa: E712
)
)
) )
tokens = result.scalars().all()
count = 0
for token in tokens:
token.revoked = True # type: ignore[assignment]
count += 1
if count > 0: if count > 0:
await db.commit()
logger.warning( logger.warning(
f"Revoked {count} tokens for user {user_id} and client {client_id}" "Revoked %s tokens for user %s and client %s", count, user_id, client_id
) )
return count return count
@@ -808,24 +755,10 @@ async def revoke_all_user_tokens(db: AsyncSession, user_id: UUID) -> int:
Returns: Returns:
Number of tokens revoked Number of tokens revoked
""" """
result = await db.execute( count = await oauth_provider_token_repo.revoke_all_for_user(db, user_id=user_id)
select(OAuthProviderRefreshToken).where(
and_(
OAuthProviderRefreshToken.user_id == user_id,
OAuthProviderRefreshToken.revoked == False, # noqa: E712
)
)
)
tokens = result.scalars().all()
count = 0
for token in tokens:
token.revoked = True # type: ignore[assignment]
count += 1
if count > 0: if count > 0:
await db.commit() logger.info("Revoked %s OAuth provider tokens for user %s", count, user_id)
logger.info(f"Revoked {count} OAuth provider tokens for user {user_id}")
return count return count
@@ -864,8 +797,6 @@ async def introspect_token(
# Try as access token (JWT) first # Try as access token (JWT) first
if token_type_hint != "refresh_token": if token_type_hint != "refresh_token":
try: try:
from jose.exceptions import ExpiredSignatureError, JWTError
payload = jwt.decode( payload = jwt.decode(
token, token,
settings.SECRET_KEY, settings.SECRET_KEY,
@@ -878,12 +809,7 @@ async def introspect_token(
# Check if associated refresh token is revoked # Check if associated refresh token is revoked
jti = payload.get("jti") jti = payload.get("jti")
if jti: if jti:
result = await db.execute( refresh_record = await oauth_provider_token_repo.get_by_jti(db, jti=jti)
select(OAuthProviderRefreshToken).where(
OAuthProviderRefreshToken.jti == jti
)
)
refresh_record = result.scalar_one_or_none()
if refresh_record and refresh_record.revoked: if refresh_record and refresh_record.revoked:
return {"active": False} return {"active": False}
@@ -901,18 +827,17 @@ async def introspect_token(
} }
except ExpiredSignatureError: except ExpiredSignatureError:
return {"active": False} return {"active": False}
except (JWTError, Exception): # noqa: S110 - Intentional: invalid JWT falls through to refresh token check except InvalidTokenError:
pass
except Exception: # noqa: S110 - Intentional: invalid JWT falls through to refresh token check
pass pass
# Try as refresh token # Try as refresh token
if token_type_hint != "access_token": if token_type_hint != "access_token":
token_hash = hash_token(token) token_hash = hash_token(token)
result = await db.execute( refresh_record = await oauth_provider_token_repo.get_by_token_hash(
select(OAuthProviderRefreshToken).where( db, token_hash=token_hash
OAuthProviderRefreshToken.token_hash == token_hash
)
) )
refresh_record = result.scalar_one_or_none()
if refresh_record and refresh_record.is_valid: if refresh_record and refresh_record.is_valid:
return { return {
@@ -937,17 +862,11 @@ async def get_consent(
db: AsyncSession, db: AsyncSession,
user_id: UUID, user_id: UUID,
client_id: str, client_id: str,
) -> OAuthConsent | None: ):
"""Get existing consent record for user-client pair.""" """Get existing consent record for user-client pair."""
result = await db.execute( return await oauth_consent_repo.get_consent(
select(OAuthConsent).where( db, user_id=user_id, client_id=client_id
and_(
OAuthConsent.user_id == user_id,
OAuthConsent.client_id == client_id,
)
)
) )
return result.scalar_one_or_none()
async def check_consent( async def check_consent(
@@ -972,31 +891,15 @@ async def grant_consent(
user_id: UUID, user_id: UUID,
client_id: str, client_id: str,
scopes: list[str], scopes: list[str],
) -> OAuthConsent: ):
""" """
Grant or update consent for a user-client pair. Grant or update consent for a user-client pair.
If consent already exists, updates the granted scopes. If consent already exists, updates the granted scopes.
""" """
consent = await get_consent(db, user_id, client_id) return await oauth_consent_repo.grant_consent(
db, user_id=user_id, client_id=client_id, scopes=scopes
if consent: )
# Merge scopes
granted = str(consent.granted_scopes) if consent.granted_scopes else ""
existing = set(parse_scope(granted))
new_scopes = existing | set(scopes)
consent.granted_scopes = join_scope(list(new_scopes)) # type: ignore[assignment]
else:
consent = OAuthConsent(
user_id=user_id,
client_id=client_id,
granted_scopes=join_scope(scopes),
)
db.add(consent)
await db.commit()
await db.refresh(consent)
return consent
async def revoke_consent( async def revoke_consent(
@@ -1009,21 +912,13 @@ async def revoke_consent(
Returns True if consent was found and revoked. Returns True if consent was found and revoked.
""" """
# Delete consent record # Revoke all tokens first
result = await db.execute(
delete(OAuthConsent).where(
and_(
OAuthConsent.user_id == user_id,
OAuthConsent.client_id == client_id,
)
)
)
# Revoke all tokens
await revoke_tokens_for_user_client(db, user_id, client_id) await revoke_tokens_for_user_client(db, user_id, client_id)
await db.commit() # Delete consent record
return result.rowcount > 0 # type: ignore[attr-defined] return await oauth_consent_repo.revoke_consent(
db, user_id=user_id, client_id=client_id
)
# ============================================================================ # ============================================================================
@@ -1031,6 +926,26 @@ async def revoke_consent(
# ============================================================================ # ============================================================================
async def register_client(db: AsyncSession, client_data: OAuthClientCreate) -> tuple:
"""Create a new OAuth client. Returns (client, secret)."""
return await oauth_client_repo.create_client(db, obj_in=client_data)
async def list_clients(db: AsyncSession) -> list:
"""List all registered OAuth clients."""
return await oauth_client_repo.get_all_clients(db)
async def delete_client_by_id(db: AsyncSession, client_id: str) -> None:
"""Delete an OAuth client by client_id."""
await oauth_client_repo.delete_client(db, client_id=client_id)
async def list_user_consents(db: AsyncSession, user_id: UUID) -> list[dict]:
"""Get all OAuth consents for a user with client details."""
return await oauth_consent_repo.get_user_consents_with_clients(db, user_id=user_id)
async def cleanup_expired_codes(db: AsyncSession) -> int: async def cleanup_expired_codes(db: AsyncSession) -> int:
""" """
Delete expired authorization codes. Delete expired authorization codes.
@@ -1040,13 +955,7 @@ async def cleanup_expired_codes(db: AsyncSession) -> int:
Returns: Returns:
Number of codes deleted Number of codes deleted
""" """
result = await db.execute( return await oauth_authorization_code_repo.cleanup_expired(db)
delete(OAuthAuthorizationCode).where(
OAuthAuthorizationCode.expires_at < datetime.now(UTC)
)
)
await db.commit()
return result.rowcount # type: ignore[attr-defined]
async def cleanup_expired_tokens(db: AsyncSession) -> int: async def cleanup_expired_tokens(db: AsyncSession) -> int:
@@ -1058,12 +967,4 @@ async def cleanup_expired_tokens(db: AsyncSession) -> int:
Returns: Returns:
Number of tokens deleted Number of tokens deleted
""" """
# Delete tokens that are both expired AND revoked (or just very old) return await oauth_provider_token_repo.cleanup_expired(db, cutoff_days=7)
cutoff = datetime.now(UTC) - timedelta(days=7)
result = await db.execute(
delete(OAuthProviderRefreshToken).where(
OAuthProviderRefreshToken.expires_at < cutoff
)
)
await db.commit()
return result.rowcount # type: ignore[attr-defined]

View File

@@ -19,14 +19,15 @@ from typing import TypedDict, cast
from uuid import UUID from uuid import UUID
from authlib.integrations.httpx_client import AsyncOAuth2Client from authlib.integrations.httpx_client import AsyncOAuth2Client
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.core.auth import create_access_token, create_refresh_token from app.core.auth import create_access_token, create_refresh_token
from app.core.config import settings from app.core.config import settings
from app.core.exceptions import AuthenticationError from app.core.exceptions import AuthenticationError
from app.crud import oauth_account, oauth_state
from app.models.user import User from app.models.user import User
from app.repositories.oauth_account import oauth_account_repo as oauth_account
from app.repositories.oauth_state import oauth_state_repo as oauth_state
from app.repositories.user import user_repo
from app.schemas.oauth import ( from app.schemas.oauth import (
OAuthAccountCreate, OAuthAccountCreate,
OAuthCallbackResponse, OAuthCallbackResponse,
@@ -38,19 +39,22 @@ from app.schemas.oauth import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class OAuthProviderConfig(TypedDict, total=False): class _OAuthProviderConfigRequired(TypedDict):
"""Type definition for OAuth provider configuration."""
name: str name: str
icon: str icon: str
authorize_url: str authorize_url: str
token_url: str token_url: str
userinfo_url: str userinfo_url: str
email_url: str # Optional, GitHub-only
scopes: list[str] scopes: list[str]
supports_pkce: bool supports_pkce: bool
class OAuthProviderConfig(_OAuthProviderConfigRequired, total=False):
"""Type definition for OAuth provider configuration."""
email_url: str # Optional, GitHub-only
# Provider configurations # Provider configurations
OAUTH_PROVIDERS: dict[str, OAuthProviderConfig] = { OAUTH_PROVIDERS: dict[str, OAuthProviderConfig] = {
"google": { "google": {
@@ -215,7 +219,7 @@ class OAuthService:
**auth_params, **auth_params,
) )
logger.info(f"OAuth authorization URL created for {provider}") logger.info("OAuth authorization URL created for %s", provider)
return url, state return url, state
@staticmethod @staticmethod
@@ -250,8 +254,9 @@ class OAuthService:
# This prevents authorization code injection attacks (RFC 6749 Section 10.6) # This prevents authorization code injection attacks (RFC 6749 Section 10.6)
if state_record.redirect_uri != redirect_uri: if state_record.redirect_uri != redirect_uri:
logger.warning( logger.warning(
f"OAuth redirect_uri mismatch: expected {state_record.redirect_uri}, " "OAuth redirect_uri mismatch: expected %s, got %s",
f"got {redirect_uri}" state_record.redirect_uri,
redirect_uri,
) )
raise AuthenticationError("Redirect URI mismatch") raise AuthenticationError("Redirect URI mismatch")
@@ -295,7 +300,7 @@ class OAuthService:
except AuthenticationError: except AuthenticationError:
raise raise
except Exception as e: except Exception as e:
logger.error(f"OAuth token exchange failed: {e!s}") logger.error("OAuth token exchange failed: %s", e)
raise AuthenticationError("Failed to exchange authorization code") raise AuthenticationError("Failed to exchange authorization code")
# Get user info from provider # Get user info from provider
@@ -308,7 +313,7 @@ class OAuthService:
client, provider, config, access_token client, provider, config, access_token
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to get user info: {e!s}") logger.error("Failed to get user info: %s", e)
raise AuthenticationError( raise AuthenticationError(
"Failed to get user information from provider" "Failed to get user information from provider"
) )
@@ -343,18 +348,17 @@ class OAuthService:
await oauth_account.update_tokens( await oauth_account.update_tokens(
db, db,
account=existing_oauth, account=existing_oauth,
access_token_encrypted=token.get("access_token"), refresh_token_encrypted=token.get("refresh_token"), token_expires_at=datetime.now(UTC) access_token=token.get("access_token"),
refresh_token=token.get("refresh_token"),
token_expires_at=datetime.now(UTC)
+ timedelta(seconds=token.get("expires_in", 3600)), + timedelta(seconds=token.get("expires_in", 3600)),
) )
logger.info(f"OAuth login successful for {user.email} via {provider}") logger.info("OAuth login successful for %s via %s", user.email, provider)
elif state_record.user_id: elif state_record.user_id:
# Account linking flow (user is already logged in) # Account linking flow (user is already logged in)
result = await db.execute( user = await user_repo.get(db, id=str(state_record.user_id))
select(User).where(User.id == state_record.user_id)
)
user = result.scalar_one_or_none()
if not user: if not user:
raise AuthenticationError("User not found for account linking") raise AuthenticationError("User not found for account linking")
@@ -375,24 +379,23 @@ class OAuthService:
provider=provider, provider=provider,
provider_user_id=provider_user_id, provider_user_id=provider_user_id,
provider_email=provider_email, provider_email=provider_email,
access_token_encrypted=token.get("access_token"), refresh_token_encrypted=token.get("refresh_token"), token_expires_at=datetime.now(UTC) access_token=token.get("access_token"),
refresh_token=token.get("refresh_token"),
token_expires_at=datetime.now(UTC)
+ timedelta(seconds=token.get("expires_in", 3600)) + timedelta(seconds=token.get("expires_in", 3600))
if token.get("expires_in") if token.get("expires_in")
else None, else None,
) )
await oauth_account.create_account(db, obj_in=oauth_create) await oauth_account.create_account(db, obj_in=oauth_create)
logger.info(f"OAuth account linked: {provider} -> {user.email}") logger.info("OAuth account linked: %s -> %s", provider, user.email)
else: else:
# New OAuth login - check for existing user by email # New OAuth login - check for existing user by email
user = None user = None
if provider_email and settings.OAUTH_AUTO_LINK_BY_EMAIL: if provider_email and settings.OAUTH_AUTO_LINK_BY_EMAIL:
result = await db.execute( user = await user_repo.get_by_email(db, email=provider_email)
select(User).where(User.email == provider_email)
)
user = result.scalar_one_or_none()
if user: if user:
# Auto-link to existing user # Auto-link to existing user
@@ -407,7 +410,9 @@ class OAuthService:
if existing_provider: if existing_provider:
# This shouldn't happen if we got here, but safety check # This shouldn't happen if we got here, but safety check
logger.warning( logger.warning(
f"OAuth account already linked (race condition?): {provider} -> {user.email}" "OAuth account already linked (race condition?): %s -> %s",
provider,
user.email,
) )
else: else:
# Create OAuth account link # Create OAuth account link
@@ -416,8 +421,8 @@ class OAuthService:
provider=provider, provider=provider,
provider_user_id=provider_user_id, provider_user_id=provider_user_id,
provider_email=provider_email, provider_email=provider_email,
access_token_encrypted=token.get("access_token"), access_token=token.get("access_token"),
refresh_token_encrypted=token.get("refresh_token"), refresh_token=token.get("refresh_token"),
token_expires_at=datetime.now(UTC) token_expires_at=datetime.now(UTC)
+ timedelta(seconds=token.get("expires_in", 3600)) + timedelta(seconds=token.get("expires_in", 3600))
if token.get("expires_in") if token.get("expires_in")
@@ -425,7 +430,9 @@ class OAuthService:
) )
await oauth_account.create_account(db, obj_in=oauth_create) await oauth_account.create_account(db, obj_in=oauth_create)
logger.info(f"OAuth auto-linked by email: {provider} -> {user.email}") logger.info(
"OAuth auto-linked by email: %s -> %s", provider, user.email
)
else: else:
# Create new user # Create new user
@@ -445,7 +452,7 @@ class OAuthService:
) )
is_new_user = True is_new_user = True
logger.info(f"New user created via OAuth: {user.email} ({provider})") logger.info("New user created via OAuth: %s (%s)", user.email, provider)
# Generate JWT tokens # Generate JWT tokens
claims = { claims = {
@@ -486,7 +493,7 @@ class OAuthService:
# GitHub requires separate request for email # GitHub requires separate request for email
if provider == "github" and not user_info.get("email"): if provider == "github" and not user_info.get("email"):
email_resp = await client.get( email_resp = await client.get(
config["email_url"], config["email_url"], # pyright: ignore[reportTypedDictNotRequiredAccess]
headers=headers, headers=headers,
) )
email_resp.raise_for_status() email_resp.raise_for_status()
@@ -530,8 +537,9 @@ class OAuthService:
AuthenticationError: If verification fails AuthenticationError: If verification fails
""" """
import httpx import httpx
from jose import jwt as jose_jwt import jwt as pyjwt
from jose.exceptions import JWTError from jwt.algorithms import RSAAlgorithm
from jwt.exceptions import InvalidTokenError
try: try:
# Fetch Google's public keys (JWKS) # Fetch Google's public keys (JWKS)
@@ -545,24 +553,27 @@ class OAuthService:
jwks = jwks_response.json() jwks = jwks_response.json()
# Get the key ID from the token header # Get the key ID from the token header
unverified_header = jose_jwt.get_unverified_header(id_token) unverified_header = pyjwt.get_unverified_header(id_token)
kid = unverified_header.get("kid") kid = unverified_header.get("kid")
if not kid: if not kid:
raise AuthenticationError("ID token missing key ID (kid)") raise AuthenticationError("ID token missing key ID (kid)")
# Find the matching public key # Find the matching public key
public_key = None jwk_data = None
for key in jwks.get("keys", []): for key in jwks.get("keys", []):
if key.get("kid") == kid: if key.get("kid") == kid:
public_key = key jwk_data = key
break break
if not public_key: if not jwk_data:
raise AuthenticationError("ID token signed with unknown key") raise AuthenticationError("ID token signed with unknown key")
# Convert JWK to a public key object for PyJWT
public_key = RSAAlgorithm.from_jwk(jwk_data)
# Verify the token signature and decode claims # Verify the token signature and decode claims
# jose library will verify signature against the JWK # PyJWT will verify signature against the RSA public key
payload = jose_jwt.decode( payload = pyjwt.decode(
id_token, id_token,
public_key, public_key,
algorithms=["RS256"], # Google uses RS256 algorithms=["RS256"], # Google uses RS256
@@ -581,23 +592,24 @@ class OAuthService:
token_nonce = payload.get("nonce") token_nonce = payload.get("nonce")
if token_nonce != expected_nonce: if token_nonce != expected_nonce:
logger.warning( logger.warning(
f"OAuth ID token nonce mismatch: expected {expected_nonce}, " "OAuth ID token nonce mismatch: expected %s, got %s",
f"got {token_nonce}" expected_nonce,
token_nonce,
) )
raise AuthenticationError("Invalid ID token nonce") raise AuthenticationError("Invalid ID token nonce")
logger.debug("Google ID token verified successfully") logger.debug("Google ID token verified successfully")
return payload return payload
except JWTError as e: except InvalidTokenError as e:
logger.warning(f"Google ID token verification failed: {e}") logger.warning("Google ID token verification failed: %s", e)
raise AuthenticationError("Invalid ID token signature") raise AuthenticationError("Invalid ID token signature")
except httpx.HTTPError as e: except httpx.HTTPError as e:
logger.error(f"Failed to fetch Google JWKS: {e}") logger.error("Failed to fetch Google JWKS: %s", e)
# If we can't verify the ID token, fail closed for security # If we can't verify the ID token, fail closed for security
raise AuthenticationError("Failed to verify ID token") raise AuthenticationError("Failed to verify ID token")
except Exception as e: except Exception as e:
logger.error(f"Unexpected error verifying Google ID token: {e}") logger.error("Unexpected error verifying Google ID token: %s", e)
raise AuthenticationError("ID token verification error") raise AuthenticationError("ID token verification error")
@staticmethod @staticmethod
@@ -644,14 +656,15 @@ class OAuthService:
provider=provider, provider=provider,
provider_user_id=provider_user_id, provider_user_id=provider_user_id,
provider_email=email, provider_email=email,
access_token_encrypted=token.get("access_token"), refresh_token_encrypted=token.get("refresh_token"), token_expires_at=datetime.now(UTC) access_token=token.get("access_token"),
refresh_token=token.get("refresh_token"),
token_expires_at=datetime.now(UTC)
+ timedelta(seconds=token.get("expires_in", 3600)) + timedelta(seconds=token.get("expires_in", 3600))
if token.get("expires_in") if token.get("expires_in")
else None, else None,
) )
await oauth_account.create_account(db, obj_in=oauth_create) await oauth_account.create_account(db, obj_in=oauth_create)
await db.commit()
await db.refresh(user) await db.refresh(user)
return user return user
@@ -698,9 +711,23 @@ class OAuthService:
if not deleted: if not deleted:
raise AuthenticationError(f"No {provider} account found to unlink") raise AuthenticationError(f"No {provider} account found to unlink")
logger.info(f"OAuth provider unlinked: {provider} from {user.email}") logger.info("OAuth provider unlinked: %s from %s", provider, user.email)
return True return True
@staticmethod
async def get_user_accounts(db: AsyncSession, *, user_id: UUID) -> list:
"""Get all OAuth accounts linked to a user."""
return await oauth_account.get_user_accounts(db, user_id=user_id)
@staticmethod
async def get_user_account_by_provider(
db: AsyncSession, *, user_id: UUID, provider: str
):
"""Get a specific OAuth account for a user and provider."""
return await oauth_account.get_user_account_by_provider(
db, user_id=user_id, provider=provider
)
@staticmethod @staticmethod
async def cleanup_expired_states(db: AsyncSession) -> int: async def cleanup_expired_states(db: AsyncSession) -> int:
""" """

View File

@@ -0,0 +1,155 @@
# app/services/organization_service.py
"""Service layer for organization operations — delegates to OrganizationRepository."""
import logging
from typing import Any
from uuid import UUID
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.exceptions import NotFoundError
from app.models.organization import Organization
from app.models.user_organization import OrganizationRole, UserOrganization
from app.repositories.organization import OrganizationRepository, organization_repo
from app.schemas.organizations import OrganizationCreate, OrganizationUpdate
logger = logging.getLogger(__name__)
class OrganizationService:
"""Service for organization management operations."""
def __init__(
self, organization_repository: OrganizationRepository | None = None
) -> None:
self._repo = organization_repository or organization_repo
async def get_organization(self, db: AsyncSession, org_id: str) -> Organization:
"""Get organization by ID, raising NotFoundError if not found."""
org = await self._repo.get(db, id=org_id)
if not org:
raise NotFoundError(f"Organization {org_id} not found")
return org
async def create_organization(
self, db: AsyncSession, *, obj_in: OrganizationCreate
) -> Organization:
"""Create a new organization."""
return await self._repo.create(db, obj_in=obj_in)
async def update_organization(
self,
db: AsyncSession,
*,
org: Organization,
obj_in: OrganizationUpdate | dict[str, Any],
) -> Organization:
"""Update an existing organization."""
return await self._repo.update(db, db_obj=org, obj_in=obj_in)
async def remove_organization(self, db: AsyncSession, org_id: str) -> None:
"""Permanently delete an organization by ID."""
await self._repo.remove(db, id=org_id)
async def get_member_count(self, db: AsyncSession, *, organization_id: UUID) -> int:
"""Get number of active members in an organization."""
return await self._repo.get_member_count(db, organization_id=organization_id)
async def get_multi_with_member_counts(
self,
db: AsyncSession,
*,
skip: int = 0,
limit: int = 100,
is_active: bool | None = None,
search: str | None = None,
) -> tuple[list[dict[str, Any]], int]:
"""List organizations with member counts and pagination."""
return await self._repo.get_multi_with_member_counts(
db, skip=skip, limit=limit, is_active=is_active, search=search
)
async def get_user_organizations_with_details(
self,
db: AsyncSession,
*,
user_id: UUID,
is_active: bool | None = None,
) -> list[dict[str, Any]]:
"""Get all organizations a user belongs to, with membership details."""
return await self._repo.get_user_organizations_with_details(
db, user_id=user_id, is_active=is_active
)
async def get_organization_members(
self,
db: AsyncSession,
*,
organization_id: UUID,
skip: int = 0,
limit: int = 100,
is_active: bool | None = True,
) -> tuple[list[dict[str, Any]], int]:
"""Get members of an organization with pagination."""
return await self._repo.get_organization_members(
db,
organization_id=organization_id,
skip=skip,
limit=limit,
is_active=is_active,
)
async def add_member(
self,
db: AsyncSession,
*,
organization_id: UUID,
user_id: UUID,
role: OrganizationRole = OrganizationRole.MEMBER,
) -> UserOrganization:
"""Add a user to an organization."""
return await self._repo.add_user(
db, organization_id=organization_id, user_id=user_id, role=role
)
async def remove_member(
self,
db: AsyncSession,
*,
organization_id: UUID,
user_id: UUID,
) -> bool:
"""Remove a user from an organization. Returns True if found and removed."""
return await self._repo.remove_user(
db, organization_id=organization_id, user_id=user_id
)
async def get_user_role_in_org(
self, db: AsyncSession, *, user_id: UUID, organization_id: UUID
) -> OrganizationRole | None:
"""Get the role of a user in an organization."""
return await self._repo.get_user_role_in_org(
db, user_id=user_id, organization_id=organization_id
)
async def get_org_distribution(
self, db: AsyncSession, *, limit: int = 6
) -> list[dict[str, Any]]:
"""Return top organizations by member count for admin dashboard."""
from sqlalchemy import func, select
result = await db.execute(
select(
Organization.name,
func.count(UserOrganization.user_id).label("count"),
)
.join(UserOrganization, Organization.id == UserOrganization.organization_id)
.group_by(Organization.name)
.order_by(func.count(UserOrganization.user_id).desc())
.limit(limit)
)
return [{"name": row.name, "value": row.count} for row in result.all()]
# Default singleton
organization_service = OrganizationService()

View File

@@ -8,7 +8,7 @@ import logging
from datetime import UTC, datetime from datetime import UTC, datetime
from app.core.database import SessionLocal from app.core.database import SessionLocal
from app.crud.session import session as session_crud from app.repositories.session import session_repo as session_repo
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -32,15 +32,15 @@ async def cleanup_expired_sessions(keep_days: int = 30) -> int:
async with SessionLocal() as db: async with SessionLocal() as db:
try: try:
# Use CRUD method to cleanup # Use repository method to cleanup
count = await session_crud.cleanup_expired(db, keep_days=keep_days) count = await session_repo.cleanup_expired(db, keep_days=keep_days)
logger.info(f"Session cleanup complete: {count} sessions deleted") logger.info("Session cleanup complete: %s sessions deleted", count)
return count return count
except Exception as e: except Exception as e:
logger.error(f"Error during session cleanup: {e!s}", exc_info=True) logger.exception("Error during session cleanup: %s", e)
return 0 return 0
@@ -79,10 +79,10 @@ async def get_session_statistics() -> dict:
"expired": expired_sessions, "expired": expired_sessions,
} }
logger.info(f"Session statistics: {stats}") logger.info("Session statistics: %s", stats)
return stats return stats
except Exception as e: except Exception as e:
logger.error(f"Error getting session statistics: {e!s}", exc_info=True) logger.exception("Error getting session statistics: %s", e)
return {} return {}

View File

@@ -0,0 +1,97 @@
# app/services/session_service.py
"""Service layer for session operations — delegates to SessionRepository."""
import logging
from datetime import datetime
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.user_session import UserSession
from app.repositories.session import SessionRepository, session_repo
from app.schemas.sessions import SessionCreate
logger = logging.getLogger(__name__)
class SessionService:
"""Service for user session management operations."""
def __init__(self, session_repository: SessionRepository | None = None) -> None:
self._repo = session_repository or session_repo
async def create_session(
self, db: AsyncSession, *, obj_in: SessionCreate
) -> UserSession:
"""Create a new session record."""
return await self._repo.create_session(db, obj_in=obj_in)
async def get_session(
self, db: AsyncSession, session_id: str
) -> UserSession | None:
"""Get session by ID."""
return await self._repo.get(db, id=session_id)
async def get_user_sessions(
self, db: AsyncSession, *, user_id: str, active_only: bool = True
) -> list[UserSession]:
"""Get all sessions for a user."""
return await self._repo.get_user_sessions(
db, user_id=user_id, active_only=active_only
)
async def get_active_by_jti(
self, db: AsyncSession, *, jti: str
) -> UserSession | None:
"""Get active session by refresh token JTI."""
return await self._repo.get_active_by_jti(db, jti=jti)
async def get_by_jti(self, db: AsyncSession, *, jti: str) -> UserSession | None:
"""Get session by refresh token JTI (active or inactive)."""
return await self._repo.get_by_jti(db, jti=jti)
async def deactivate(
self, db: AsyncSession, *, session_id: str
) -> UserSession | None:
"""Deactivate a session (logout from device)."""
return await self._repo.deactivate(db, session_id=session_id)
async def deactivate_all_user_sessions(
self, db: AsyncSession, *, user_id: str
) -> int:
"""Deactivate all sessions for a user. Returns count deactivated."""
return await self._repo.deactivate_all_user_sessions(db, user_id=user_id)
async def update_refresh_token(
self,
db: AsyncSession,
*,
session: UserSession,
new_jti: str,
new_expires_at: datetime,
) -> UserSession:
"""Update session with a rotated refresh token."""
return await self._repo.update_refresh_token(
db, session=session, new_jti=new_jti, new_expires_at=new_expires_at
)
async def cleanup_expired_for_user(self, db: AsyncSession, *, user_id: str) -> int:
"""Remove expired sessions for a user. Returns count removed."""
return await self._repo.cleanup_expired_for_user(db, user_id=user_id)
async def get_all_sessions(
self,
db: AsyncSession,
*,
skip: int = 0,
limit: int = 100,
active_only: bool = True,
with_user: bool = True,
) -> tuple[list[UserSession], int]:
"""Get all sessions with pagination (admin only)."""
return await self._repo.get_all_sessions(
db, skip=skip, limit=limit, active_only=active_only, with_user=with_user
)
# Default singleton
session_service = SessionService()

View File

@@ -0,0 +1,120 @@
# app/services/user_service.py
"""Service layer for user operations — delegates to UserRepository."""
import logging
from typing import Any
from uuid import UUID
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.exceptions import NotFoundError
from app.models.user import User
from app.repositories.user import UserRepository, user_repo
from app.schemas.users import UserCreate, UserUpdate
logger = logging.getLogger(__name__)
class UserService:
"""Service for user management operations."""
def __init__(self, user_repository: UserRepository | None = None) -> None:
self._repo = user_repository or user_repo
async def get_user(self, db: AsyncSession, user_id: str) -> User:
"""Get user by ID, raising NotFoundError if not found."""
user = await self._repo.get(db, id=user_id)
if not user:
raise NotFoundError(f"User {user_id} not found")
return user
async def get_by_email(self, db: AsyncSession, email: str) -> User | None:
"""Get user by email address."""
return await self._repo.get_by_email(db, email=email)
async def create_user(self, db: AsyncSession, user_data: UserCreate) -> User:
"""Create a new user."""
return await self._repo.create(db, obj_in=user_data)
async def update_user(
self, db: AsyncSession, *, user: User, obj_in: UserUpdate | dict[str, Any]
) -> User:
"""Update an existing user."""
return await self._repo.update(db, db_obj=user, obj_in=obj_in)
async def soft_delete_user(self, db: AsyncSession, user_id: str) -> None:
"""Soft-delete a user by ID."""
await self._repo.soft_delete(db, id=user_id)
async def list_users(
self,
db: AsyncSession,
*,
skip: int = 0,
limit: int = 100,
sort_by: str | None = None,
sort_order: str = "asc",
filters: dict[str, Any] | None = None,
search: str | None = None,
) -> tuple[list[User], int]:
"""List users with pagination, sorting, filtering, and search."""
return await self._repo.get_multi_with_total(
db,
skip=skip,
limit=limit,
sort_by=sort_by,
sort_order=sort_order,
filters=filters,
search=search,
)
async def bulk_update_status(
self, db: AsyncSession, *, user_ids: list[UUID], is_active: bool
) -> int:
"""Bulk update active status for multiple users. Returns count updated."""
return await self._repo.bulk_update_status(
db, user_ids=user_ids, is_active=is_active
)
async def bulk_soft_delete(
self,
db: AsyncSession,
*,
user_ids: list[UUID],
exclude_user_id: UUID | None = None,
) -> int:
"""Bulk soft-delete multiple users. Returns count deleted."""
return await self._repo.bulk_soft_delete(
db, user_ids=user_ids, exclude_user_id=exclude_user_id
)
async def get_stats(self, db: AsyncSession) -> dict[str, Any]:
"""Return user stats needed for the admin dashboard."""
from sqlalchemy import func, select
total_users = (
await db.execute(select(func.count()).select_from(User))
).scalar() or 0
active_count = (
await db.execute(
select(func.count()).select_from(User).where(User.is_active)
)
).scalar() or 0
inactive_count = (
await db.execute(
select(func.count()).select_from(User).where(User.is_active.is_(False))
)
).scalar() or 0
all_users = list(
(await db.execute(select(User).order_by(User.created_at))).scalars().all()
)
return {
"total_users": total_users,
"active_count": active_count,
"inactive_count": inactive_count,
"all_users": all_users,
}
# Default singleton
user_service = UserService()

View File

@@ -65,10 +65,10 @@ async def setup_async_test_db():
async with test_engine.begin() as conn: async with test_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all) await conn.run_sync(Base.metadata.create_all)
AsyncTestingSessionLocal = sessionmaker( AsyncTestingSessionLocal = sessionmaker( # pyright: ignore[reportCallIssue]
autocommit=False, autocommit=False,
autoflush=False, autoflush=False,
bind=test_engine, bind=test_engine, # pyright: ignore[reportArgumentType]
expire_on_commit=False, expire_on_commit=False,
class_=AsyncSession, class_=AsyncSession,
) )

View File

@@ -79,12 +79,13 @@ This FastAPI backend application follows a **clean layered architecture** patter
### Authentication & Security ### Authentication & Security
- **python-jose**: JWT token generation and validation - **PyJWT**: JWT token generation and validation
- Cryptographic signing - Cryptographic signing (HS256, RS256)
- Token expiration handling - Token expiration handling
- Claims validation - Claims validation
- JWK support for Google ID token verification
- **passlib + bcrypt**: Password hashing - **bcrypt**: Password hashing
- Industry-standard bcrypt algorithm - Industry-standard bcrypt algorithm
- Configurable cost factor - Configurable cost factor
- Salt generation - Salt generation
@@ -117,7 +118,8 @@ backend/
│ ├── api/ # API layer │ ├── api/ # API layer
│ │ ├── dependencies/ # Dependency injection │ │ ├── dependencies/ # Dependency injection
│ │ │ ├── auth.py # Authentication dependencies │ │ │ ├── auth.py # Authentication dependencies
│ │ │ ── permissions.py # Authorization dependencies │ │ │ ── permissions.py # Authorization dependencies
│ │ │ └── services.py # Service singleton injection
│ │ ├── routes/ # API endpoints │ │ ├── routes/ # API endpoints
│ │ │ ├── auth.py # Authentication routes │ │ │ ├── auth.py # Authentication routes
│ │ │ ├── users.py # User management routes │ │ │ ├── users.py # User management routes
@@ -131,13 +133,14 @@ backend/
│ │ ├── config.py # Application configuration │ │ ├── config.py # Application configuration
│ │ ├── database.py # Database connection │ │ ├── database.py # Database connection
│ │ ├── exceptions.py # Custom exception classes │ │ ├── exceptions.py # Custom exception classes
│ │ ├── repository_exceptions.py # Repository-level exception hierarchy
│ │ └── middleware.py # Custom middleware │ │ └── middleware.py # Custom middleware
│ │ │ │
│ ├── crud/ # Database operations │ ├── repositories/ # Data access layer
│ │ ├── base.py # Generic CRUD base class │ │ ├── base.py # Generic repository base class
│ │ ├── user.py # User CRUD operations │ │ ├── user.py # User repository
│ │ ├── session.py # Session CRUD operations │ │ ├── session.py # Session repository
│ │ └── organization.py # Organization CRUD │ │ └── organization.py # Organization repository
│ │ │ │
│ ├── models/ # SQLAlchemy models │ ├── models/ # SQLAlchemy models
│ │ ├── base.py # Base model with mixins │ │ ├── base.py # Base model with mixins
@@ -153,8 +156,11 @@ backend/
│ │ ├── sessions.py # Session schemas │ │ ├── sessions.py # Session schemas
│ │ └── organizations.py # Organization schemas │ │ └── organizations.py # Organization schemas
│ │ │ │
│ ├── services/ # Business logic │ ├── services/ # Business logic layer
│ │ ├── auth_service.py # Authentication service │ │ ├── auth_service.py # Authentication service
│ │ ├── user_service.py # User management service
│ │ ├── session_service.py # Session management service
│ │ ├── organization_service.py # Organization service
│ │ ├── email_service.py # Email service │ │ ├── email_service.py # Email service
│ │ └── session_cleanup.py # Background cleanup │ │ └── session_cleanup.py # Background cleanup
│ │ │ │
@@ -168,20 +174,25 @@ backend/
├── tests/ # Test suite ├── tests/ # Test suite
│ ├── api/ # Integration tests │ ├── api/ # Integration tests
│ ├── crud/ # CRUD tests │ ├── repositories/ # Repository unit tests
│ ├── services/ # Service unit tests
│ ├── models/ # Model tests │ ├── models/ # Model tests
│ ├── services/ # Service tests
│ └── conftest.py # Test configuration │ └── conftest.py # Test configuration
├── docs/ # Documentation ├── docs/ # Documentation
│ ├── ARCHITECTURE.md # This file │ ├── ARCHITECTURE.md # This file
│ ├── CODING_STANDARDS.md # Coding standards │ ├── CODING_STANDARDS.md # Coding standards
│ ├── COMMON_PITFALLS.md # Common mistakes to avoid
│ ├── E2E_TESTING.md # E2E testing guide
│ └── FEATURE_EXAMPLE.md # Feature implementation guide │ └── FEATURE_EXAMPLE.md # Feature implementation guide
├── requirements.txt # Python dependencies ├── pyproject.toml # Dependencies, tool configs (Ruff, pytest, coverage, Pyright)
├── pytest.ini # Pytest configuration ├── uv.lock # Locked dependency versions (commit to git)
├── .coveragerc # Coverage configuration ├── Makefile # Development commands (quality, security, testing)
── alembic.ini # Alembic configuration ── .pre-commit-config.yaml # Pre-commit hook configuration
├── .secrets.baseline # detect-secrets baseline (known false positives)
├── alembic.ini # Alembic configuration
└── migrate.py # Migration helper script
``` ```
## Layered Architecture ## Layered Architecture
@@ -214,11 +225,11 @@ The application follows a strict 5-layer architecture:
└──────────────────────────┬──────────────────────────────────┘ └──────────────────────────┬──────────────────────────────────┘
│ calls │ calls
┌──────────────────────────▼──────────────────────────────────┐ ┌──────────────────────────▼──────────────────────────────────┐
CRUD Layer (crud/) Repository Layer (repositories/)
│ - Database operations │ │ - Database operations │
│ - Query building │ │ - Query building │
│ - Transaction management │ - Custom repository exceptions
│ - Error handling │ - No business logic
└──────────────────────────┬──────────────────────────────────┘ └──────────────────────────┬──────────────────────────────────┘
│ uses │ uses
┌──────────────────────────▼──────────────────────────────────┐ ┌──────────────────────────▼──────────────────────────────────┐
@@ -262,7 +273,7 @@ async def get_current_user_info(
**Rules**: **Rules**:
- Should NOT contain business logic - Should NOT contain business logic
- Should NOT directly perform database operations (use CRUD or services) - Should NOT directly call repositories (use services injected via `dependencies/services.py`)
- Must validate all input via Pydantic schemas - Must validate all input via Pydantic schemas
- Must specify response models - Must specify response models
- Should apply appropriate rate limits - Should apply appropriate rate limits
@@ -279,9 +290,9 @@ async def get_current_user_info(
**Example**: **Example**:
```python ```python
def get_current_user( async def get_current_user(
token: str = Depends(oauth2_scheme), token: str = Depends(oauth2_scheme),
db: Session = Depends(get_db) db: AsyncSession = Depends(get_db)
) -> User: ) -> User:
""" """
Extract and validate user from JWT token. Extract and validate user from JWT token.
@@ -295,7 +306,7 @@ def get_current_user(
except Exception: except Exception:
raise AuthenticationError("Invalid authentication credentials") raise AuthenticationError("Invalid authentication credentials")
user = user_crud.get(db, id=user_id) user = await user_repo.get(db, id=user_id)
if not user: if not user:
raise AuthenticationError("User not found") raise AuthenticationError("User not found")
@@ -313,7 +324,7 @@ def get_current_user(
**Responsibility**: Implement complex business logic **Responsibility**: Implement complex business logic
**Key Functions**: **Key Functions**:
- Orchestrate multiple CRUD operations - Orchestrate multiple repository operations
- Implement business rules - Implement business rules
- Handle external service integration - Handle external service integration
- Coordinate transactions - Coordinate transactions
@@ -323,9 +334,9 @@ def get_current_user(
class AuthService: class AuthService:
"""Authentication service with business logic.""" """Authentication service with business logic."""
def login( async def login(
self, self,
db: Session, db: AsyncSession,
email: str, email: str,
password: str, password: str,
request: Request request: Request
@@ -339,8 +350,8 @@ class AuthService:
3. Generate tokens 3. Generate tokens
4. Return tokens and user info 4. Return tokens and user info
""" """
# Validate credentials # Validate credentials via repository
user = user_crud.get_by_email(db, email=email) user = await user_repo.get_by_email(db, email=email)
if not user or not verify_password(password, user.hashed_password): if not user or not verify_password(password, user.hashed_password):
raise AuthenticationError("Invalid credentials") raise AuthenticationError("Invalid credentials")
@@ -350,11 +361,10 @@ class AuthService:
# Extract device info # Extract device info
device_info = extract_device_info(request) device_info = extract_device_info(request)
# Create session # Create session via repository
session = session_crud.create_session( session = await session_repo.create(
db, db,
user_id=user.id, obj_in=SessionCreate(user_id=user.id, **device_info)
device_info=device_info
) )
# Generate tokens # Generate tokens
@@ -373,75 +383,60 @@ class AuthService:
**Rules**: **Rules**:
- Contains business logic, not just data operations - Contains business logic, not just data operations
- Can call multiple CRUD operations - Can call multiple repository operations
- Should handle complex workflows - Should handle complex workflows
- Must maintain data consistency - Must maintain data consistency
- Should use transactions when needed - Should use transactions when needed
#### 4. CRUD Layer (`app/crud/`) #### 4. Repository Layer (`app/repositories/`)
**Responsibility**: Database operations and queries **Responsibility**: Database operations and queries — no business logic
**Key Functions**: **Key Functions**:
- Create, read, update, delete operations - Create, read, update, delete operations
- Build database queries - Build database queries
- Handle database errors - Raise custom repository exceptions (`DuplicateEntryError`, `IntegrityConstraintError`)
- Manage soft deletes - Manage soft deletes
- Implement pagination and filtering - Implement pagination and filtering
**Example**: **Example**:
```python ```python
class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): class SessionRepository(RepositoryBase[UserSession, SessionCreate, SessionUpdate]):
"""CRUD operations for user sessions.""" """Repository for user sessions — database operations only."""
def get_by_jti(self, db: Session, jti: UUID) -> Optional[UserSession]: async def get_by_jti(self, db: AsyncSession, *, jti: str) -> UserSession | None:
"""Get session by refresh token JTI.""" """Get session by refresh token JTI."""
try: result = await db.execute(
return ( select(UserSession).where(UserSession.refresh_token_jti == jti)
db.query(UserSession) )
.filter(UserSession.refresh_token_jti == jti) return result.scalar_one_or_none()
.first()
)
except Exception as e:
logger.error(f"Error getting session by JTI: {str(e)}")
return None
def get_active_by_jti( async def deactivate(self, db: AsyncSession, *, session_id: UUID) -> bool:
self,
db: Session,
jti: UUID
) -> Optional[UserSession]:
"""Get active session by refresh token JTI."""
session = self.get_by_jti(db, jti=jti)
if session and session.is_active and not session.is_expired:
return session
return None
def deactivate(self, db: Session, session_id: UUID) -> bool:
"""Deactivate a session (logout).""" """Deactivate a session (logout)."""
try: try:
session = self.get(db, id=session_id) session = await self.get(db, id=session_id)
if not session: if not session:
return False return False
session.is_active = False session.is_active = False
db.commit() await db.commit()
logger.info(f"Session {session_id} deactivated") logger.info(f"Session {session_id} deactivated")
return True return True
except Exception as e: except Exception as e:
db.rollback() await db.rollback()
logger.error(f"Error deactivating session: {str(e)}") logger.error(f"Error deactivating session: {str(e)}")
return False return False
``` ```
**Rules**: **Rules**:
- Should NOT contain business logic - Should NOT contain business logic
- Must handle database exceptions - Must raise custom repository exceptions (not raw `ValueError`/`IntegrityError`)
- Must use parameterized queries (SQLAlchemy does this) - Must use async SQLAlchemy 2.0 `select()` API (never `db.query()`)
- Should log all database errors - Should log all database errors
- Must rollback on errors - Must rollback on errors
- Should use soft deletes when possible - Should use soft deletes when possible
- **Never imported directly by routes** — always called through services
#### 5. Data Layer (`app/models/` + `app/schemas/`) #### 5. Data Layer (`app/models/` + `app/schemas/`)
@@ -546,51 +541,23 @@ SessionLocal = sessionmaker(
#### Dependency Injection Pattern #### Dependency Injection Pattern
```python ```python
def get_db() -> Generator[Session, None, None]: async def get_db() -> AsyncGenerator[AsyncSession, None]:
""" """
Database session dependency for FastAPI routes. Async database session dependency for FastAPI routes.
Automatically commits on success, rolls back on error. The session is passed to service methods; commit/rollback is
managed inside service or repository methods.
""" """
db = SessionLocal() async with AsyncSessionLocal() as db:
try:
yield db yield db
finally:
db.close()
# Usage in routes # Usage in routes — always through a service, never direct repository
@router.get("/users") @router.get("/users")
def list_users(db: Session = Depends(get_db)): async def list_users(
return user_crud.get_multi(db) user_service: UserService = Depends(get_user_service),
``` db: AsyncSession = Depends(get_db),
):
#### Context Manager Pattern return await user_service.get_users(db)
```python
@contextmanager
def transaction_scope() -> Generator[Session, None, None]:
"""
Context manager for database transactions.
Use for complex operations requiring multiple steps.
Automatically commits on success, rolls back on error.
"""
db = SessionLocal()
try:
yield db
db.commit()
except Exception:
db.rollback()
raise
finally:
db.close()
# Usage in services
def complex_operation():
with transaction_scope() as db:
user = user_crud.create(db, obj_in=user_data)
session = session_crud.create(db, session_data)
return user, session
``` ```
### Model Mixins ### Model Mixins
@@ -782,22 +749,15 @@ def get_profile(
```python ```python
@router.delete("/sessions/{session_id}") @router.delete("/sessions/{session_id}")
def revoke_session( async def revoke_session(
session_id: UUID, session_id: UUID,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db) session_service: SessionService = Depends(get_session_service),
db: AsyncSession = Depends(get_db),
): ):
"""Users can only revoke their own sessions.""" """Users can only revoke their own sessions."""
session = session_crud.get(db, id=session_id) # SessionService verifies ownership and raises NotFoundError / AuthorizationError
await session_service.revoke_session(db, session_id=session_id, user_id=current_user.id)
if not session:
raise NotFoundError("Session not found")
# Check ownership
if session.user_id != current_user.id:
raise AuthorizationError("You can only revoke your own sessions")
session_crud.deactivate(db, session_id=session_id)
return MessageResponse(success=True, message="Session revoked") return MessageResponse(success=True, message="Session revoked")
``` ```
@@ -1061,23 +1021,27 @@ from app.services.session_cleanup import cleanup_expired_sessions
scheduler = AsyncIOScheduler() scheduler = AsyncIOScheduler()
@app.on_event("startup") @asynccontextmanager
async def startup_event(): async def lifespan(app: FastAPI):
"""Start background jobs on application startup.""" """Application lifespan context manager."""
if not settings.IS_TEST: # Don't run in tests # Startup
if os.getenv("IS_TEST", "False") != "True":
scheduler.add_job( scheduler.add_job(
cleanup_expired_sessions, cleanup_expired_sessions,
"cron", "cron",
hour=2, # Run at 2 AM daily hour=2, # Run at 2 AM daily
id="cleanup_expired_sessions" id="cleanup_expired_sessions",
replace_existing=True,
) )
scheduler.start() scheduler.start()
logger.info("Background jobs started") logger.info("Background jobs started")
@app.on_event("shutdown") yield
async def shutdown_event():
"""Stop background jobs on application shutdown.""" # Shutdown
scheduler.shutdown() if os.getenv("IS_TEST", "False") != "True":
scheduler.shutdown()
await close_async_db() # Dispose database engine connections
``` ```
### Job Implementation ### Job Implementation
@@ -1092,8 +1056,8 @@ async def cleanup_expired_sessions():
Runs daily at 2 AM. Removes sessions expired for more than 30 days. Runs daily at 2 AM. Removes sessions expired for more than 30 days.
""" """
try: try:
with transaction_scope() as db: async with AsyncSessionLocal() as db:
count = session_crud.cleanup_expired(db, keep_days=30) count = await session_repo.cleanup_expired(db, keep_days=30)
logger.info(f"Cleaned up {count} expired sessions") logger.info(f"Cleaned up {count} expired sessions")
except Exception as e: except Exception as e:
logger.error(f"Error cleaning up sessions: {str(e)}", exc_info=True) logger.error(f"Error cleaning up sessions: {str(e)}", exc_info=True)
@@ -1110,7 +1074,7 @@ async def cleanup_expired_sessions():
│Integration │ ← API endpoint tests │Integration │ ← API endpoint tests
│ Tests │ │ Tests │
├─────────────┤ ├─────────────┤
│ Unit │ ← CRUD, services, utilities │ Unit │ ← repositories, services, utilities
│ Tests │ │ Tests │
└─────────────┘ └─────────────┘
``` ```
@@ -1205,6 +1169,8 @@ app.add_middleware(
## Performance Considerations ## Performance Considerations
> 📖 For the full benchmarking guide (how to run, read results, write new benchmarks, and manage baselines), see **[BENCHMARKS.md](BENCHMARKS.md)**.
### Database Connection Pooling ### Database Connection Pooling
- Pool size: 20 connections - Pool size: 20 connections

311
backend/docs/BENCHMARKS.md Normal file
View File

@@ -0,0 +1,311 @@
# Performance Benchmarks Guide
Automated performance benchmarking infrastructure using **pytest-benchmark** to detect latency regressions in critical API endpoints.
## Table of Contents
- [Why Benchmark?](#why-benchmark)
- [Quick Start](#quick-start)
- [How It Works](#how-it-works)
- [Understanding Results](#understanding-results)
- [Test Organization](#test-organization)
- [Writing Benchmark Tests](#writing-benchmark-tests)
- [Baseline Management](#baseline-management)
- [CI/CD Integration](#cicd-integration)
- [Troubleshooting](#troubleshooting)
---
## Why Benchmark?
Performance regressions are silent bugs — they don't break tests or cause errors, but they degrade the user experience over time. Common causes include:
- **Unintended N+1 queries** after adding a relationship
- **Heavier serialization** after adding new fields to a response model
- **Middleware overhead** from new security headers or logging
- **Dependency upgrades** that introduce slower code paths
Without automated benchmarks, these regressions go unnoticed until users complain. Performance benchmarks serve as an **early warning system** — they measure endpoint latency on every run and flag significant deviations from an established baseline.
### What benchmarks give you
| Benefit | Description |
|---------|-------------|
| **Regression detection** | Automatically flags when an endpoint becomes significantly slower |
| **Baseline tracking** | Stores known-good performance numbers for comparison |
| **Confidence in refactors** | Verify that code changes don't degrade response times |
| **Visibility** | Makes performance a first-class, measurable quality attribute |
---
## Quick Start
```bash
# Run benchmarks (no comparison, just see current numbers)
make benchmark
# Save current results as the baseline
make benchmark-save
# Run benchmarks and compare against the saved baseline
make benchmark-check
```
---
## How It Works
The benchmarking system has three layers:
### 1. pytest-benchmark integration
[pytest-benchmark](https://pytest-benchmark.readthedocs.io/) is a pytest plugin that provides a `benchmark` fixture. It handles:
- **Calibration**: Automatically determines how many iterations to run for statistical significance
- **Timing**: Uses `time.perf_counter` for high-resolution measurements
- **Statistics**: Computes min, max, mean, median, standard deviation, IQR, and outlier detection
- **Comparison**: Compares current results against saved baselines and flags regressions
### 2. Benchmark types
The test suite includes two categories of performance tests:
| Type | How it works | Examples |
|------|-------------|----------|
| **pytest-benchmark tests** | Uses the `benchmark` fixture for precise, multi-round timing | `test_health_endpoint_performance`, `test_openapi_schema_performance`, `test_password_hashing_performance`, `test_password_verification_performance`, `test_access_token_creation_performance`, `test_refresh_token_creation_performance`, `test_token_decode_performance` |
| **Manual latency tests** | Uses `time.perf_counter` with explicit thresholds (for async endpoints that pytest-benchmark doesn't support natively) | `test_login_latency`, `test_get_current_user_latency`, `test_register_latency`, `test_token_refresh_latency`, `test_sessions_list_latency`, `test_user_profile_update_latency` |
### 3. Regression detection
When running `make benchmark-check`, the system:
1. Runs all benchmark tests
2. Compares results against the saved baseline (`.benchmarks/` directory)
3. **Fails the build** if any test's mean time exceeds **200%** of the baseline (i.e., 3× slower)
The `200%` threshold in `--benchmark-compare-fail=mean:200%` means "fail if the mean increased by more than 200% relative to the baseline." This is deliberately generous to avoid false positives from normal run-to-run variance while still catching real regressions.
---
## Understanding Results
A typical benchmark output looks like this:
```
--------------------------------------------------------------------------------------- benchmark: 2 tests --------------------------------------------------------------------------------------
Name (time in ms) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_health_endpoint_performance 0.9841 (1.0) 1.5513 (1.0) 1.1390 (1.0) 0.1098 (1.0) 1.1151 (1.0) 0.1672 (1.0) 39;2 877.9666 (1.0) 133 1
test_openapi_schema_performance 1.6523 (1.68) 2.0892 (1.35) 1.7843 (1.57) 0.1553 (1.41) 1.7200 (1.54) 0.1727 (1.03) 2;0 560.4471 (0.64) 10 1
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
```
### Column reference
| Column | Meaning |
|--------|---------|
| **Min** | Fastest single execution |
| **Max** | Slowest single execution |
| **Mean** | Average across all rounds — the primary metric for regression detection |
| **StdDev** | How much results vary between rounds (lower = more stable) |
| **Median** | Middle value, less sensitive to outliers than mean |
| **IQR** | Interquartile range — spread of the middle 50% of results |
| **Outliers** | Format `A;B` — A = within 1 StdDev, B = within 1.5 IQR from quartiles |
| **OPS** | Operations per second (`1 / Mean`) |
| **Rounds** | How many times the test was executed (auto-calibrated) |
| **Iterations** | Iterations per round (usually 1 for ms-scale tests) |
### The ratio numbers `(1.0)`, `(1.68)`, etc.
These show how each test compares **to the best result in that column**. The fastest test is always `(1.0)`, and others show their relative factor. For example, `(1.68)` means "1.68× slower than the fastest."
### Color coding
- **Green**: The fastest (best) value in each column
- **Red**: The slowest (worst) value in each column
This is a **relative ranking within the current run** — red does NOT mean the test failed or that performance is bad. It simply highlights which endpoint is the slower one in the group.
### What's "normal"?
For this project's current endpoints:
| Test | Expected range | Why |
|------|---------------|-----|
| `GET /health` | ~11.5ms | Minimal logic, mocked DB check |
| `GET /api/v1/openapi.json` | ~1.52.5ms | Serializes entire API schema |
| `get_password_hash` | ~200ms | CPU-bound bcrypt hashing |
| `verify_password` | ~200ms | CPU-bound bcrypt verification |
| `create_access_token` | ~1720µs | JWT encoding with HMAC-SHA256 |
| `create_refresh_token` | ~1720µs | JWT encoding with HMAC-SHA256 |
| `decode_token` | ~2025µs | JWT decoding and claim validation |
| `POST /api/v1/auth/login` | < 500ms threshold | Includes bcrypt password verification |
| `POST /api/v1/auth/register` | < 500ms threshold | Includes bcrypt password hashing |
| `POST /api/v1/auth/refresh` | < 200ms threshold | Token rotation + DB session update |
| `GET /api/v1/users/me` | < 200ms threshold | DB lookup + token validation |
| `GET /api/v1/sessions/me` | < 200ms threshold | Session list query + token validation |
| `PATCH /api/v1/users/me` | < 200ms threshold | DB update + token validation |
---
## Test Organization
```
backend/tests/
├── benchmarks/
│ └── test_endpoint_performance.py # All performance benchmark tests
backend/.benchmarks/ # Saved baselines (auto-generated)
└── Linux-CPython-3.12-64bit/
└── 0001_baseline.json # Platform-specific baseline file
```
### Test markers
All benchmark tests use the `@pytest.mark.benchmark` marker. The `--benchmark-only` flag ensures that only tests using the `benchmark` fixture are executed during benchmark runs, while manual latency tests (async) are skipped.
---
## Writing Benchmark Tests
### Stateless endpoint (using pytest-benchmark fixture)
```python
import pytest
from fastapi.testclient import TestClient
def test_my_endpoint_performance(sync_client, benchmark):
"""Benchmark: GET /my-endpoint should respond within acceptable latency."""
result = benchmark(sync_client.get, "/my-endpoint")
assert result.status_code == 200
```
The `benchmark` fixture handles all timing, calibration, and statistics automatically. Just pass it the callable and arguments.
### Async / DB-dependent endpoint (manual timing)
For async endpoints that require database access, use manual timing with an explicit threshold:
```python
import time
import pytest
MAX_RESPONSE_MS = 300
@pytest.mark.asyncio
async def test_my_async_endpoint_latency(client, setup_fixture):
"""Performance: endpoint must respond under threshold."""
iterations = 5
total_ms = 0.0
for _ in range(iterations):
start = time.perf_counter()
response = await client.get("/api/v1/my-endpoint")
elapsed_ms = (time.perf_counter() - start) * 1000
total_ms += elapsed_ms
assert response.status_code == 200
mean_ms = total_ms / iterations
assert mean_ms < MAX_RESPONSE_MS, (
f"Latency regression: {mean_ms:.1f}ms exceeds {MAX_RESPONSE_MS}ms threshold"
)
```
### Guidelines for new benchmarks
1. **Benchmark critical paths** — endpoints users hit frequently or where latency matters most
2. **Mock external dependencies** for stateless tests to isolate endpoint overhead
3. **Set generous thresholds** for manual tests — account for CI variability
4. **Keep benchmarks fast** — they run on every check, so avoid heavy setup
---
## Baseline Management
### Saving a baseline
```bash
make benchmark-save
```
This runs all benchmarks and saves results to `.benchmarks/<platform>/0001_baseline.json`. The baseline captures:
- Mean, min, max, median, stddev for each test
- Machine info (CPU, OS, Python version)
- Timestamp
### Comparing against baseline
```bash
make benchmark-check
```
If no baseline exists, this command automatically creates one and prints a warning. On subsequent runs, it compares current results against the saved baseline.
### When to update the baseline
- **After intentional performance changes** (e.g., you optimized an endpoint — save the new, faster baseline)
- **After infrastructure changes** (e.g., new CI runner, different hardware)
- **After adding new benchmark tests** (the new tests need a baseline entry)
```bash
# Update the baseline after intentional changes
make benchmark-save
```
### Version control
The `.benchmarks/` directory can be committed to version control so that CI pipelines can compare against a known-good baseline. However, since benchmark results are machine-specific, you may prefer to generate baselines in CI rather than committing local results.
---
## CI/CD Integration
Add benchmark checking to your CI pipeline to catch regressions on every PR:
```yaml
# Example GitHub Actions step
- name: Performance regression check
run: |
cd backend
make benchmark-save # Create baseline from main branch
# ... apply PR changes ...
make benchmark-check # Compare PR against baseline
```
A more robust approach:
1. Save the baseline on the `main` branch after each merge
2. On PR branches, run `make benchmark-check` against the `main` baseline
3. The pipeline fails if any endpoint regresses beyond the 200% threshold
---
## Troubleshooting
### "No benchmark baseline found" warning
```
⚠️ No benchmark baseline found. Run 'make benchmark-save' first to create one.
```
This means no baseline file exists yet. The command will auto-create one. Future runs of `make benchmark-check` will compare against it.
### Machine info mismatch warning
```
WARNING: benchmark machine_info is different
```
This is expected when comparing baselines generated on a different machine or OS. The comparison still works, but absolute numbers may differ. Re-save the baseline on the current machine if needed.
### High variance (large StdDev)
If StdDev is high relative to the Mean, results may be unreliable. Common causes:
- System under load during benchmark run
- Garbage collection interference
- Thermal throttling
Try running benchmarks on an idle system or increasing `min_rounds` in `pyproject.toml`.
### Only 7 of 13 tests run
The async tests (`test_login_latency`, `test_get_current_user_latency`, `test_register_latency`, `test_token_refresh_latency`, `test_sessions_list_latency`, `test_user_profile_update_latency`) are skipped during `--benchmark-only` runs because they don't use the `benchmark` fixture. They run as part of the normal test suite (`make test`) with manual threshold assertions.

View File

@@ -75,15 +75,14 @@ def create_user(db: Session, user_in: UserCreate) -> User:
### 4. Code Formatting ### 4. Code Formatting
Use automated formatters: Use automated formatters:
- **Black**: Code formatting - **Ruff**: Code formatting and linting (replaces Black, isort, flake8)
- **isort**: Import sorting - **pyright**: Static type checking
- **flake8**: Linting
Run before committing: Run before committing (or use `make validate`):
```bash ```bash
black app tests uv run ruff format app tests
isort app tests uv run ruff check app tests
flake8 app tests uv run pyright app
``` ```
## Code Organization ## Code Organization
@@ -94,19 +93,17 @@ Follow the 5-layer architecture strictly:
``` ```
API Layer (routes/) API Layer (routes/)
↓ calls ↓ calls (via service injected from dependencies/services.py)
Dependencies (dependencies/)
↓ injects
Service Layer (services/) Service Layer (services/)
↓ calls ↓ calls
CRUD Layer (crud/) Repository Layer (repositories/)
↓ uses ↓ uses
Models & Schemas (models/, schemas/) Models & Schemas (models/, schemas/)
``` ```
**Rules:** **Rules:**
- Routes should NOT directly call CRUD operations (use services when business logic is needed) - Routes must NEVER import repositories directly — always use a service
- CRUD operations should NOT contain business logic - Services call repositories; repositories contain only database operations
- Models should NOT import from higher layers - Models should NOT import from higher layers
- Each layer should only depend on the layer directly below it - Each layer should only depend on the layer directly below it
@@ -125,7 +122,7 @@ from sqlalchemy.orm import Session
# 3. Local application imports # 3. Local application imports
from app.api.dependencies.auth import get_current_user from app.api.dependencies.auth import get_current_user
from app.crud import user_crud from app.api.dependencies.services import get_user_service
from app.models.user import User from app.models.user import User
from app.schemas.users import UserResponse, UserCreate from app.schemas.users import UserResponse, UserCreate
``` ```
@@ -217,7 +214,7 @@ if not user:
### Error Handling Pattern ### Error Handling Pattern
Always follow this pattern in CRUD operations (Async version): Always follow this pattern in repository operations (Async version):
```python ```python
from sqlalchemy.exc import IntegrityError, OperationalError, DataError from sqlalchemy.exc import IntegrityError, OperationalError, DataError
@@ -430,7 +427,7 @@ backend/app/alembic/versions/
## Database Operations ## Database Operations
### Async CRUD Pattern ### Async Repository Pattern
**IMPORTANT**: This application uses **async SQLAlchemy** with modern patterns for better performance and testability. **IMPORTANT**: This application uses **async SQLAlchemy** with modern patterns for better performance and testability.
@@ -442,19 +439,19 @@ backend/app/alembic/versions/
4. **Testability**: Easy to mock and test 4. **Testability**: Easy to mock and test
5. **Consistent Ordering**: Always order queries for pagination 5. **Consistent Ordering**: Always order queries for pagination
### Use the Async CRUD Base Class ### Use the Async Repository Base Class
Always inherit from `CRUDBase` for database operations: Always inherit from `RepositoryBase` for database operations:
```python ```python
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select from sqlalchemy import select
from app.crud.base import CRUDBase from app.repositories.base import RepositoryBase
from app.models.user import User from app.models.user import User
from app.schemas.users import UserCreate, UserUpdate from app.schemas.users import UserCreate, UserUpdate
class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]): class UserRepository(RepositoryBase[User, UserCreate, UserUpdate]):
"""CRUD operations for User model.""" """Repository for User model — database operations only."""
async def get_by_email( async def get_by_email(
self, self,
@@ -467,7 +464,7 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
) )
return result.scalar_one_or_none() return result.scalar_one_or_none()
user_crud = CRUDUser(User) user_repo = UserRepository(User)
``` ```
**Key Points:** **Key Points:**
@@ -476,6 +473,7 @@ user_crud = CRUDUser(User)
- Use `await db.execute()` for queries - Use `await db.execute()` for queries
- Use `.scalar_one_or_none()` instead of `.first()` - Use `.scalar_one_or_none()` instead of `.first()`
- Use `T | None` instead of `Optional[T]` - Use `T | None` instead of `Optional[T]`
- Repository instances are used internally by services — never import them in routes
### Modern SQLAlchemy Patterns ### Modern SQLAlchemy Patterns
@@ -563,13 +561,13 @@ async def create_user(
The database session is automatically managed by FastAPI. The database session is automatically managed by FastAPI.
Commit on success, rollback on error. Commit on success, rollback on error.
""" """
return await user_crud.create(db, obj_in=user_in) return await user_service.create_user(db, obj_in=user_in)
``` ```
**Key Points:** **Key Points:**
- Route functions must be `async def` - Route functions must be `async def`
- Database parameter is `AsyncSession` - Database parameter is `AsyncSession`
- Always `await` CRUD operations - Always `await` repository operations
#### In Services (Multiple Operations) #### In Services (Multiple Operations)
@@ -582,12 +580,11 @@ async def complex_operation(
""" """
Perform multiple database operations atomically. Perform multiple database operations atomically.
The session automatically commits on success or rolls back on error. Services call repositories; commit/rollback is handled inside
each repository method.
""" """
user = await user_crud.create(db, obj_in=user_data) user = await user_repo.create(db, obj_in=user_data)
session = await session_crud.create(db, obj_in=session_data) session = await session_repo.create(db, obj_in=session_data)
# Commit is handled by the route's dependency
return user, session return user, session
``` ```
@@ -597,10 +594,10 @@ Prefer soft deletes over hard deletes for audit trails:
```python ```python
# Good - Soft delete (sets deleted_at) # Good - Soft delete (sets deleted_at)
await user_crud.soft_delete(db, id=user_id) await user_repo.soft_delete(db, id=user_id)
# Acceptable only when required - Hard delete # Acceptable only when required - Hard delete
user_crud.remove(db, id=user_id) await user_repo.remove(db, id=user_id)
``` ```
### Query Patterns ### Query Patterns
@@ -740,9 +737,10 @@ Always implement pagination for list endpoints:
from app.schemas.common import PaginationParams, PaginatedResponse from app.schemas.common import PaginationParams, PaginatedResponse
@router.get("/users", response_model=PaginatedResponse[UserResponse]) @router.get("/users", response_model=PaginatedResponse[UserResponse])
def list_users( async def list_users(
pagination: PaginationParams = Depends(), pagination: PaginationParams = Depends(),
db: Session = Depends(get_db) user_service: UserService = Depends(get_user_service),
db: AsyncSession = Depends(get_db),
): ):
""" """
List all users with pagination. List all users with pagination.
@@ -750,10 +748,8 @@ def list_users(
Default page size: 20 Default page size: 20
Maximum page size: 100 Maximum page size: 100
""" """
users, total = user_crud.get_multi_with_total( users, total = await user_service.get_users(
db, db, skip=pagination.offset, limit=pagination.limit
skip=pagination.offset,
limit=pagination.limit
) )
return PaginatedResponse(data=users, pagination=pagination.create_meta(total)) return PaginatedResponse(data=users, pagination=pagination.create_meta(total))
``` ```
@@ -816,19 +812,17 @@ def admin_route(
pass pass
# Check ownership # Check ownership
def delete_resource( async def delete_resource(
resource_id: UUID, resource_id: UUID,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db) resource_service: ResourceService = Depends(get_resource_service),
db: AsyncSession = Depends(get_db),
): ):
resource = resource_crud.get(db, id=resource_id) # Service handles ownership check and raises appropriate errors
if not resource: await resource_service.delete_resource(
raise NotFoundError("Resource not found") db, resource_id=resource_id, user_id=current_user.id,
is_superuser=current_user.is_superuser,
if resource.user_id != current_user.id and not current_user.is_superuser: )
raise AuthorizationError("You can only delete your own resources")
resource_crud.remove(db, id=resource_id)
``` ```
### Input Validation ### Input Validation
@@ -862,9 +856,9 @@ tests/
├── api/ # Integration tests ├── api/ # Integration tests
│ ├── test_users.py │ ├── test_users.py
│ └── test_auth.py │ └── test_auth.py
├── crud/ # Unit tests for CRUD ├── repositories/ # Unit tests for repositories
├── models/ # Model tests ├── services/ # Unit tests for services
└── services/ # Service tests └── models/ # Model tests
``` ```
### Async Testing with pytest-asyncio ### Async Testing with pytest-asyncio
@@ -927,7 +921,7 @@ async def test_user(db_session: AsyncSession) -> User:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_user(db_session: AsyncSession, test_user: User): async def test_get_user(db_session: AsyncSession, test_user: User):
"""Test retrieving a user by ID.""" """Test retrieving a user by ID."""
user = await user_crud.get(db_session, id=test_user.id) user = await user_repo.get(db_session, id=test_user.id)
assert user is not None assert user is not None
assert user.email == test_user.email assert user.email == test_user.email
``` ```

View File

@@ -334,14 +334,14 @@ def login(request: Request, credentials: OAuth2PasswordRequestForm):
# ❌ WRONG - Returns password hash! # ❌ WRONG - Returns password hash!
@router.get("/users/{user_id}") @router.get("/users/{user_id}")
def get_user(user_id: UUID, db: Session = Depends(get_db)) -> User: def get_user(user_id: UUID, db: Session = Depends(get_db)) -> User:
return user_crud.get(db, id=user_id) # Returns ORM model with ALL fields! return user_repo.get(db, id=user_id) # Returns ORM model with ALL fields!
``` ```
```python ```python
# ✅ CORRECT - Use response schema # ✅ CORRECT - Use response schema
@router.get("/users/{user_id}", response_model=UserResponse) @router.get("/users/{user_id}", response_model=UserResponse)
def get_user(user_id: UUID, db: Session = Depends(get_db)): def get_user(user_id: UUID, db: Session = Depends(get_db)):
user = user_crud.get(db, id=user_id) user = user_repo.get(db, id=user_id)
if not user: if not user:
raise HTTPException(status_code=404, detail="User not found") raise HTTPException(status_code=404, detail="User not found")
return user # Pydantic filters to only UserResponse fields return user # Pydantic filters to only UserResponse fields
@@ -506,8 +506,8 @@ def revoke_session(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db) db: Session = Depends(get_db)
): ):
session = session_crud.get(db, id=session_id) session = session_repo.get(db, id=session_id)
session_crud.deactivate(db, session_id=session_id) session_repo.deactivate(db, session_id=session_id)
# BUG: User can revoke ANYONE'S session! # BUG: User can revoke ANYONE'S session!
return {"message": "Session revoked"} return {"message": "Session revoked"}
``` ```
@@ -520,7 +520,7 @@ def revoke_session(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db) db: Session = Depends(get_db)
): ):
session = session_crud.get(db, id=session_id) session = session_repo.get(db, id=session_id)
if not session: if not session:
raise NotFoundError("Session not found") raise NotFoundError("Session not found")
@@ -529,7 +529,7 @@ def revoke_session(
if session.user_id != current_user.id: if session.user_id != current_user.id:
raise AuthorizationError("You can only revoke your own sessions") raise AuthorizationError("You can only revoke your own sessions")
session_crud.deactivate(db, session_id=session_id) session_repo.deactivate(db, session_id=session_id)
return {"message": "Session revoked"} return {"message": "Session revoked"}
``` ```
@@ -616,7 +616,43 @@ def create_user(
return user return user
``` ```
**Rule**: Add type hints to ALL functions. Use `mypy` to enforce type checking. **Rule**: Add type hints to ALL functions. Use `pyright` to enforce type checking (`make type-check`).
---
---
### ❌ PITFALL #19: Importing Repositories Directly in Routes
**Issue**: Routes should never call repositories directly. The layered architecture requires all business operations to go through the service layer.
```python
# ❌ WRONG - Route bypasses service layer
from app.repositories.session import session_repo
@router.get("/sessions/me")
async def list_sessions(
current_user: User = Depends(get_current_active_user),
db: AsyncSession = Depends(get_db),
):
return await session_repo.get_user_sessions(db, user_id=current_user.id)
```
```python
# ✅ CORRECT - Route calls service injected via dependency
from app.api.dependencies.services import get_session_service
from app.services.session_service import SessionService
@router.get("/sessions/me")
async def list_sessions(
current_user: User = Depends(get_current_active_user),
session_service: SessionService = Depends(get_session_service),
db: AsyncSession = Depends(get_db),
):
return await session_service.get_user_sessions(db, user_id=current_user.id)
```
**Rule**: Routes import from `app.api.dependencies.services`, never from `app.repositories.*`. Services are the only callers of repositories.
--- ---
@@ -649,6 +685,11 @@ Use this checklist to catch issues before code review:
- [ ] Resource ownership verification - [ ] Resource ownership verification
- [ ] CORS configured (no wildcards in production) - [ ] CORS configured (no wildcards in production)
### Architecture
- [ ] Routes never import repositories directly (only services)
- [ ] Services call repositories; repositories call database only
- [ ] New service registered in `app/api/dependencies/services.py`
### Python ### Python
- [ ] Use `==` not `is` for value comparison - [ ] Use `==` not `is` for value comparison
- [ ] No mutable default arguments - [ ] No mutable default arguments
@@ -661,21 +702,18 @@ Use this checklist to catch issues before code review:
### Pre-commit Checks ### Pre-commit Checks
Add these to your development workflow: Add these to your development workflow (or use `make validate`):
```bash ```bash
# Format code # Format + lint (Ruff replaces Black, isort, flake8)
black app tests uv run ruff format app tests
isort app tests uv run ruff check app tests
# Type checking # Type checking
mypy app --strict uv run pyright app
# Linting
flake8 app tests
# Run tests # Run tests
pytest --cov=app --cov-report=term-missing IS_TEST=True uv run pytest --cov=app --cov-report=term-missing
# Check coverage (should be 80%+) # Check coverage (should be 80%+)
coverage report --fail-under=80 coverage report --fail-under=80
@@ -693,6 +731,6 @@ Add new entries when:
--- ---
**Last Updated**: 2025-10-31 **Last Updated**: 2026-02-28
**Issues Cataloged**: 18 common pitfalls **Issues Cataloged**: 19 common pitfalls
**Remember**: This document exists because these issues HAVE occurred. Don't skip it. **Remember**: This document exists because these issues HAVE occurred. Don't skip it.

View File

@@ -99,7 +99,7 @@ backend/tests/
│ └── test_database_workflows.py # PostgreSQL workflow tests │ └── test_database_workflows.py # PostgreSQL workflow tests
├── api/ # Integration tests (SQLite, fast) ├── api/ # Integration tests (SQLite, fast)
├── crud/ # Unit tests ├── repositories/ # Repository unit tests
└── conftest.py # Standard fixtures └── conftest.py # Standard fixtures
``` ```

File diff suppressed because it is too large Load Diff

View File

@@ -1,4 +1,4 @@
#!/bin/bash #!/bin/sh
set -e set -e
echo "Starting Backend" echo "Starting Backend"

View File

@@ -20,43 +20,36 @@ dependencies = [
"uvicorn>=0.34.0", "uvicorn>=0.34.0",
"pydantic>=2.10.6", "pydantic>=2.10.6",
"pydantic-settings>=2.2.1", "pydantic-settings>=2.2.1",
"python-multipart>=0.0.19", "python-multipart>=0.0.22",
"fastapi-utils==0.8.0", "fastapi-utils==0.8.0",
# Database # Database
"sqlalchemy>=2.0.29", "sqlalchemy>=2.0.29",
"alembic>=1.14.1", "alembic>=1.14.1",
"psycopg2-binary>=2.9.9", "psycopg2-binary>=2.9.9",
"asyncpg>=0.29.0", "asyncpg>=0.29.0",
"aiosqlite==0.21.0", "aiosqlite==0.21.0",
# Environment configuration # Environment configuration
"python-dotenv>=1.0.1", "python-dotenv>=1.0.1",
# API utilities # API utilities
"email-validator>=2.1.0.post1", "email-validator>=2.1.0.post1",
"ujson>=5.9.0", "ujson>=5.9.0",
# CORS and security # CORS and security
"starlette>=0.40.0", "starlette>=0.40.0",
"starlette-csrf>=1.4.5", "starlette-csrf>=1.4.5",
"slowapi>=0.1.9", "slowapi>=0.1.9",
# Utilities # Utilities
"httpx>=0.27.0", "httpx>=0.27.0",
"tenacity>=8.2.3", "tenacity>=8.2.3",
"pytz>=2024.1", "pytz>=2024.1",
"pillow>=10.3.0", "pillow>=12.1.1",
"apscheduler==3.11.0", "apscheduler==3.11.0",
# Security and authentication
# Security and authentication (pinned for reproducibility) "PyJWT>=2.9.0",
"python-jose==3.4.0",
"passlib==1.7.4",
"bcrypt==4.2.1", "bcrypt==4.2.1",
"cryptography==44.0.1", "cryptography>=46.0.5",
# OAuth authentication # OAuth authentication
"authlib>=1.3.0", "authlib>=1.6.6",
"urllib3>=2.6.3",
] ]
# Development dependencies # Development dependencies
@@ -72,7 +65,18 @@ dev = [
# Development tools # Development tools
"ruff>=0.8.0", # All-in-one: linting, formatting, import sorting "ruff>=0.8.0", # All-in-one: linting, formatting, import sorting
"mypy>=1.8.0", # Type checking "pyright>=1.1.390", # Type checking
# Security auditing
"pip-audit>=2.7.0", # Dependency vulnerability scanning (PyPA/OSV)
"pip-licenses>=4.0.0", # License compliance checking
"detect-secrets>=1.5.0", # Hardcoded secrets detection
# Performance benchmarking
"pytest-benchmark>=4.0.0", # Performance regression detection
# Pre-commit hooks
"pre-commit>=4.0.0", # Git pre-commit hook framework
] ]
# E2E testing with real PostgreSQL (requires Docker) # E2E testing with real PostgreSQL (requires Docker)
@@ -131,6 +135,8 @@ select = [
"RUF", # Ruff-specific "RUF", # Ruff-specific
"ASYNC", # flake8-async "ASYNC", # flake8-async
"S", # flake8-bandit (security) "S", # flake8-bandit (security)
"G", # flake8-logging-format (logging best practices)
"T20", # flake8-print (no print statements in production code)
] ]
# Ignore specific rules # Ignore specific rules
@@ -154,11 +160,13 @@ unfixable = []
[tool.ruff.lint.per-file-ignores] [tool.ruff.lint.per-file-ignores]
"app/alembic/env.py" = ["E402", "F403", "F405"] # Alembic requires specific import order "app/alembic/env.py" = ["E402", "F403", "F405"] # Alembic requires specific import order
"app/alembic/versions/*.py" = ["E402"] # Migration files have specific structure "app/alembic/versions/*.py" = ["E402"] # Migration files have specific structure
"tests/**/*.py" = ["S101", "N806", "B017", "N817", "S110", "ASYNC251", "RUF043"] # pytest: asserts, CamelCase fixtures, blind exceptions, try-pass patterns, and async test helpers are intentional "tests/**/*.py" = ["S101", "N806", "B017", "N817", "ASYNC251", "RUF043", "T20"] # pytest: asserts, CamelCase fixtures, blind exceptions, async test helpers, and print for debugging are intentional
"app/models/__init__.py" = ["F401"] # __init__ files re-export modules "app/models/__init__.py" = ["F401"] # __init__ files re-export modules
"app/models/base.py" = ["F401"] # Re-exports Base for use by other models "app/models/base.py" = ["F401"] # Re-exports Base for use by other models
"app/utils/test_utils.py" = ["N806"] # SQLAlchemy session factories use CamelCase convention "app/utils/test_utils.py" = ["N806"] # SQLAlchemy session factories use CamelCase convention
"app/main.py" = ["N806"] # Constants use UPPER_CASE convention "app/main.py" = ["N806"] # Constants use UPPER_CASE convention
"app/init_db.py" = ["T20"] # CLI script uses print for user-facing output
"migrate.py" = ["T20"] # CLI script uses print for user-facing output
# ============================================================================ # ============================================================================
# Ruff Import Sorting (isort replacement) # Ruff Import Sorting (isort replacement)
@@ -185,120 +193,6 @@ indent-style = "space"
skip-magic-trailing-comma = false skip-magic-trailing-comma = false
line-ending = "lf" line-ending = "lf"
# ============================================================================
# mypy Configuration - Type Checking
# ============================================================================
[tool.mypy]
python_version = "3.12"
warn_return_any = false # SQLAlchemy queries return Any - overly strict
warn_unused_configs = true
disallow_untyped_defs = false # Gradual typing - enable later
disallow_incomplete_defs = false
check_untyped_defs = true
no_implicit_optional = true
warn_redundant_casts = true
warn_unused_ignores = true
warn_no_return = true
strict_equality = true
ignore_missing_imports = false
explicit_package_bases = true
namespace_packages = true
# Pydantic plugin for better validation
plugins = ["pydantic.mypy"]
# Per-module options
[[tool.mypy.overrides]]
module = "alembic.*"
ignore_errors = true
[[tool.mypy.overrides]]
module = "app.alembic.*"
ignore_errors = true
[[tool.mypy.overrides]]
module = "sqlalchemy.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "fastapi_utils.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "slowapi.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "jose.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "passlib.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "pydantic_settings.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "fastapi.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "apscheduler.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "starlette.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "authlib.*"
ignore_missing_imports = true
# SQLAlchemy ORM models - Column descriptors cause type confusion
[[tool.mypy.overrides]]
module = "app.models.*"
disable_error_code = ["assignment", "arg-type", "return-value"]
# CRUD operations - Generic ModelType and SQLAlchemy Result issues
[[tool.mypy.overrides]]
module = "app.crud.*"
disable_error_code = ["attr-defined", "assignment", "arg-type", "return-value"]
# API routes - SQLAlchemy Column to Pydantic schema conversions
[[tool.mypy.overrides]]
module = "app.api.routes.*"
disable_error_code = ["arg-type", "call-arg", "call-overload", "assignment"]
# API dependencies - Similar SQLAlchemy Column issues
[[tool.mypy.overrides]]
module = "app.api.dependencies.*"
disable_error_code = ["arg-type"]
# FastAPI exception handlers have correct signatures despite mypy warnings
[[tool.mypy.overrides]]
module = "app.main"
disable_error_code = ["arg-type"]
# Auth service - SQLAlchemy Column issues
[[tool.mypy.overrides]]
module = "app.services.auth_service"
disable_error_code = ["assignment", "arg-type"]
# Test utils - Testing patterns
[[tool.mypy.overrides]]
module = "app.utils.auth_test_utils"
disable_error_code = ["assignment", "arg-type"]
# ============================================================================
# Pydantic mypy plugin configuration
# ============================================================================
[tool.pydantic-mypy]
init_forbid_extra = true
init_typed = true
warn_required_dynamic_aliases = true
# ============================================================================ # ============================================================================
# Pytest Configuration # Pytest Configuration
# ============================================================================ # ============================================================================
@@ -315,12 +209,15 @@ addopts = [
"--cov=app", "--cov=app",
"--cov-report=term-missing", "--cov-report=term-missing",
"--cov-report=html", "--cov-report=html",
"--ignore=tests/benchmarks", # benchmarks are incompatible with xdist; run via 'make benchmark'
"-p", "no:benchmark", # disable pytest-benchmark plugin during normal runs (conflicts with xdist)
] ]
markers = [ markers = [
"sqlite: marks tests that should run on SQLite (mocked).", "sqlite: marks tests that should run on SQLite (mocked).",
"postgres: marks tests that require a real PostgreSQL database.", "postgres: marks tests that require a real PostgreSQL database.",
"e2e: marks end-to-end tests requiring Docker containers.", "e2e: marks end-to-end tests requiring Docker containers.",
"schemathesis: marks Schemathesis-generated API tests.", "schemathesis: marks Schemathesis-generated API tests.",
"benchmark: marks performance benchmark tests.",
] ]
asyncio_default_fixture_loop_scope = "function" asyncio_default_fixture_loop_scope = "function"

View File

@@ -0,0 +1,23 @@
{
"include": ["app"],
"exclude": ["app/alembic"],
"pythonVersion": "3.12",
"venvPath": ".",
"venv": ".venv",
"typeCheckingMode": "standard",
"reportMissingImports": true,
"reportMissingTypeStubs": false,
"reportUnknownMemberType": false,
"reportUnknownVariableType": false,
"reportUnknownArgumentType": false,
"reportUnknownParameterType": false,
"reportUnknownLambdaType": false,
"reportReturnType": true,
"reportUnusedImport": false,
"reportGeneralTypeIssues": false,
"reportAttributeAccessIssue": false,
"reportArgumentType": false,
"strictListInference": false,
"strictDictionaryInference": false,
"strictSetInference": false
}

View File

@@ -147,7 +147,7 @@ class TestAdminCreateUser:
headers={"Authorization": f"Bearer {superuser_token}"}, headers={"Authorization": f"Bearer {superuser_token}"},
) )
assert response.status_code == status.HTTP_404_NOT_FOUND assert response.status_code == status.HTTP_409_CONFLICT
class TestAdminGetUser: class TestAdminGetUser:
@@ -565,7 +565,7 @@ class TestAdminCreateOrganization:
headers={"Authorization": f"Bearer {superuser_token}"}, headers={"Authorization": f"Bearer {superuser_token}"},
) )
assert response.status_code == status.HTTP_404_NOT_FOUND assert response.status_code == status.HTTP_409_CONFLICT
class TestAdminGetOrganization: class TestAdminGetOrganization:

View File

@@ -45,7 +45,7 @@ class TestAdminListUsersFilters:
async def test_list_users_database_error_propagates(self, client, superuser_token): async def test_list_users_database_error_propagates(self, client, superuser_token):
"""Test that database errors propagate correctly (covers line 118-120).""" """Test that database errors propagate correctly (covers line 118-120)."""
with patch( with patch(
"app.api.routes.admin.user_crud.get_multi_with_total", "app.api.routes.admin.user_service.list_users",
side_effect=Exception("DB error"), side_effect=Exception("DB error"),
): ):
with pytest.raises(Exception): with pytest.raises(Exception):
@@ -74,8 +74,8 @@ class TestAdminCreateUserErrors:
}, },
) )
# Should get error for duplicate email # Should get conflict for duplicate email
assert response.status_code == status.HTTP_404_NOT_FOUND assert response.status_code == status.HTTP_409_CONFLICT
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_user_unexpected_error_propagates( async def test_create_user_unexpected_error_propagates(
@@ -83,7 +83,7 @@ class TestAdminCreateUserErrors:
): ):
"""Test unexpected errors during user creation (covers line 151-153).""" """Test unexpected errors during user creation (covers line 151-153)."""
with patch( with patch(
"app.api.routes.admin.user_crud.create", "app.api.routes.admin.user_service.create_user",
side_effect=RuntimeError("Unexpected error"), side_effect=RuntimeError("Unexpected error"),
): ):
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
@@ -135,7 +135,7 @@ class TestAdminUpdateUserErrors:
): ):
"""Test unexpected errors during user update (covers line 206-208).""" """Test unexpected errors during user update (covers line 206-208)."""
with patch( with patch(
"app.api.routes.admin.user_crud.update", "app.api.routes.admin.user_service.update_user",
side_effect=RuntimeError("Update failed"), side_effect=RuntimeError("Update failed"),
): ):
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
@@ -166,7 +166,7 @@ class TestAdminDeleteUserErrors:
): ):
"""Test unexpected errors during user deletion (covers line 238-240).""" """Test unexpected errors during user deletion (covers line 238-240)."""
with patch( with patch(
"app.api.routes.admin.user_crud.soft_delete", "app.api.routes.admin.user_service.soft_delete_user",
side_effect=Exception("Delete failed"), side_effect=Exception("Delete failed"),
): ):
with pytest.raises(Exception): with pytest.raises(Exception):
@@ -196,7 +196,7 @@ class TestAdminActivateUserErrors:
): ):
"""Test unexpected errors during user activation (covers line 282-284).""" """Test unexpected errors during user activation (covers line 282-284)."""
with patch( with patch(
"app.api.routes.admin.user_crud.update", "app.api.routes.admin.user_service.update_user",
side_effect=Exception("Activation failed"), side_effect=Exception("Activation failed"),
): ):
with pytest.raises(Exception): with pytest.raises(Exception):
@@ -238,7 +238,7 @@ class TestAdminDeactivateUserErrors:
): ):
"""Test unexpected errors during user deactivation (covers line 326-328).""" """Test unexpected errors during user deactivation (covers line 326-328)."""
with patch( with patch(
"app.api.routes.admin.user_crud.update", "app.api.routes.admin.user_service.update_user",
side_effect=Exception("Deactivation failed"), side_effect=Exception("Deactivation failed"),
): ):
with pytest.raises(Exception): with pytest.raises(Exception):
@@ -258,7 +258,7 @@ class TestAdminListOrganizationsErrors:
async def test_list_organizations_database_error(self, client, superuser_token): async def test_list_organizations_database_error(self, client, superuser_token):
"""Test list organizations with database error (covers line 427-456).""" """Test list organizations with database error (covers line 427-456)."""
with patch( with patch(
"app.api.routes.admin.organization_crud.get_multi_with_member_counts", "app.api.routes.admin.organization_service.get_multi_with_member_counts",
side_effect=Exception("DB error"), side_effect=Exception("DB error"),
): ):
with pytest.raises(Exception): with pytest.raises(Exception):
@@ -299,14 +299,14 @@ class TestAdminCreateOrganizationErrors:
}, },
) )
# Should get error for duplicate slug # Should get conflict for duplicate slug
assert response.status_code == status.HTTP_404_NOT_FOUND assert response.status_code == status.HTTP_409_CONFLICT
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_organization_unexpected_error(self, client, superuser_token): async def test_create_organization_unexpected_error(self, client, superuser_token):
"""Test unexpected errors during organization creation (covers line 484-485).""" """Test unexpected errors during organization creation (covers line 484-485)."""
with patch( with patch(
"app.api.routes.admin.organization_crud.create", "app.api.routes.admin.organization_service.create_organization",
side_effect=RuntimeError("Creation failed"), side_effect=RuntimeError("Creation failed"),
): ):
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
@@ -367,7 +367,7 @@ class TestAdminUpdateOrganizationErrors:
org_id = org.id org_id = org.id
with patch( with patch(
"app.api.routes.admin.organization_crud.update", "app.api.routes.admin.organization_service.update_organization",
side_effect=Exception("Update failed"), side_effect=Exception("Update failed"),
): ):
with pytest.raises(Exception): with pytest.raises(Exception):
@@ -412,7 +412,7 @@ class TestAdminDeleteOrganizationErrors:
org_id = org.id org_id = org.id
with patch( with patch(
"app.api.routes.admin.organization_crud.remove", "app.api.routes.admin.organization_service.remove_organization",
side_effect=Exception("Delete failed"), side_effect=Exception("Delete failed"),
): ):
with pytest.raises(Exception): with pytest.raises(Exception):
@@ -456,7 +456,7 @@ class TestAdminListOrganizationMembersErrors:
org_id = org.id org_id = org.id
with patch( with patch(
"app.api.routes.admin.organization_crud.get_organization_members", "app.api.routes.admin.organization_service.get_organization_members",
side_effect=Exception("DB error"), side_effect=Exception("DB error"),
): ):
with pytest.raises(Exception): with pytest.raises(Exception):
@@ -531,7 +531,7 @@ class TestAdminAddOrganizationMemberErrors:
org_id = org.id org_id = org.id
with patch( with patch(
"app.api.routes.admin.organization_crud.add_user", "app.api.routes.admin.organization_service.add_member",
side_effect=Exception("Add failed"), side_effect=Exception("Add failed"),
): ):
with pytest.raises(Exception): with pytest.raises(Exception):
@@ -587,7 +587,7 @@ class TestAdminRemoveOrganizationMemberErrors:
org_id = org.id org_id = org.id
with patch( with patch(
"app.api.routes.admin.organization_crud.remove_user", "app.api.routes.admin.organization_service.remove_member",
side_effect=Exception("Remove failed"), side_effect=Exception("Remove failed"),
): ):
with pytest.raises(Exception): with pytest.raises(Exception):

View File

@@ -19,7 +19,7 @@ class TestLoginSessionCreationFailure:
"""Test that login succeeds even if session creation fails.""" """Test that login succeeds even if session creation fails."""
# Mock session creation to fail # Mock session creation to fail
with patch( with patch(
"app.api.routes.auth.session_crud.create_session", "app.api.routes.auth.session_service.create_session",
side_effect=Exception("Session creation failed"), side_effect=Exception("Session creation failed"),
): ):
response = await client.post( response = await client.post(
@@ -43,7 +43,7 @@ class TestOAuthLoginSessionCreationFailure:
): ):
"""Test OAuth login succeeds even if session creation fails.""" """Test OAuth login succeeds even if session creation fails."""
with patch( with patch(
"app.api.routes.auth.session_crud.create_session", "app.api.routes.auth.session_service.create_session",
side_effect=Exception("Session failed"), side_effect=Exception("Session failed"),
): ):
response = await client.post( response = await client.post(
@@ -76,7 +76,7 @@ class TestRefreshTokenSessionUpdateFailure:
# Mock session update to fail # Mock session update to fail
with patch( with patch(
"app.api.routes.auth.session_crud.update_refresh_token", "app.api.routes.auth.session_service.update_refresh_token",
side_effect=Exception("Update failed"), side_effect=Exception("Update failed"),
): ):
response = await client.post( response = await client.post(
@@ -130,7 +130,7 @@ class TestLogoutWithNonExistentSession:
tokens = response.json() tokens = response.json()
# Mock session lookup to return None # Mock session lookup to return None
with patch("app.api.routes.auth.session_crud.get_by_jti", return_value=None): with patch("app.api.routes.auth.session_service.get_by_jti", return_value=None):
response = await client.post( response = await client.post(
"/api/v1/auth/logout", "/api/v1/auth/logout",
headers={"Authorization": f"Bearer {tokens['access_token']}"}, headers={"Authorization": f"Bearer {tokens['access_token']}"},
@@ -157,7 +157,7 @@ class TestLogoutUnexpectedError:
# Mock to raise unexpected error # Mock to raise unexpected error
with patch( with patch(
"app.api.routes.auth.session_crud.get_by_jti", "app.api.routes.auth.session_service.get_by_jti",
side_effect=Exception("Unexpected error"), side_effect=Exception("Unexpected error"),
): ):
response = await client.post( response = await client.post(
@@ -186,7 +186,7 @@ class TestLogoutAllUnexpectedError:
# Mock to raise database error # Mock to raise database error
with patch( with patch(
"app.api.routes.auth.session_crud.deactivate_all_user_sessions", "app.api.routes.auth.session_service.deactivate_all_user_sessions",
side_effect=Exception("DB error"), side_effect=Exception("DB error"),
): ):
response = await client.post( response = await client.post(
@@ -212,7 +212,7 @@ class TestPasswordResetConfirmSessionInvalidation:
# Mock session invalidation to fail # Mock session invalidation to fail
with patch( with patch(
"app.api.routes.auth.session_crud.deactivate_all_user_sessions", "app.api.routes.auth.session_service.deactivate_all_user_sessions",
side_effect=Exception("Invalidation failed"), side_effect=Exception("Invalidation failed"),
): ):
response = await client.post( response = await client.post(

View File

@@ -334,7 +334,7 @@ class TestPasswordResetConfirm:
token = create_password_reset_token(async_test_user.email) token = create_password_reset_token(async_test_user.email)
# Mock the database commit to raise an exception # Mock the database commit to raise an exception
with patch("app.api.routes.auth.user_crud.get_by_email") as mock_get: with patch("app.services.auth_service.user_repo.get_by_email") as mock_get:
mock_get.side_effect = Exception("Database error") mock_get.side_effect = Exception("Database error")
response = await client.post( response = await client.post(

View File

@@ -12,8 +12,8 @@ These tests prevent real-world attack scenarios.
import pytest import pytest
from httpx import AsyncClient from httpx import AsyncClient
from app.crud.session import session as session_crud
from app.models.user import User from app.models.user import User
from app.repositories.session import session_repo as session_repo
class TestRevokedSessionSecurity: class TestRevokedSessionSecurity:
@@ -117,7 +117,7 @@ class TestRevokedSessionSecurity:
async with SessionLocal() as session: async with SessionLocal() as session:
# Find and delete the session # Find and delete the session
db_session = await session_crud.get_by_jti(session, jti=jti) db_session = await session_repo.get_by_jti(session, jti=jti)
if db_session: if db_session:
await session.delete(db_session) await session.delete(db_session)
await session.commit() await session.commit()

View File

@@ -8,7 +8,7 @@ from uuid import uuid4
import pytest import pytest
from app.crud.oauth import oauth_account from app.repositories.oauth_account import oauth_account_repo as oauth_account
from app.schemas.oauth import OAuthAccountCreate from app.schemas.oauth import OAuthAccountCreate
@@ -349,7 +349,7 @@ class TestOAuthProviderEndpoints:
_test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
# Create a test client # Create a test client
from app.crud.oauth import oauth_client from app.repositories.oauth_client import oauth_client_repo as oauth_client
from app.schemas.oauth import OAuthClientCreate from app.schemas.oauth import OAuthClientCreate
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
@@ -386,7 +386,7 @@ class TestOAuthProviderEndpoints:
_test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
# Create a test client # Create a test client
from app.crud.oauth import oauth_client from app.repositories.oauth_client import oauth_client_repo as oauth_client
from app.schemas.oauth import OAuthClientCreate from app.schemas.oauth import OAuthClientCreate
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:

View File

@@ -537,7 +537,7 @@ class TestOrganizationExceptionHandlers:
): ):
"""Test generic exception handler in get_my_organizations (covers lines 81-83).""" """Test generic exception handler in get_my_organizations (covers lines 81-83)."""
with patch( with patch(
"app.crud.organization.organization.get_user_organizations_with_details", "app.api.routes.organizations.organization_service.get_user_organizations_with_details",
side_effect=Exception("Database connection lost"), side_effect=Exception("Database connection lost"),
): ):
# The exception handler logs and re-raises, so we expect the exception # The exception handler logs and re-raises, so we expect the exception
@@ -554,7 +554,7 @@ class TestOrganizationExceptionHandlers:
): ):
"""Test generic exception handler in get_organization (covers lines 124-128).""" """Test generic exception handler in get_organization (covers lines 124-128)."""
with patch( with patch(
"app.crud.organization.organization.get", "app.api.routes.organizations.organization_service.get_organization",
side_effect=Exception("Database timeout"), side_effect=Exception("Database timeout"),
): ):
with pytest.raises(Exception, match="Database timeout"): with pytest.raises(Exception, match="Database timeout"):
@@ -569,7 +569,7 @@ class TestOrganizationExceptionHandlers:
): ):
"""Test generic exception handler in get_organization_members (covers lines 170-172).""" """Test generic exception handler in get_organization_members (covers lines 170-172)."""
with patch( with patch(
"app.crud.organization.organization.get_organization_members", "app.api.routes.organizations.organization_service.get_organization_members",
side_effect=Exception("Connection pool exhausted"), side_effect=Exception("Connection pool exhausted"),
): ):
with pytest.raises(Exception, match="Connection pool exhausted"): with pytest.raises(Exception, match="Connection pool exhausted"):
@@ -591,11 +591,11 @@ class TestOrganizationExceptionHandlers:
admin_token = login_response.json()["access_token"] admin_token = login_response.json()["access_token"]
with patch( with patch(
"app.crud.organization.organization.get", "app.api.routes.organizations.organization_service.get_organization",
return_value=test_org_with_user_admin, return_value=test_org_with_user_admin,
): ):
with patch( with patch(
"app.crud.organization.organization.update", "app.api.routes.organizations.organization_service.update_organization",
side_effect=Exception("Write lock timeout"), side_effect=Exception("Write lock timeout"),
): ):
with pytest.raises(Exception, match="Write lock timeout"): with pytest.raises(Exception, match="Write lock timeout"):

View File

@@ -11,9 +11,9 @@ These tests prevent unauthorized access and privilege escalation.
import pytest import pytest
from httpx import AsyncClient from httpx import AsyncClient
from app.crud.user import user as user_crud
from app.models.organization import Organization from app.models.organization import Organization
from app.models.user import User from app.models.user import User
from app.repositories.user import user_repo as user_repo
class TestInactiveUserBlocking: class TestInactiveUserBlocking:
@@ -50,7 +50,7 @@ class TestInactiveUserBlocking:
# Step 2: Admin deactivates the user # Step 2: Admin deactivates the user
async with SessionLocal() as session: async with SessionLocal() as session:
user = await user_crud.get(session, id=async_test_user.id) user = await user_repo.get(session, id=async_test_user.id)
user.is_active = False user.is_active = False
await session.commit() await session.commit()
@@ -80,7 +80,7 @@ class TestInactiveUserBlocking:
# Deactivate user # Deactivate user
async with SessionLocal() as session: async with SessionLocal() as session:
user = await user_crud.get(session, id=async_test_user.id) user = await user_repo.get(session, id=async_test_user.id)
user.is_active = False user.is_active = False
await session.commit() await session.commit()

View File

@@ -39,7 +39,7 @@ async def async_test_user2(async_test_db):
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
from app.crud.user import user as user_crud from app.repositories.user import user_repo as user_repo
from app.schemas.users import UserCreate from app.schemas.users import UserCreate
user_data = UserCreate( user_data = UserCreate(
@@ -48,7 +48,7 @@ async def async_test_user2(async_test_db):
first_name="Test", first_name="Test",
last_name="User2", last_name="User2",
) )
user = await user_crud.create(session, obj_in=user_data) user = await user_repo.create(session, obj_in=user_data)
await session.commit() await session.commit()
await session.refresh(user) await session.refresh(user)
return user return user
@@ -191,9 +191,9 @@ class TestRevokeSession:
# Verify session is deactivated # Verify session is deactivated
async with SessionLocal() as session: async with SessionLocal() as session:
from app.crud.session import session as session_crud from app.repositories.session import session_repo as session_repo
revoked_session = await session_crud.get(session, id=str(session_id)) revoked_session = await session_repo.get(session, id=str(session_id))
assert revoked_session.is_active is False assert revoked_session.is_active is False
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -267,8 +267,8 @@ class TestCleanupExpiredSessions:
"""Test successfully cleaning up expired sessions.""" """Test successfully cleaning up expired sessions."""
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
# Create expired and active sessions using CRUD to avoid greenlet issues # Create expired and active sessions using repository to avoid greenlet issues
from app.crud.session import session as session_crud from app.repositories.session import session_repo as session_repo
from app.schemas.sessions import SessionCreate from app.schemas.sessions import SessionCreate
async with SessionLocal() as db: async with SessionLocal() as db:
@@ -282,7 +282,7 @@ class TestCleanupExpiredSessions:
expires_at=datetime.now(UTC) - timedelta(days=1), expires_at=datetime.now(UTC) - timedelta(days=1),
last_used_at=datetime.now(UTC) - timedelta(days=2), last_used_at=datetime.now(UTC) - timedelta(days=2),
) )
e1 = await session_crud.create_session(db, obj_in=e1_data) e1 = await session_repo.create_session(db, obj_in=e1_data)
e1.is_active = False e1.is_active = False
db.add(e1) db.add(e1)
@@ -296,7 +296,7 @@ class TestCleanupExpiredSessions:
expires_at=datetime.now(UTC) - timedelta(hours=1), expires_at=datetime.now(UTC) - timedelta(hours=1),
last_used_at=datetime.now(UTC) - timedelta(hours=2), last_used_at=datetime.now(UTC) - timedelta(hours=2),
) )
e2 = await session_crud.create_session(db, obj_in=e2_data) e2 = await session_repo.create_session(db, obj_in=e2_data)
e2.is_active = False e2.is_active = False
db.add(e2) db.add(e2)
@@ -310,7 +310,7 @@ class TestCleanupExpiredSessions:
expires_at=datetime.now(UTC) + timedelta(days=7), expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC), last_used_at=datetime.now(UTC),
) )
await session_crud.create_session(db, obj_in=a1_data) await session_repo.create_session(db, obj_in=a1_data)
await db.commit() await db.commit()
# Cleanup expired sessions # Cleanup expired sessions
@@ -333,8 +333,8 @@ class TestCleanupExpiredSessions:
"""Test cleanup when no sessions are expired.""" """Test cleanup when no sessions are expired."""
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
# Create only active sessions using CRUD # Create only active sessions using repository
from app.crud.session import session as session_crud from app.repositories.session import session_repo as session_repo
from app.schemas.sessions import SessionCreate from app.schemas.sessions import SessionCreate
async with SessionLocal() as db: async with SessionLocal() as db:
@@ -347,7 +347,7 @@ class TestCleanupExpiredSessions:
expires_at=datetime.now(UTC) + timedelta(days=7), expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC), last_used_at=datetime.now(UTC),
) )
await session_crud.create_session(db, obj_in=a1_data) await session_repo.create_session(db, obj_in=a1_data)
await db.commit() await db.commit()
response = await client.delete( response = await client.delete(
@@ -384,7 +384,7 @@ class TestSessionsAdditionalCases:
# Create multiple sessions # Create multiple sessions
async with SessionLocal() as session: async with SessionLocal() as session:
from app.crud.session import session as session_crud from app.repositories.session import session_repo as session_repo
from app.schemas.sessions import SessionCreate from app.schemas.sessions import SessionCreate
for i in range(5): for i in range(5):
@@ -397,7 +397,7 @@ class TestSessionsAdditionalCases:
expires_at=datetime.now(UTC) + timedelta(days=7), expires_at=datetime.now(UTC) + timedelta(days=7),
last_used_at=datetime.now(UTC), last_used_at=datetime.now(UTC),
) )
await session_crud.create_session(session, obj_in=session_data) await session_repo.create_session(session, obj_in=session_data)
await session.commit() await session.commit()
response = await client.get( response = await client.get(
@@ -431,7 +431,7 @@ class TestSessionsAdditionalCases:
"""Test cleanup with mix of active/inactive and expired/not-expired sessions.""" """Test cleanup with mix of active/inactive and expired/not-expired sessions."""
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
from app.crud.session import session as session_crud from app.repositories.session import session_repo as session_repo
from app.schemas.sessions import SessionCreate from app.schemas.sessions import SessionCreate
async with SessionLocal() as db: async with SessionLocal() as db:
@@ -445,7 +445,7 @@ class TestSessionsAdditionalCases:
expires_at=datetime.now(UTC) - timedelta(days=1), expires_at=datetime.now(UTC) - timedelta(days=1),
last_used_at=datetime.now(UTC) - timedelta(days=2), last_used_at=datetime.now(UTC) - timedelta(days=2),
) )
e1 = await session_crud.create_session(db, obj_in=e1_data) e1 = await session_repo.create_session(db, obj_in=e1_data)
e1.is_active = False e1.is_active = False
db.add(e1) db.add(e1)
@@ -459,7 +459,7 @@ class TestSessionsAdditionalCases:
expires_at=datetime.now(UTC) - timedelta(hours=1), expires_at=datetime.now(UTC) - timedelta(hours=1),
last_used_at=datetime.now(UTC) - timedelta(hours=2), last_used_at=datetime.now(UTC) - timedelta(hours=2),
) )
await session_crud.create_session(db, obj_in=e2_data) await session_repo.create_session(db, obj_in=e2_data)
await db.commit() await db.commit()
@@ -502,10 +502,10 @@ class TestSessionExceptionHandlers:
"""Test list_sessions handles database errors (covers lines 104-106).""" """Test list_sessions handles database errors (covers lines 104-106)."""
from unittest.mock import patch from unittest.mock import patch
from app.crud import session as session_module from app.repositories import session as session_module
with patch.object( with patch.object(
session_module.session, session_module.session_repo,
"get_user_sessions", "get_user_sessions",
side_effect=Exception("Database error"), side_effect=Exception("Database error"),
): ):
@@ -527,10 +527,10 @@ class TestSessionExceptionHandlers:
from unittest.mock import patch from unittest.mock import patch
from uuid import uuid4 from uuid import uuid4
from app.crud import session as session_module from app.repositories import session as session_module
# First create a session to revoke # First create a session to revoke
from app.crud.session import session as session_crud from app.repositories.session import session_repo as session_repo
from app.schemas.sessions import SessionCreate from app.schemas.sessions import SessionCreate
_test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
@@ -545,12 +545,12 @@ class TestSessionExceptionHandlers:
last_used_at=datetime.now(UTC), last_used_at=datetime.now(UTC),
expires_at=datetime.now(UTC) + timedelta(days=60), expires_at=datetime.now(UTC) + timedelta(days=60),
) )
user_session = await session_crud.create_session(db, obj_in=session_in) user_session = await session_repo.create_session(db, obj_in=session_in)
session_id = user_session.id session_id = user_session.id
# Mock the deactivate method to raise an exception # Mock the deactivate method to raise an exception
with patch.object( with patch.object(
session_module.session, session_module.session_repo,
"deactivate", "deactivate",
side_effect=Exception("Database connection lost"), side_effect=Exception("Database connection lost"),
): ):
@@ -568,10 +568,10 @@ class TestSessionExceptionHandlers:
"""Test cleanup_expired_sessions handles database errors (covers lines 233-236).""" """Test cleanup_expired_sessions handles database errors (covers lines 233-236)."""
from unittest.mock import patch from unittest.mock import patch
from app.crud import session as session_module from app.repositories import session as session_module
with patch.object( with patch.object(
session_module.session, session_module.session_repo,
"cleanup_expired_for_user", "cleanup_expired_for_user",
side_effect=Exception("Cleanup failed"), side_effect=Exception("Cleanup failed"),
): ):

View File

@@ -157,7 +157,7 @@ class TestListUsers:
response = await client.get("/api/v1/users") response = await client.get("/api/v1/users")
assert response.status_code == status.HTTP_401_UNAUTHORIZED assert response.status_code == status.HTTP_401_UNAUTHORIZED
# Note: Removed test_list_users_unexpected_error because mocking at CRUD level # Note: Removed test_list_users_unexpected_error because mocking at repository level
# causes the exception to be raised before FastAPI can handle it properly # causes the exception to be raised before FastAPI can handle it properly

View File

@@ -99,7 +99,8 @@ class TestUpdateCurrentUser:
from unittest.mock import patch from unittest.mock import patch
with patch( with patch(
"app.api.routes.users.user_crud.update", side_effect=Exception("DB error") "app.api.routes.users.user_service.update_user",
side_effect=Exception("DB error"),
): ):
with pytest.raises(Exception): with pytest.raises(Exception):
await client.patch( await client.patch(
@@ -134,7 +135,7 @@ class TestUpdateCurrentUser:
from unittest.mock import patch from unittest.mock import patch
with patch( with patch(
"app.api.routes.users.user_crud.update", "app.api.routes.users.user_service.update_user",
side_effect=ValueError("Invalid value"), side_effect=ValueError("Invalid value"),
): ):
with pytest.raises(ValueError): with pytest.raises(ValueError):
@@ -224,7 +225,8 @@ class TestUpdateUserById:
from unittest.mock import patch from unittest.mock import patch
with patch( with patch(
"app.api.routes.users.user_crud.update", side_effect=ValueError("Invalid") "app.api.routes.users.user_service.update_user",
side_effect=ValueError("Invalid"),
): ):
with pytest.raises(ValueError): with pytest.raises(ValueError):
await client.patch( await client.patch(
@@ -241,7 +243,8 @@ class TestUpdateUserById:
from unittest.mock import patch from unittest.mock import patch
with patch( with patch(
"app.api.routes.users.user_crud.update", side_effect=Exception("Unexpected") "app.api.routes.users.user_service.update_user",
side_effect=Exception("Unexpected"),
): ):
with pytest.raises(Exception): with pytest.raises(Exception):
await client.patch( await client.patch(
@@ -354,7 +357,7 @@ class TestDeleteUserById:
from unittest.mock import patch from unittest.mock import patch
with patch( with patch(
"app.api.routes.users.user_crud.soft_delete", "app.api.routes.users.user_service.soft_delete_user",
side_effect=ValueError("Cannot delete"), side_effect=ValueError("Cannot delete"),
): ):
with pytest.raises(ValueError): with pytest.raises(ValueError):
@@ -371,7 +374,7 @@ class TestDeleteUserById:
from unittest.mock import patch from unittest.mock import patch
with patch( with patch(
"app.api.routes.users.user_crud.soft_delete", "app.api.routes.users.user_service.soft_delete_user",
side_effect=Exception("Unexpected"), side_effect=Exception("Unexpected"),
): ):
with pytest.raises(Exception): with pytest.raises(Exception):

View File

@@ -0,0 +1,327 @@
"""
Performance Benchmark Tests.
These tests establish baseline performance metrics for critical API endpoints
and core operations, detecting regressions when response times degrade.
Usage:
make benchmark # Run benchmarks and save baseline
make benchmark-check # Run benchmarks and compare against saved baseline
Baselines are stored in .benchmarks/ and should be committed to version control
so CI can detect performance regressions across commits.
"""
import time
import uuid
from unittest.mock import patch
import pytest
import pytest_asyncio
from fastapi.testclient import TestClient
from app.core.auth import (
create_access_token,
create_refresh_token,
decode_token,
get_password_hash,
verify_password,
)
from app.main import app
pytestmark = [pytest.mark.benchmark]
# Pre-computed hash for sync benchmarks (avoids hashing in every iteration)
_BENCH_PASSWORD = "BenchPass123!"
_BENCH_HASH = get_password_hash(_BENCH_PASSWORD)
# =============================================================================
# Fixtures
# =============================================================================
@pytest.fixture
def sync_client():
"""Create a FastAPI test client with mocked database for stateless endpoints."""
with patch("app.main.check_database_health") as mock_health_check:
mock_health_check.return_value = True
yield TestClient(app)
# =============================================================================
# Stateless Endpoint Benchmarks (no DB required)
# =============================================================================
def test_health_endpoint_performance(sync_client, benchmark):
"""Benchmark: GET /health should respond within acceptable latency."""
result = benchmark(sync_client.get, "/health")
assert result.status_code == 200
def test_openapi_schema_performance(sync_client, benchmark):
"""Benchmark: OpenAPI schema generation should not regress."""
result = benchmark(sync_client.get, "/api/v1/openapi.json")
assert result.status_code == 200
# =============================================================================
# Core Crypto & Token Benchmarks (no DB required)
#
# These benchmark the CPU-intensive operations that underpin auth:
# password hashing, verification, and JWT creation/decoding.
# =============================================================================
def test_password_hashing_performance(benchmark):
"""Benchmark: bcrypt password hashing (CPU-bound, ~100ms expected)."""
result = benchmark(get_password_hash, _BENCH_PASSWORD)
assert result.startswith("$2b$")
def test_password_verification_performance(benchmark):
"""Benchmark: bcrypt password verification against a known hash."""
result = benchmark(verify_password, _BENCH_PASSWORD, _BENCH_HASH)
assert result is True
def test_access_token_creation_performance(benchmark):
"""Benchmark: JWT access token generation."""
user_id = str(uuid.uuid4())
token = benchmark(create_access_token, user_id)
assert isinstance(token, str)
assert len(token) > 0
def test_refresh_token_creation_performance(benchmark):
"""Benchmark: JWT refresh token generation."""
user_id = str(uuid.uuid4())
token = benchmark(create_refresh_token, user_id)
assert isinstance(token, str)
assert len(token) > 0
def test_token_decode_performance(benchmark):
"""Benchmark: JWT token decoding and validation."""
user_id = str(uuid.uuid4())
token = create_access_token(user_id)
payload = benchmark(decode_token, token, "access")
assert payload.sub == user_id
# =============================================================================
# Database-dependent Endpoint Benchmarks (async, manual timing)
#
# pytest-benchmark does not support async functions natively. These tests
# measure latency manually and assert against a maximum threshold (in ms)
# to catch performance regressions.
# =============================================================================
MAX_LOGIN_MS = 500
MAX_GET_USER_MS = 200
MAX_REGISTER_MS = 500
MAX_TOKEN_REFRESH_MS = 200
MAX_SESSIONS_LIST_MS = 200
MAX_USER_UPDATE_MS = 200
@pytest_asyncio.fixture
async def bench_user(async_test_db):
"""Create a test user for benchmark tests."""
from app.models.user import User
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user = User(
id=uuid.uuid4(),
email="bench@example.com",
password_hash=get_password_hash("BenchPass123!"),
first_name="Bench",
last_name="User",
is_active=True,
is_superuser=False,
)
session.add(user)
await session.commit()
await session.refresh(user)
return user
@pytest_asyncio.fixture
async def bench_token(client, bench_user):
"""Get an auth token for the benchmark user."""
response = await client.post(
"/api/v1/auth/login",
json={"email": "bench@example.com", "password": "BenchPass123!"},
)
assert response.status_code == 200, f"Login failed: {response.text}"
return response.json()["access_token"]
@pytest_asyncio.fixture
async def bench_refresh_token(client, bench_user):
"""Get a refresh token for the benchmark user."""
response = await client.post(
"/api/v1/auth/login",
json={"email": "bench@example.com", "password": "BenchPass123!"},
)
assert response.status_code == 200, f"Login failed: {response.text}"
return response.json()["refresh_token"]
@pytest.mark.asyncio
async def test_login_latency(client, bench_user):
"""Performance: POST /api/v1/auth/login must respond under threshold."""
iterations = 5
total_ms = 0.0
for _ in range(iterations):
start = time.perf_counter()
response = await client.post(
"/api/v1/auth/login",
json={"email": "bench@example.com", "password": "BenchPass123!"},
)
elapsed_ms = (time.perf_counter() - start) * 1000
total_ms += elapsed_ms
assert response.status_code == 200
mean_ms = total_ms / iterations
print(f"\n Login mean latency: {mean_ms:.1f}ms (threshold: {MAX_LOGIN_MS}ms)")
assert mean_ms < MAX_LOGIN_MS, (
f"Login latency regression: {mean_ms:.1f}ms exceeds {MAX_LOGIN_MS}ms threshold"
)
@pytest.mark.asyncio
async def test_get_current_user_latency(client, bench_token):
"""Performance: GET /api/v1/users/me must respond under threshold."""
iterations = 10
total_ms = 0.0
for _ in range(iterations):
start = time.perf_counter()
response = await client.get(
"/api/v1/users/me",
headers={"Authorization": f"Bearer {bench_token}"},
)
elapsed_ms = (time.perf_counter() - start) * 1000
total_ms += elapsed_ms
assert response.status_code == 200
mean_ms = total_ms / iterations
print(
f"\n Get user mean latency: {mean_ms:.1f}ms (threshold: {MAX_GET_USER_MS}ms)"
)
assert mean_ms < MAX_GET_USER_MS, (
f"Get user latency regression: {mean_ms:.1f}ms exceeds {MAX_GET_USER_MS}ms threshold"
)
@pytest.mark.asyncio
async def test_register_latency(client):
"""Performance: POST /api/v1/auth/register must respond under threshold."""
iterations = 3
total_ms = 0.0
for i in range(iterations):
start = time.perf_counter()
response = await client.post(
"/api/v1/auth/register",
json={
"email": f"benchreg{i}@example.com",
"password": "BenchRegPass123!",
"first_name": "Bench",
"last_name": "Register",
},
)
elapsed_ms = (time.perf_counter() - start) * 1000
total_ms += elapsed_ms
assert response.status_code == 201, f"Register failed: {response.text}"
mean_ms = total_ms / iterations
print(
f"\n Register mean latency: {mean_ms:.1f}ms (threshold: {MAX_REGISTER_MS}ms)"
)
assert mean_ms < MAX_REGISTER_MS, (
f"Register latency regression: {mean_ms:.1f}ms exceeds {MAX_REGISTER_MS}ms threshold"
)
@pytest.mark.asyncio
async def test_token_refresh_latency(client, bench_refresh_token):
"""Performance: POST /api/v1/auth/refresh must respond under threshold."""
iterations = 5
total_ms = 0.0
for _ in range(iterations):
start = time.perf_counter()
response = await client.post(
"/api/v1/auth/refresh",
json={"refresh_token": bench_refresh_token},
)
elapsed_ms = (time.perf_counter() - start) * 1000
total_ms += elapsed_ms
assert response.status_code == 200, f"Refresh failed: {response.text}"
# Use the new refresh token for the next iteration
bench_refresh_token = response.json()["refresh_token"]
mean_ms = total_ms / iterations
print(
f"\n Token refresh mean latency: {mean_ms:.1f}ms (threshold: {MAX_TOKEN_REFRESH_MS}ms)"
)
assert mean_ms < MAX_TOKEN_REFRESH_MS, (
f"Token refresh latency regression: {mean_ms:.1f}ms exceeds {MAX_TOKEN_REFRESH_MS}ms threshold"
)
@pytest.mark.asyncio
async def test_sessions_list_latency(client, bench_token):
"""Performance: GET /api/v1/sessions must respond under threshold."""
iterations = 10
total_ms = 0.0
for _ in range(iterations):
start = time.perf_counter()
response = await client.get(
"/api/v1/sessions/me",
headers={"Authorization": f"Bearer {bench_token}"},
)
elapsed_ms = (time.perf_counter() - start) * 1000
total_ms += elapsed_ms
assert response.status_code == 200
mean_ms = total_ms / iterations
print(
f"\n Sessions list mean latency: {mean_ms:.1f}ms (threshold: {MAX_SESSIONS_LIST_MS}ms)"
)
assert mean_ms < MAX_SESSIONS_LIST_MS, (
f"Sessions list latency regression: {mean_ms:.1f}ms exceeds {MAX_SESSIONS_LIST_MS}ms threshold"
)
@pytest.mark.asyncio
async def test_user_profile_update_latency(client, bench_token):
"""Performance: PATCH /api/v1/users/me must respond under threshold."""
iterations = 5
total_ms = 0.0
for i in range(iterations):
start = time.perf_counter()
response = await client.patch(
"/api/v1/users/me",
headers={"Authorization": f"Bearer {bench_token}"},
json={"first_name": f"Bench{i}"},
)
elapsed_ms = (time.perf_counter() - start) * 1000
total_ms += elapsed_ms
assert response.status_code == 200, f"Update failed: {response.text}"
mean_ms = total_ms / iterations
print(
f"\n User update mean latency: {mean_ms:.1f}ms (threshold: {MAX_USER_UPDATE_MS}ms)"
)
assert mean_ms < MAX_USER_UPDATE_MS, (
f"User update latency regression: {mean_ms:.1f}ms exceeds {MAX_USER_UPDATE_MS}ms threshold"
)

View File

@@ -2,8 +2,8 @@
import uuid import uuid
from datetime import UTC, datetime, timedelta from datetime import UTC, datetime, timedelta
import jwt
import pytest import pytest
from jose import jwt
from app.core.auth import ( from app.core.auth import (
TokenExpiredError, TokenExpiredError,
@@ -215,6 +215,7 @@ class TestTokenDecoding:
payload = { payload = {
"sub": 123, # sub should be a string, not an integer "sub": 123, # sub should be a string, not an integer
"exp": int((now + timedelta(minutes=30)).timestamp()), "exp": int((now + timedelta(minutes=30)).timestamp()),
"iat": int(now.timestamp()),
} }
token = jwt.encode(payload, settings.SECRET_KEY, algorithm=settings.ALGORITHM) token = jwt.encode(payload, settings.SECRET_KEY, algorithm=settings.ALGORITHM)

View File

@@ -9,8 +9,8 @@ Critical security tests covering:
These tests cover critical security vulnerabilities that could be exploited. These tests cover critical security vulnerabilities that could be exploited.
""" """
import jwt
import pytest import pytest
from jose import jwt
from app.core.auth import TokenInvalidError, create_access_token, decode_token from app.core.auth import TokenInvalidError, create_access_token, decode_token
from app.core.config import settings from app.core.config import settings
@@ -38,8 +38,8 @@ class TestJWTAlgorithmSecurityAttacks:
Attacker creates a token with "alg: none" to bypass signature verification. Attacker creates a token with "alg: none" to bypass signature verification.
NOTE: Lines 209 and 212 in auth.py are DEFENSIVE CODE that's never reached NOTE: Lines 209 and 212 in auth.py are DEFENSIVE CODE that's never reached
because python-jose library rejects "none" algorithm tokens BEFORE we get there. because PyJWT rejects "none" algorithm tokens BEFORE we get there.
This is good for security! The library throws JWTError which becomes TokenInvalidError. This is good for security! The library throws InvalidTokenError which becomes TokenInvalidError.
This test verifies the overall protection works, even though our defensive This test verifies the overall protection works, even though our defensive
checks at lines 209-212 don't execute because the library catches it first. checks at lines 209-212 don't execute because the library catches it first.
@@ -108,36 +108,33 @@ class TestJWTAlgorithmSecurityAttacks:
Test that tokens with wrong algorithm are rejected. Test that tokens with wrong algorithm are rejected.
Attack Scenario: Attack Scenario:
Attacker changes algorithm from HS256 to RS256, attempting to use Attacker changes the "alg" header to RS256 while keeping an HMAC
the public key as the HMAC secret. This could allow token forgery. signature, attempting algorithm confusion to forge tokens.
Reference: https://www.nccgroup.com/us/about-us/newsroom-and-events/blog/2019/january/jwt-algorithm-confusion/ Reference: https://www.nccgroup.com/us/about-us/newsroom-and-events/blog/2019/january/jwt-algorithm-confusion/
NOTE: Like the "none" algorithm test, python-jose library catches this
before our defensive checks at line 212. This is good for security!
""" """
import base64
import json
import time import time
now = int(time.time()) now = int(time.time())
# Create a valid payload
payload = {"sub": "user123", "exp": now + 3600, "iat": now, "type": "access"} payload = {"sub": "user123", "exp": now + 3600, "iat": now, "type": "access"}
# Encode with wrong algorithm (RS256 instead of HS256) # Hand-craft a token claiming RS256 in the header — PyJWT cannot encode
# This simulates an attacker trying algorithm substitution # RS256 with an HMAC key, so we craft the header manually (same technique
wrong_algorithm = "RS256" if settings.ALGORITHM == "HS256" else "HS256" # as the "alg: none" tests) to produce a token that actually reaches decode_token.
header = {"alg": "RS256", "typ": "JWT"}
header_encoded = (
base64.urlsafe_b64encode(json.dumps(header).encode()).decode().rstrip("=")
)
payload_encoded = (
base64.urlsafe_b64encode(json.dumps(payload).encode()).decode().rstrip("=")
)
# Attach a fake signature to form a complete (but invalid) JWT
malicious_token = f"{header_encoded}.{payload_encoded}.fakesignature"
try: with pytest.raises(TokenInvalidError):
malicious_token = jwt.encode( decode_token(malicious_token)
payload, settings.SECRET_KEY, algorithm=wrong_algorithm
)
# Should reject the token (library catches mismatch)
with pytest.raises(TokenInvalidError):
decode_token(malicious_token)
except Exception:
# If encoding fails, that's also acceptable (library protection)
pass
def test_reject_hs384_when_hs256_expected(self): def test_reject_hs384_when_hs256_expected(self):
""" """
@@ -151,17 +148,11 @@ class TestJWTAlgorithmSecurityAttacks:
payload = {"sub": "user123", "exp": now + 3600, "iat": now, "type": "access"} payload = {"sub": "user123", "exp": now + 3600, "iat": now, "type": "access"}
# Create token with HS384 instead of HS256 # Create token with HS384 instead of HS256 (HMAC key works with HS384)
try: malicious_token = jwt.encode(payload, settings.SECRET_KEY, algorithm="HS384")
malicious_token = jwt.encode(
payload, settings.SECRET_KEY, algorithm="HS384"
)
with pytest.raises(TokenInvalidError): with pytest.raises(TokenInvalidError):
decode_token(malicious_token) decode_token(malicious_token)
except Exception:
# If encoding fails, that's also fine
pass
def test_valid_token_with_correct_algorithm_accepted(self): def test_valid_token_with_correct_algorithm_accepted(self):
""" """

View File

@@ -46,7 +46,7 @@ async def login_user(client, email: str, password: str = "SecurePassword123!"):
async def create_superuser(e2e_db_session, email: str, password: str): async def create_superuser(e2e_db_session, email: str, password: str):
"""Create a superuser directly in the database.""" """Create a superuser directly in the database."""
from app.crud.user import user as user_crud from app.repositories.user import user_repo as user_repo
from app.schemas.users import UserCreate from app.schemas.users import UserCreate
user_in = UserCreate( user_in = UserCreate(
@@ -56,7 +56,7 @@ async def create_superuser(e2e_db_session, email: str, password: str):
last_name="User", last_name="User",
is_superuser=True, is_superuser=True,
) )
user = await user_crud.create(e2e_db_session, obj_in=user_in) user = await user_repo.create(e2e_db_session, obj_in=user_in)
return user return user

View File

@@ -27,13 +27,16 @@ except ImportError:
pytestmark = [ pytestmark = [
pytest.mark.e2e, pytest.mark.e2e,
pytest.mark.schemathesis, pytest.mark.schemathesis,
pytest.mark.skipif(
not SCHEMATHESIS_AVAILABLE,
reason="schemathesis not installed - run: make install-e2e",
),
] ]
if not SCHEMATHESIS_AVAILABLE:
def test_schemathesis_compatibility():
"""Gracefully handle missing schemathesis dependency."""
pytest.skip("schemathesis not installed - run: make install-e2e")
if SCHEMATHESIS_AVAILABLE: if SCHEMATHESIS_AVAILABLE:
from app.main import app from app.main import app

View File

@@ -46,7 +46,7 @@ async def register_and_login(client, email: str, password: str = "SecurePassword
async def create_superuser_and_login(client, db_session): async def create_superuser_and_login(client, db_session):
"""Helper to create a superuser directly in DB and login.""" """Helper to create a superuser directly in DB and login."""
from app.crud.user import user as user_crud from app.repositories.user import user_repo as user_repo
from app.schemas.users import UserCreate from app.schemas.users import UserCreate
email = f"admin-{uuid4().hex[:8]}@example.com" email = f"admin-{uuid4().hex[:8]}@example.com"
@@ -60,7 +60,7 @@ async def create_superuser_and_login(client, db_session):
last_name="User", last_name="User",
is_superuser=True, is_superuser=True,
) )
await user_crud.create(db_session, obj_in=user_in) await user_repo.create(db_session, obj_in=user_in)
# Login # Login
login_resp = await client.post( login_resp = await client.post(

View File

View File

@@ -1,6 +1,6 @@
# tests/crud/test_base.py # tests/repositories/test_base.py
""" """
Comprehensive tests for CRUDBase class covering all error paths and edge cases. Comprehensive tests for BaseRepository class covering all error paths and edge cases.
""" """
from datetime import UTC from datetime import UTC
@@ -11,11 +11,16 @@ import pytest
from sqlalchemy.exc import DataError, IntegrityError, OperationalError from sqlalchemy.exc import DataError, IntegrityError, OperationalError
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload
from app.crud.user import user as user_crud from app.core.repository_exceptions import (
DuplicateEntryError,
IntegrityConstraintError,
InvalidInputError,
)
from app.repositories.user import user_repo as user_repo
from app.schemas.users import UserCreate, UserUpdate from app.schemas.users import UserCreate, UserUpdate
class TestCRUDBaseGet: class TestRepositoryBaseGet:
"""Tests for get method covering UUID validation and options.""" """Tests for get method covering UUID validation and options."""
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -24,7 +29,7 @@ class TestCRUDBaseGet:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
result = await user_crud.get(session, id="invalid-uuid") result = await user_repo.get(session, id="invalid-uuid")
assert result is None assert result is None
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -33,7 +38,7 @@ class TestCRUDBaseGet:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
result = await user_crud.get(session, id=12345) # int instead of UUID result = await user_repo.get(session, id=12345) # int instead of UUID
assert result is None assert result is None
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -43,7 +48,7 @@ class TestCRUDBaseGet:
async with SessionLocal() as session: async with SessionLocal() as session:
# Pass UUID object directly # Pass UUID object directly
result = await user_crud.get(session, id=async_test_user.id) result = await user_repo.get(session, id=async_test_user.id)
assert result is not None assert result is not None
assert result.id == async_test_user.id assert result.id == async_test_user.id
@@ -55,7 +60,7 @@ class TestCRUDBaseGet:
async with SessionLocal() as session: async with SessionLocal() as session:
# Test that options parameter is accepted and doesn't error # Test that options parameter is accepted and doesn't error
# We pass an empty list which still tests the code path # We pass an empty list which still tests the code path
result = await user_crud.get( result = await user_repo.get(
session, id=str(async_test_user.id), options=[] session, id=str(async_test_user.id), options=[]
) )
assert result is not None assert result is not None
@@ -69,10 +74,10 @@ class TestCRUDBaseGet:
# Mock execute to raise an exception # Mock execute to raise an exception
with patch.object(session, "execute", side_effect=Exception("DB error")): with patch.object(session, "execute", side_effect=Exception("DB error")):
with pytest.raises(Exception, match="DB error"): with pytest.raises(Exception, match="DB error"):
await user_crud.get(session, id=str(uuid4())) await user_repo.get(session, id=str(uuid4()))
class TestCRUDBaseGetMulti: class TestRepositoryBaseGetMulti:
"""Tests for get_multi method covering pagination validation and options.""" """Tests for get_multi method covering pagination validation and options."""
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -81,8 +86,8 @@ class TestCRUDBaseGetMulti:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
with pytest.raises(ValueError, match="skip must be non-negative"): with pytest.raises(InvalidInputError, match="skip must be non-negative"):
await user_crud.get_multi(session, skip=-1) await user_repo.get_multi(session, skip=-1)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_multi_negative_limit(self, async_test_db): async def test_get_multi_negative_limit(self, async_test_db):
@@ -90,8 +95,8 @@ class TestCRUDBaseGetMulti:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
with pytest.raises(ValueError, match="limit must be non-negative"): with pytest.raises(InvalidInputError, match="limit must be non-negative"):
await user_crud.get_multi(session, limit=-1) await user_repo.get_multi(session, limit=-1)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_multi_limit_too_large(self, async_test_db): async def test_get_multi_limit_too_large(self, async_test_db):
@@ -99,8 +104,8 @@ class TestCRUDBaseGetMulti:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
with pytest.raises(ValueError, match="Maximum limit is 1000"): with pytest.raises(InvalidInputError, match="Maximum limit is 1000"):
await user_crud.get_multi(session, limit=1001) await user_repo.get_multi(session, limit=1001)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_multi_with_options(self, async_test_db, async_test_user): async def test_get_multi_with_options(self, async_test_db, async_test_user):
@@ -109,7 +114,7 @@ class TestCRUDBaseGetMulti:
async with SessionLocal() as session: async with SessionLocal() as session:
# Test that options parameter is accepted # Test that options parameter is accepted
results = await user_crud.get_multi(session, skip=0, limit=10, options=[]) results = await user_repo.get_multi(session, skip=0, limit=10, options=[])
assert isinstance(results, list) assert isinstance(results, list)
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -120,10 +125,10 @@ class TestCRUDBaseGetMulti:
async with SessionLocal() as session: async with SessionLocal() as session:
with patch.object(session, "execute", side_effect=Exception("DB error")): with patch.object(session, "execute", side_effect=Exception("DB error")):
with pytest.raises(Exception, match="DB error"): with pytest.raises(Exception, match="DB error"):
await user_crud.get_multi(session) await user_repo.get_multi(session)
class TestCRUDBaseCreate: class TestRepositoryBaseCreate:
"""Tests for create method covering various error conditions.""" """Tests for create method covering various error conditions."""
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -140,8 +145,8 @@ class TestCRUDBaseCreate:
last_name="Duplicate", last_name="Duplicate",
) )
with pytest.raises(ValueError, match="already exists"): with pytest.raises(DuplicateEntryError, match="already exists"):
await user_crud.create(session, obj_in=user_data) await user_repo.create(session, obj_in=user_data)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_integrity_error_non_duplicate(self, async_test_db): async def test_create_integrity_error_non_duplicate(self, async_test_db):
@@ -165,12 +170,14 @@ class TestCRUDBaseCreate:
last_name="User", last_name="User",
) )
with pytest.raises(ValueError, match="Database integrity error"): with pytest.raises(
await user_crud.create(session, obj_in=user_data) DuplicateEntryError, match="Database integrity error"
):
await user_repo.create(session, obj_in=user_data)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_operational_error(self, async_test_db): async def test_create_operational_error(self, async_test_db):
"""Test create with OperationalError (user CRUD catches as generic Exception).""" """Test create with OperationalError (user repository catches as generic Exception)."""
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
@@ -188,13 +195,13 @@ class TestCRUDBaseCreate:
last_name="User", last_name="User",
) )
# User CRUD catches this as generic Exception and re-raises # User repository catches this as generic Exception and re-raises
with pytest.raises(OperationalError): with pytest.raises(OperationalError):
await user_crud.create(session, obj_in=user_data) await user_repo.create(session, obj_in=user_data)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_data_error(self, async_test_db): async def test_create_data_error(self, async_test_db):
"""Test create with DataError (user CRUD catches as generic Exception).""" """Test create with DataError (user repository catches as generic Exception)."""
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
@@ -210,9 +217,9 @@ class TestCRUDBaseCreate:
last_name="User", last_name="User",
) )
# User CRUD catches this as generic Exception and re-raises # User repository catches this as generic Exception and re-raises
with pytest.raises(DataError): with pytest.raises(DataError):
await user_crud.create(session, obj_in=user_data) await user_repo.create(session, obj_in=user_data)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_unexpected_error(self, async_test_db): async def test_create_unexpected_error(self, async_test_db):
@@ -231,10 +238,10 @@ class TestCRUDBaseCreate:
) )
with pytest.raises(RuntimeError, match="Unexpected error"): with pytest.raises(RuntimeError, match="Unexpected error"):
await user_crud.create(session, obj_in=user_data) await user_repo.create(session, obj_in=user_data)
class TestCRUDBaseUpdate: class TestRepositoryBaseUpdate:
"""Tests for update method covering error conditions.""" """Tests for update method covering error conditions."""
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -244,7 +251,7 @@ class TestCRUDBaseUpdate:
# Create another user # Create another user
async with SessionLocal() as session: async with SessionLocal() as session:
from app.crud.user import user as user_crud from app.repositories.user import user_repo as user_repo
user2_data = UserCreate( user2_data = UserCreate(
email="user2@example.com", email="user2@example.com",
@@ -252,12 +259,12 @@ class TestCRUDBaseUpdate:
first_name="User", first_name="User",
last_name="Two", last_name="Two",
) )
user2 = await user_crud.create(session, obj_in=user2_data) user2 = await user_repo.create(session, obj_in=user2_data)
await session.commit() await session.commit()
# Try to update user2 with user1's email # Try to update user2 with user1's email
async with SessionLocal() as session: async with SessionLocal() as session:
user2_obj = await user_crud.get(session, id=str(user2.id)) user2_obj = await user_repo.get(session, id=str(user2.id))
with patch.object( with patch.object(
session, session,
@@ -268,8 +275,8 @@ class TestCRUDBaseUpdate:
): ):
update_data = UserUpdate(email=async_test_user.email) update_data = UserUpdate(email=async_test_user.email)
with pytest.raises(ValueError, match="already exists"): with pytest.raises(DuplicateEntryError, match="already exists"):
await user_crud.update( await user_repo.update(
session, db_obj=user2_obj, obj_in=update_data session, db_obj=user2_obj, obj_in=update_data
) )
@@ -279,10 +286,10 @@ class TestCRUDBaseUpdate:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id)) user = await user_repo.get(session, id=str(async_test_user.id))
# Update with dict (tests lines 164-165) # Update with dict (tests lines 164-165)
updated = await user_crud.update( updated = await user_repo.update(
session, db_obj=user, obj_in={"first_name": "UpdatedName"} session, db_obj=user, obj_in={"first_name": "UpdatedName"}
) )
assert updated.first_name == "UpdatedName" assert updated.first_name == "UpdatedName"
@@ -293,7 +300,7 @@ class TestCRUDBaseUpdate:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id)) user = await user_repo.get(session, id=str(async_test_user.id))
with patch.object( with patch.object(
session, session,
@@ -302,8 +309,10 @@ class TestCRUDBaseUpdate:
"statement", {}, Exception("constraint failed") "statement", {}, Exception("constraint failed")
), ),
): ):
with pytest.raises(ValueError, match="Database integrity error"): with pytest.raises(
await user_crud.update( IntegrityConstraintError, match="Database integrity error"
):
await user_repo.update(
session, db_obj=user, obj_in={"first_name": "Test"} session, db_obj=user, obj_in={"first_name": "Test"}
) )
@@ -313,7 +322,7 @@ class TestCRUDBaseUpdate:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id)) user = await user_repo.get(session, id=str(async_test_user.id))
with patch.object( with patch.object(
session, session,
@@ -322,8 +331,10 @@ class TestCRUDBaseUpdate:
"statement", {}, Exception("connection error") "statement", {}, Exception("connection error")
), ),
): ):
with pytest.raises(ValueError, match="Database operation failed"): with pytest.raises(
await user_crud.update( IntegrityConstraintError, match="Database operation failed"
):
await user_repo.update(
session, db_obj=user, obj_in={"first_name": "Test"} session, db_obj=user, obj_in={"first_name": "Test"}
) )
@@ -333,18 +344,18 @@ class TestCRUDBaseUpdate:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id)) user = await user_repo.get(session, id=str(async_test_user.id))
with patch.object( with patch.object(
session, "commit", side_effect=RuntimeError("Unexpected") session, "commit", side_effect=RuntimeError("Unexpected")
): ):
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
await user_crud.update( await user_repo.update(
session, db_obj=user, obj_in={"first_name": "Test"} session, db_obj=user, obj_in={"first_name": "Test"}
) )
class TestCRUDBaseRemove: class TestRepositoryBaseRemove:
"""Tests for remove method covering UUID validation and error conditions.""" """Tests for remove method covering UUID validation and error conditions."""
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -353,7 +364,7 @@ class TestCRUDBaseRemove:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
result = await user_crud.remove(session, id="invalid-uuid") result = await user_repo.remove(session, id="invalid-uuid")
assert result is None assert result is None
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -369,13 +380,13 @@ class TestCRUDBaseRemove:
first_name="To", first_name="To",
last_name="Delete", last_name="Delete",
) )
user = await user_crud.create(session, obj_in=user_data) user = await user_repo.create(session, obj_in=user_data)
user_id = user.id user_id = user.id
await session.commit() await session.commit()
# Delete with UUID object # Delete with UUID object
async with SessionLocal() as session: async with SessionLocal() as session:
result = await user_crud.remove(session, id=user_id) # UUID object result = await user_repo.remove(session, id=user_id) # UUID object
assert result is not None assert result is not None
assert result.id == user_id assert result.id == user_id
@@ -385,7 +396,7 @@ class TestCRUDBaseRemove:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
result = await user_crud.remove(session, id=str(uuid4())) result = await user_repo.remove(session, id=str(uuid4()))
assert result is None assert result is None
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -403,9 +414,10 @@ class TestCRUDBaseRemove:
), ),
): ):
with pytest.raises( with pytest.raises(
ValueError, match="Cannot delete.*referenced by other records" IntegrityConstraintError,
match="Cannot delete.*referenced by other records",
): ):
await user_crud.remove(session, id=str(async_test_user.id)) await user_repo.remove(session, id=str(async_test_user.id))
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_remove_unexpected_error(self, async_test_db, async_test_user): async def test_remove_unexpected_error(self, async_test_db, async_test_user):
@@ -417,10 +429,10 @@ class TestCRUDBaseRemove:
session, "commit", side_effect=RuntimeError("Unexpected") session, "commit", side_effect=RuntimeError("Unexpected")
): ):
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
await user_crud.remove(session, id=str(async_test_user.id)) await user_repo.remove(session, id=str(async_test_user.id))
class TestCRUDBaseGetMultiWithTotal: class TestRepositoryBaseGetMultiWithTotal:
"""Tests for get_multi_with_total method covering pagination, filtering, sorting.""" """Tests for get_multi_with_total method covering pagination, filtering, sorting."""
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -429,7 +441,7 @@ class TestCRUDBaseGetMultiWithTotal:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
items, total = await user_crud.get_multi_with_total( items, total = await user_repo.get_multi_with_total(
session, skip=0, limit=10 session, skip=0, limit=10
) )
assert isinstance(items, list) assert isinstance(items, list)
@@ -442,8 +454,8 @@ class TestCRUDBaseGetMultiWithTotal:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
with pytest.raises(ValueError, match="skip must be non-negative"): with pytest.raises(InvalidInputError, match="skip must be non-negative"):
await user_crud.get_multi_with_total(session, skip=-1) await user_repo.get_multi_with_total(session, skip=-1)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_multi_with_total_negative_limit(self, async_test_db): async def test_get_multi_with_total_negative_limit(self, async_test_db):
@@ -451,8 +463,8 @@ class TestCRUDBaseGetMultiWithTotal:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
with pytest.raises(ValueError, match="limit must be non-negative"): with pytest.raises(InvalidInputError, match="limit must be non-negative"):
await user_crud.get_multi_with_total(session, limit=-1) await user_repo.get_multi_with_total(session, limit=-1)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_multi_with_total_limit_too_large(self, async_test_db): async def test_get_multi_with_total_limit_too_large(self, async_test_db):
@@ -460,8 +472,8 @@ class TestCRUDBaseGetMultiWithTotal:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
with pytest.raises(ValueError, match="Maximum limit is 1000"): with pytest.raises(InvalidInputError, match="Maximum limit is 1000"):
await user_crud.get_multi_with_total(session, limit=1001) await user_repo.get_multi_with_total(session, limit=1001)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_multi_with_total_with_filters( async def test_get_multi_with_total_with_filters(
@@ -472,7 +484,7 @@ class TestCRUDBaseGetMultiWithTotal:
async with SessionLocal() as session: async with SessionLocal() as session:
filters = {"email": async_test_user.email} filters = {"email": async_test_user.email}
items, total = await user_crud.get_multi_with_total( items, total = await user_repo.get_multi_with_total(
session, filters=filters session, filters=filters
) )
assert total == 1 assert total == 1
@@ -500,12 +512,12 @@ class TestCRUDBaseGetMultiWithTotal:
first_name="ZZZ", first_name="ZZZ",
last_name="User", last_name="User",
) )
await user_crud.create(session, obj_in=user_data1) await user_repo.create(session, obj_in=user_data1)
await user_crud.create(session, obj_in=user_data2) await user_repo.create(session, obj_in=user_data2)
await session.commit() await session.commit()
async with SessionLocal() as session: async with SessionLocal() as session:
items, total = await user_crud.get_multi_with_total( items, total = await user_repo.get_multi_with_total(
session, sort_by="email", sort_order="asc" session, sort_by="email", sort_order="asc"
) )
assert total >= 3 assert total >= 3
@@ -533,12 +545,12 @@ class TestCRUDBaseGetMultiWithTotal:
first_name="CCC", first_name="CCC",
last_name="User", last_name="User",
) )
await user_crud.create(session, obj_in=user_data1) await user_repo.create(session, obj_in=user_data1)
await user_crud.create(session, obj_in=user_data2) await user_repo.create(session, obj_in=user_data2)
await session.commit() await session.commit()
async with SessionLocal() as session: async with SessionLocal() as session:
items, _total = await user_crud.get_multi_with_total( items, _total = await user_repo.get_multi_with_total(
session, sort_by="email", sort_order="desc", limit=1 session, sort_by="email", sort_order="desc", limit=1
) )
assert len(items) == 1 assert len(items) == 1
@@ -558,19 +570,19 @@ class TestCRUDBaseGetMultiWithTotal:
first_name=f"User{i}", first_name=f"User{i}",
last_name="Test", last_name="Test",
) )
await user_crud.create(session, obj_in=user_data) await user_repo.create(session, obj_in=user_data)
await session.commit() await session.commit()
async with SessionLocal() as session: async with SessionLocal() as session:
# Get first page # Get first page
items1, total = await user_crud.get_multi_with_total( items1, total = await user_repo.get_multi_with_total(
session, skip=0, limit=2 session, skip=0, limit=2
) )
assert len(items1) == 2 assert len(items1) == 2
assert total >= 3 assert total >= 3
# Get second page # Get second page
items2, total2 = await user_crud.get_multi_with_total( items2, total2 = await user_repo.get_multi_with_total(
session, skip=2, limit=2 session, skip=2, limit=2
) )
assert len(items2) >= 1 assert len(items2) >= 1
@@ -582,7 +594,7 @@ class TestCRUDBaseGetMultiWithTotal:
assert ids1.isdisjoint(ids2) assert ids1.isdisjoint(ids2)
class TestCRUDBaseCount: class TestRepositoryBaseCount:
"""Tests for count method.""" """Tests for count method."""
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -591,7 +603,7 @@ class TestCRUDBaseCount:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
count = await user_crud.count(session) count = await user_repo.count(session)
assert isinstance(count, int) assert isinstance(count, int)
assert count >= 1 # At least the test user assert count >= 1 # At least the test user
@@ -602,7 +614,7 @@ class TestCRUDBaseCount:
# Create additional users # Create additional users
async with SessionLocal() as session: async with SessionLocal() as session:
initial_count = await user_crud.count(session) initial_count = await user_repo.count(session)
user_data1 = UserCreate( user_data1 = UserCreate(
email="count1@example.com", email="count1@example.com",
@@ -616,12 +628,12 @@ class TestCRUDBaseCount:
first_name="Count", first_name="Count",
last_name="Two", last_name="Two",
) )
await user_crud.create(session, obj_in=user_data1) await user_repo.create(session, obj_in=user_data1)
await user_crud.create(session, obj_in=user_data2) await user_repo.create(session, obj_in=user_data2)
await session.commit() await session.commit()
async with SessionLocal() as session: async with SessionLocal() as session:
new_count = await user_crud.count(session) new_count = await user_repo.count(session)
assert new_count == initial_count + 2 assert new_count == initial_count + 2
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -632,10 +644,10 @@ class TestCRUDBaseCount:
async with SessionLocal() as session: async with SessionLocal() as session:
with patch.object(session, "execute", side_effect=Exception("DB error")): with patch.object(session, "execute", side_effect=Exception("DB error")):
with pytest.raises(Exception, match="DB error"): with pytest.raises(Exception, match="DB error"):
await user_crud.count(session) await user_repo.count(session)
class TestCRUDBaseExists: class TestRepositoryBaseExists:
"""Tests for exists method.""" """Tests for exists method."""
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -644,7 +656,7 @@ class TestCRUDBaseExists:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
result = await user_crud.exists(session, id=str(async_test_user.id)) result = await user_repo.exists(session, id=str(async_test_user.id))
assert result is True assert result is True
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -653,7 +665,7 @@ class TestCRUDBaseExists:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
result = await user_crud.exists(session, id=str(uuid4())) result = await user_repo.exists(session, id=str(uuid4()))
assert result is False assert result is False
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -662,11 +674,11 @@ class TestCRUDBaseExists:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
result = await user_crud.exists(session, id="invalid-uuid") result = await user_repo.exists(session, id="invalid-uuid")
assert result is False assert result is False
class TestCRUDBaseSoftDelete: class TestRepositoryBaseSoftDelete:
"""Tests for soft_delete method.""" """Tests for soft_delete method."""
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -682,13 +694,13 @@ class TestCRUDBaseSoftDelete:
first_name="Soft", first_name="Soft",
last_name="Delete", last_name="Delete",
) )
user = await user_crud.create(session, obj_in=user_data) user = await user_repo.create(session, obj_in=user_data)
user_id = user.id user_id = user.id
await session.commit() await session.commit()
# Soft delete the user # Soft delete the user
async with SessionLocal() as session: async with SessionLocal() as session:
deleted = await user_crud.soft_delete(session, id=str(user_id)) deleted = await user_repo.soft_delete(session, id=str(user_id))
assert deleted is not None assert deleted is not None
assert deleted.deleted_at is not None assert deleted.deleted_at is not None
@@ -698,7 +710,7 @@ class TestCRUDBaseSoftDelete:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
result = await user_crud.soft_delete(session, id="invalid-uuid") result = await user_repo.soft_delete(session, id="invalid-uuid")
assert result is None assert result is None
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -707,7 +719,7 @@ class TestCRUDBaseSoftDelete:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
result = await user_crud.soft_delete(session, id=str(uuid4())) result = await user_repo.soft_delete(session, id=str(uuid4()))
assert result is None assert result is None
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -723,18 +735,18 @@ class TestCRUDBaseSoftDelete:
first_name="Soft", first_name="Soft",
last_name="Delete2", last_name="Delete2",
) )
user = await user_crud.create(session, obj_in=user_data) user = await user_repo.create(session, obj_in=user_data)
user_id = user.id user_id = user.id
await session.commit() await session.commit()
# Soft delete with UUID object # Soft delete with UUID object
async with SessionLocal() as session: async with SessionLocal() as session:
deleted = await user_crud.soft_delete(session, id=user_id) # UUID object deleted = await user_repo.soft_delete(session, id=user_id) # UUID object
assert deleted is not None assert deleted is not None
assert deleted.deleted_at is not None assert deleted.deleted_at is not None
class TestCRUDBaseRestore: class TestRepositoryBaseRestore:
"""Tests for restore method.""" """Tests for restore method."""
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -750,16 +762,16 @@ class TestCRUDBaseRestore:
first_name="Restore", first_name="Restore",
last_name="Test", last_name="Test",
) )
user = await user_crud.create(session, obj_in=user_data) user = await user_repo.create(session, obj_in=user_data)
user_id = user.id user_id = user.id
await session.commit() await session.commit()
async with SessionLocal() as session: async with SessionLocal() as session:
await user_crud.soft_delete(session, id=str(user_id)) await user_repo.soft_delete(session, id=str(user_id))
# Restore the user # Restore the user
async with SessionLocal() as session: async with SessionLocal() as session:
restored = await user_crud.restore(session, id=str(user_id)) restored = await user_repo.restore(session, id=str(user_id))
assert restored is not None assert restored is not None
assert restored.deleted_at is None assert restored.deleted_at is None
@@ -769,7 +781,7 @@ class TestCRUDBaseRestore:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
result = await user_crud.restore(session, id="invalid-uuid") result = await user_repo.restore(session, id="invalid-uuid")
assert result is None assert result is None
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -778,7 +790,7 @@ class TestCRUDBaseRestore:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
result = await user_crud.restore(session, id=str(uuid4())) result = await user_repo.restore(session, id=str(uuid4()))
assert result is None assert result is None
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -788,7 +800,7 @@ class TestCRUDBaseRestore:
async with SessionLocal() as session: async with SessionLocal() as session:
# Try to restore a user that's not deleted # Try to restore a user that's not deleted
result = await user_crud.restore(session, id=str(async_test_user.id)) result = await user_repo.restore(session, id=str(async_test_user.id))
assert result is None assert result is None
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -804,21 +816,21 @@ class TestCRUDBaseRestore:
first_name="Restore", first_name="Restore",
last_name="Test2", last_name="Test2",
) )
user = await user_crud.create(session, obj_in=user_data) user = await user_repo.create(session, obj_in=user_data)
user_id = user.id user_id = user.id
await session.commit() await session.commit()
async with SessionLocal() as session: async with SessionLocal() as session:
await user_crud.soft_delete(session, id=str(user_id)) await user_repo.soft_delete(session, id=str(user_id))
# Restore with UUID object # Restore with UUID object
async with SessionLocal() as session: async with SessionLocal() as session:
restored = await user_crud.restore(session, id=user_id) # UUID object restored = await user_repo.restore(session, id=user_id) # UUID object
assert restored is not None assert restored is not None
assert restored.deleted_at is None assert restored.deleted_at is None
class TestCRUDBasePaginationValidation: class TestRepositoryBasePaginationValidation:
"""Tests for pagination parameter validation (covers lines 254-260).""" """Tests for pagination parameter validation (covers lines 254-260)."""
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -827,8 +839,8 @@ class TestCRUDBasePaginationValidation:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
with pytest.raises(ValueError, match="skip must be non-negative"): with pytest.raises(InvalidInputError, match="skip must be non-negative"):
await user_crud.get_multi_with_total(session, skip=-1, limit=10) await user_repo.get_multi_with_total(session, skip=-1, limit=10)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_multi_with_total_negative_limit(self, async_test_db): async def test_get_multi_with_total_negative_limit(self, async_test_db):
@@ -836,8 +848,8 @@ class TestCRUDBasePaginationValidation:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
with pytest.raises(ValueError, match="limit must be non-negative"): with pytest.raises(InvalidInputError, match="limit must be non-negative"):
await user_crud.get_multi_with_total(session, skip=0, limit=-1) await user_repo.get_multi_with_total(session, skip=0, limit=-1)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_multi_with_total_limit_too_large(self, async_test_db): async def test_get_multi_with_total_limit_too_large(self, async_test_db):
@@ -845,8 +857,8 @@ class TestCRUDBasePaginationValidation:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
with pytest.raises(ValueError, match="Maximum limit is 1000"): with pytest.raises(InvalidInputError, match="Maximum limit is 1000"):
await user_crud.get_multi_with_total(session, skip=0, limit=1001) await user_repo.get_multi_with_total(session, skip=0, limit=1001)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_multi_with_total_with_filters( async def test_get_multi_with_total_with_filters(
@@ -856,7 +868,7 @@ class TestCRUDBasePaginationValidation:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
users, total = await user_crud.get_multi_with_total( users, total = await user_repo.get_multi_with_total(
session, skip=0, limit=10, filters={"is_active": True} session, skip=0, limit=10, filters={"is_active": True}
) )
assert isinstance(users, list) assert isinstance(users, list)
@@ -868,7 +880,7 @@ class TestCRUDBasePaginationValidation:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
users, _total = await user_crud.get_multi_with_total( users, _total = await user_repo.get_multi_with_total(
session, skip=0, limit=10, sort_by="created_at", sort_order="desc" session, skip=0, limit=10, sort_by="created_at", sort_order="desc"
) )
assert isinstance(users, list) assert isinstance(users, list)
@@ -879,13 +891,13 @@ class TestCRUDBasePaginationValidation:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
users, _total = await user_crud.get_multi_with_total( users, _total = await user_repo.get_multi_with_total(
session, skip=0, limit=10, sort_by="created_at", sort_order="asc" session, skip=0, limit=10, sort_by="created_at", sort_order="asc"
) )
assert isinstance(users, list) assert isinstance(users, list)
class TestCRUDBaseModelsWithoutSoftDelete: class TestRepositoryBaseModelsWithoutSoftDelete:
""" """
Test soft_delete and restore on models without deleted_at column. Test soft_delete and restore on models without deleted_at column.
Covers lines 342-343, 383-384 - error handling for unsupported models. Covers lines 342-343, 383-384 - error handling for unsupported models.
@@ -899,8 +911,8 @@ class TestCRUDBaseModelsWithoutSoftDelete:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
# Create an organization (which doesn't have deleted_at) # Create an organization (which doesn't have deleted_at)
from app.crud.organization import organization as org_crud
from app.models.organization import Organization from app.models.organization import Organization
from app.repositories.organization import organization_repo as org_repo
async with SessionLocal() as session: async with SessionLocal() as session:
org = Organization(name="Test Org", slug="test-org") org = Organization(name="Test Org", slug="test-org")
@@ -910,8 +922,10 @@ class TestCRUDBaseModelsWithoutSoftDelete:
# Try to soft delete organization (should fail) # Try to soft delete organization (should fail)
async with SessionLocal() as session: async with SessionLocal() as session:
with pytest.raises(ValueError, match="does not have a deleted_at column"): with pytest.raises(
await org_crud.soft_delete(session, id=str(org_id)) InvalidInputError, match="does not have a deleted_at column"
):
await org_repo.soft_delete(session, id=str(org_id))
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_restore_model_without_deleted_at(self, async_test_db): async def test_restore_model_without_deleted_at(self, async_test_db):
@@ -919,8 +933,8 @@ class TestCRUDBaseModelsWithoutSoftDelete:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
# Create an organization (which doesn't have deleted_at) # Create an organization (which doesn't have deleted_at)
from app.crud.organization import organization as org_crud
from app.models.organization import Organization from app.models.organization import Organization
from app.repositories.organization import organization_repo as org_repo
async with SessionLocal() as session: async with SessionLocal() as session:
org = Organization(name="Restore Test", slug="restore-test") org = Organization(name="Restore Test", slug="restore-test")
@@ -930,11 +944,13 @@ class TestCRUDBaseModelsWithoutSoftDelete:
# Try to restore organization (should fail) # Try to restore organization (should fail)
async with SessionLocal() as session: async with SessionLocal() as session:
with pytest.raises(ValueError, match="does not have a deleted_at column"): with pytest.raises(
await org_crud.restore(session, id=str(org_id)) InvalidInputError, match="does not have a deleted_at column"
):
await org_repo.restore(session, id=str(org_id))
class TestCRUDBaseEagerLoadingWithRealOptions: class TestRepositoryBaseEagerLoadingWithRealOptions:
""" """
Test eager loading with actual SQLAlchemy load options. Test eager loading with actual SQLAlchemy load options.
Covers lines 77-78, 119-120 - options loop execution. Covers lines 77-78, 119-120 - options loop execution.
@@ -950,8 +966,8 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
# Create a session for the user # Create a session for the user
from app.crud.session import session as session_crud
from app.models.user_session import UserSession from app.models.user_session import UserSession
from app.repositories.session import session_repo as session_repo
async with SessionLocal() as session: async with SessionLocal() as session:
user_session = UserSession( user_session = UserSession(
@@ -969,7 +985,7 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
# Get session with eager loading of user relationship # Get session with eager loading of user relationship
async with SessionLocal() as session: async with SessionLocal() as session:
result = await session_crud.get( result = await session_repo.get(
session, session,
id=str(session_id), id=str(session_id),
options=[joinedload(UserSession.user)], # Real option, not empty list options=[joinedload(UserSession.user)], # Real option, not empty list
@@ -989,8 +1005,8 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
# Create multiple sessions for the user # Create multiple sessions for the user
from app.crud.session import session as session_crud
from app.models.user_session import UserSession from app.models.user_session import UserSession
from app.repositories.session import session_repo as session_repo
async with SessionLocal() as session: async with SessionLocal() as session:
for i in range(3): for i in range(3):
@@ -1008,7 +1024,7 @@ class TestCRUDBaseEagerLoadingWithRealOptions:
# Get sessions with eager loading # Get sessions with eager loading
async with SessionLocal() as session: async with SessionLocal() as session:
results = await session_crud.get_multi( results = await session_repo.get_multi(
session, session,
skip=0, skip=0,
limit=10, limit=10,

View File

@@ -1,6 +1,6 @@
# tests/crud/test_base_db_failures.py # tests/repositories/test_base_db_failures.py
""" """
Comprehensive tests for base CRUD database failure scenarios. Comprehensive tests for base repository database failure scenarios.
Tests exception handling, rollbacks, and error messages. Tests exception handling, rollbacks, and error messages.
""" """
@@ -10,16 +10,17 @@ from uuid import uuid4
import pytest import pytest
from sqlalchemy.exc import DataError, OperationalError from sqlalchemy.exc import DataError, OperationalError
from app.crud.user import user as user_crud from app.core.repository_exceptions import IntegrityConstraintError
from app.repositories.user import user_repo as user_repo
from app.schemas.users import UserCreate from app.schemas.users import UserCreate
class TestBaseCRUDCreateFailures: class TestBaseRepositoryCreateFailures:
"""Test base CRUD create method exception handling.""" """Test base repository create method exception handling."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_operational_error_triggers_rollback(self, async_test_db): async def test_create_operational_error_triggers_rollback(self, async_test_db):
"""Test that OperationalError triggers rollback (User CRUD catches as Exception).""" """Test that OperationalError triggers rollback (User repository catches as Exception)."""
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
@@ -40,16 +41,16 @@ class TestBaseCRUDCreateFailures:
last_name="User", last_name="User",
) )
# User CRUD catches this as generic Exception and re-raises # User repository catches this as generic Exception and re-raises
with pytest.raises(OperationalError): with pytest.raises(OperationalError):
await user_crud.create(session, obj_in=user_data) await user_repo.create(session, obj_in=user_data)
# Verify rollback was called # Verify rollback was called
mock_rollback.assert_called_once() mock_rollback.assert_called_once()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_data_error_triggers_rollback(self, async_test_db): async def test_create_data_error_triggers_rollback(self, async_test_db):
"""Test that DataError triggers rollback (User CRUD catches as Exception).""" """Test that DataError triggers rollback (User repository catches as Exception)."""
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
@@ -68,9 +69,9 @@ class TestBaseCRUDCreateFailures:
last_name="User", last_name="User",
) )
# User CRUD catches this as generic Exception and re-raises # User repository catches this as generic Exception and re-raises
with pytest.raises(DataError): with pytest.raises(DataError):
await user_crud.create(session, obj_in=user_data) await user_repo.create(session, obj_in=user_data)
mock_rollback.assert_called_once() mock_rollback.assert_called_once()
@@ -96,13 +97,13 @@ class TestBaseCRUDCreateFailures:
) )
with pytest.raises(RuntimeError, match="Unexpected database error"): with pytest.raises(RuntimeError, match="Unexpected database error"):
await user_crud.create(session, obj_in=user_data) await user_repo.create(session, obj_in=user_data)
mock_rollback.assert_called_once() mock_rollback.assert_called_once()
class TestBaseCRUDUpdateFailures: class TestBaseRepositoryUpdateFailures:
"""Test base CRUD update method exception handling.""" """Test base repository update method exception handling."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_operational_error(self, async_test_db, async_test_user): async def test_update_operational_error(self, async_test_db, async_test_user):
@@ -110,7 +111,7 @@ class TestBaseCRUDUpdateFailures:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id)) user = await user_repo.get(session, id=str(async_test_user.id))
async def mock_commit(): async def mock_commit():
raise OperationalError("Connection timeout", {}, Exception("Timeout")) raise OperationalError("Connection timeout", {}, Exception("Timeout"))
@@ -119,8 +120,10 @@ class TestBaseCRUDUpdateFailures:
with patch.object( with patch.object(
session, "rollback", new_callable=AsyncMock session, "rollback", new_callable=AsyncMock
) as mock_rollback: ) as mock_rollback:
with pytest.raises(ValueError, match="Database operation failed"): with pytest.raises(
await user_crud.update( IntegrityConstraintError, match="Database operation failed"
):
await user_repo.update(
session, db_obj=user, obj_in={"first_name": "Updated"} session, db_obj=user, obj_in={"first_name": "Updated"}
) )
@@ -132,7 +135,7 @@ class TestBaseCRUDUpdateFailures:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id)) user = await user_repo.get(session, id=str(async_test_user.id))
async def mock_commit(): async def mock_commit():
raise DataError("Invalid data", {}, Exception("Data type mismatch")) raise DataError("Invalid data", {}, Exception("Data type mismatch"))
@@ -141,8 +144,10 @@ class TestBaseCRUDUpdateFailures:
with patch.object( with patch.object(
session, "rollback", new_callable=AsyncMock session, "rollback", new_callable=AsyncMock
) as mock_rollback: ) as mock_rollback:
with pytest.raises(ValueError, match="Database operation failed"): with pytest.raises(
await user_crud.update( IntegrityConstraintError, match="Database operation failed"
):
await user_repo.update(
session, db_obj=user, obj_in={"first_name": "Updated"} session, db_obj=user, obj_in={"first_name": "Updated"}
) )
@@ -154,7 +159,7 @@ class TestBaseCRUDUpdateFailures:
_test_engine, SessionLocal = async_test_db _test_engine, SessionLocal = async_test_db
async with SessionLocal() as session: async with SessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id)) user = await user_repo.get(session, id=str(async_test_user.id))
async def mock_commit(): async def mock_commit():
raise KeyError("Unexpected error") raise KeyError("Unexpected error")
@@ -164,15 +169,15 @@ class TestBaseCRUDUpdateFailures:
session, "rollback", new_callable=AsyncMock session, "rollback", new_callable=AsyncMock
) as mock_rollback: ) as mock_rollback:
with pytest.raises(KeyError): with pytest.raises(KeyError):
await user_crud.update( await user_repo.update(
session, db_obj=user, obj_in={"first_name": "Updated"} session, db_obj=user, obj_in={"first_name": "Updated"}
) )
mock_rollback.assert_called_once() mock_rollback.assert_called_once()
class TestBaseCRUDRemoveFailures: class TestBaseRepositoryRemoveFailures:
"""Test base CRUD remove method exception handling.""" """Test base repository remove method exception handling."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_remove_unexpected_error_triggers_rollback( async def test_remove_unexpected_error_triggers_rollback(
@@ -191,12 +196,12 @@ class TestBaseCRUDRemoveFailures:
session, "rollback", new_callable=AsyncMock session, "rollback", new_callable=AsyncMock
) as mock_rollback: ) as mock_rollback:
with pytest.raises(RuntimeError, match="Database write failed"): with pytest.raises(RuntimeError, match="Database write failed"):
await user_crud.remove(session, id=str(async_test_user.id)) await user_repo.remove(session, id=str(async_test_user.id))
mock_rollback.assert_called_once() mock_rollback.assert_called_once()
class TestBaseCRUDGetMultiWithTotalFailures: class TestBaseRepositoryGetMultiWithTotalFailures:
"""Test get_multi_with_total exception handling.""" """Test get_multi_with_total exception handling."""
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -212,10 +217,10 @@ class TestBaseCRUDGetMultiWithTotalFailures:
with patch.object(session, "execute", side_effect=mock_execute): with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError): with pytest.raises(OperationalError):
await user_crud.get_multi_with_total(session, skip=0, limit=10) await user_repo.get_multi_with_total(session, skip=0, limit=10)
class TestBaseCRUDCountFailures: class TestBaseRepositoryCountFailures:
"""Test count method exception handling.""" """Test count method exception handling."""
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -230,10 +235,10 @@ class TestBaseCRUDCountFailures:
with patch.object(session, "execute", side_effect=mock_execute): with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError): with pytest.raises(OperationalError):
await user_crud.count(session) await user_repo.count(session)
class TestBaseCRUDSoftDeleteFailures: class TestBaseRepositorySoftDeleteFailures:
"""Test soft_delete method exception handling.""" """Test soft_delete method exception handling."""
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -253,12 +258,12 @@ class TestBaseCRUDSoftDeleteFailures:
session, "rollback", new_callable=AsyncMock session, "rollback", new_callable=AsyncMock
) as mock_rollback: ) as mock_rollback:
with pytest.raises(RuntimeError, match="Soft delete failed"): with pytest.raises(RuntimeError, match="Soft delete failed"):
await user_crud.soft_delete(session, id=str(async_test_user.id)) await user_repo.soft_delete(session, id=str(async_test_user.id))
mock_rollback.assert_called_once() mock_rollback.assert_called_once()
class TestBaseCRUDRestoreFailures: class TestBaseRepositoryRestoreFailures:
"""Test restore method exception handling.""" """Test restore method exception handling."""
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -274,12 +279,12 @@ class TestBaseCRUDRestoreFailures:
first_name="Restore", first_name="Restore",
last_name="Test", last_name="Test",
) )
user = await user_crud.create(session, obj_in=user_data) user = await user_repo.create(session, obj_in=user_data)
user_id = user.id user_id = user.id
await session.commit() await session.commit()
async with SessionLocal() as session: async with SessionLocal() as session:
await user_crud.soft_delete(session, id=str(user_id)) await user_repo.soft_delete(session, id=str(user_id))
# Now test restore failure # Now test restore failure
async with SessionLocal() as session: async with SessionLocal() as session:
@@ -292,12 +297,12 @@ class TestBaseCRUDRestoreFailures:
session, "rollback", new_callable=AsyncMock session, "rollback", new_callable=AsyncMock
) as mock_rollback: ) as mock_rollback:
with pytest.raises(RuntimeError, match="Restore failed"): with pytest.raises(RuntimeError, match="Restore failed"):
await user_crud.restore(session, id=str(user_id)) await user_repo.restore(session, id=str(user_id))
mock_rollback.assert_called_once() mock_rollback.assert_called_once()
class TestBaseCRUDGetFailures: class TestBaseRepositoryGetFailures:
"""Test get method exception handling.""" """Test get method exception handling."""
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -312,10 +317,10 @@ class TestBaseCRUDGetFailures:
with patch.object(session, "execute", side_effect=mock_execute): with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError): with pytest.raises(OperationalError):
await user_crud.get(session, id=str(uuid4())) await user_repo.get(session, id=str(uuid4()))
class TestBaseCRUDGetMultiFailures: class TestBaseRepositoryGetMultiFailures:
"""Test get_multi method exception handling.""" """Test get_multi method exception handling."""
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -330,4 +335,4 @@ class TestBaseCRUDGetMultiFailures:
with patch.object(session, "execute", side_effect=mock_execute): with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError): with pytest.raises(OperationalError):
await user_crud.get_multi(session, skip=0, limit=10) await user_repo.get_multi(session, skip=0, limit=10)

View File

@@ -1,18 +1,21 @@
# tests/crud/test_oauth.py # tests/repositories/test_oauth.py
""" """
Comprehensive tests for OAuth CRUD operations. Comprehensive tests for OAuth repository operations.
""" """
from datetime import UTC, datetime, timedelta from datetime import UTC, datetime, timedelta
import pytest import pytest
from app.crud.oauth import oauth_account, oauth_client, oauth_state from app.core.repository_exceptions import DuplicateEntryError
from app.repositories.oauth_account import oauth_account_repo as oauth_account
from app.repositories.oauth_client import oauth_client_repo as oauth_client
from app.repositories.oauth_state import oauth_state_repo as oauth_state
from app.schemas.oauth import OAuthAccountCreate, OAuthClientCreate, OAuthStateCreate from app.schemas.oauth import OAuthAccountCreate, OAuthClientCreate, OAuthStateCreate
class TestOAuthAccountCRUD: class TestOAuthAccountRepository:
"""Tests for OAuth account CRUD operations.""" """Tests for OAuth account repository operations."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_account(self, async_test_db, async_test_user): async def test_create_account(self, async_test_db, async_test_user):
@@ -60,7 +63,8 @@ class TestOAuthAccountCRUD:
# SQLite returns different error message than PostgreSQL # SQLite returns different error message than PostgreSQL
with pytest.raises( with pytest.raises(
ValueError, match="(already linked|UNIQUE constraint failed)" DuplicateEntryError,
match="(already linked|UNIQUE constraint failed|Failed to create)",
): ):
await oauth_account.create_account(session, obj_in=account_data2) await oauth_account.create_account(session, obj_in=account_data2)
@@ -256,17 +260,17 @@ class TestOAuthAccountCRUD:
updated = await oauth_account.update_tokens( updated = await oauth_account.update_tokens(
session, session,
account=account, account=account,
access_token_encrypted="new_access_token", access_token="new_access_token",
refresh_token_encrypted="new_refresh_token", refresh_token="new_refresh_token",
token_expires_at=new_expires, token_expires_at=new_expires,
) )
assert updated.access_token_encrypted == "new_access_token" assert updated.access_token == "new_access_token"
assert updated.refresh_token_encrypted == "new_refresh_token" assert updated.refresh_token == "new_refresh_token"
class TestOAuthStateCRUD: class TestOAuthStateRepository:
"""Tests for OAuth state CRUD operations.""" """Tests for OAuth state repository operations."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_state(self, async_test_db): async def test_create_state(self, async_test_db):
@@ -372,8 +376,8 @@ class TestOAuthStateCRUD:
assert result is not None assert result is not None
class TestOAuthClientCRUD: class TestOAuthClientRepository:
"""Tests for OAuth client CRUD operations (provider mode).""" """Tests for OAuth client repository operations (provider mode)."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_public_client(self, async_test_db): async def test_create_public_client(self, async_test_db):

View File

@@ -1,6 +1,6 @@
# tests/crud/test_organization_async.py # tests/repositories/test_organization_async.py
""" """
Comprehensive tests for async organization CRUD operations. Comprehensive tests for async organization repository operations.
""" """
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
@@ -9,9 +9,10 @@ from uuid import uuid4
import pytest import pytest
from sqlalchemy import select from sqlalchemy import select
from app.crud.organization import organization as organization_crud from app.core.repository_exceptions import DuplicateEntryError, IntegrityConstraintError
from app.models.organization import Organization from app.models.organization import Organization
from app.models.user_organization import OrganizationRole, UserOrganization from app.models.user_organization import OrganizationRole, UserOrganization
from app.repositories.organization import organization_repo as organization_repo
from app.schemas.organizations import OrganizationCreate from app.schemas.organizations import OrganizationCreate
@@ -34,7 +35,7 @@ class TestGetBySlug:
# Get by slug # Get by slug
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
result = await organization_crud.get_by_slug(session, slug="test-org") result = await organization_repo.get_by_slug(session, slug="test-org")
assert result is not None assert result is not None
assert result.id == org_id assert result.id == org_id
assert result.slug == "test-org" assert result.slug == "test-org"
@@ -45,7 +46,7 @@ class TestGetBySlug:
_test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
result = await organization_crud.get_by_slug(session, slug="nonexistent") result = await organization_repo.get_by_slug(session, slug="nonexistent")
assert result is None assert result is None
@@ -54,7 +55,7 @@ class TestCreate:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_success(self, async_test_db): async def test_create_success(self, async_test_db):
"""Test successfully creating an organization_crud.""" """Test successfully creating an organization_repo."""
_test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
@@ -65,7 +66,7 @@ class TestCreate:
is_active=True, is_active=True,
settings={"key": "value"}, settings={"key": "value"},
) )
result = await organization_crud.create(session, obj_in=org_in) result = await organization_repo.create(session, obj_in=org_in)
assert result.name == "New Org" assert result.name == "New Org"
assert result.slug == "new-org" assert result.slug == "new-org"
@@ -87,8 +88,8 @@ class TestCreate:
# Try to create second with same slug # Try to create second with same slug
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
org_in = OrganizationCreate(name="Org 2", slug="duplicate-slug") org_in = OrganizationCreate(name="Org 2", slug="duplicate-slug")
with pytest.raises(ValueError, match="already exists"): with pytest.raises(DuplicateEntryError, match="already exists"):
await organization_crud.create(session, obj_in=org_in) await organization_repo.create(session, obj_in=org_in)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_without_settings(self, async_test_db): async def test_create_without_settings(self, async_test_db):
@@ -97,7 +98,7 @@ class TestCreate:
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
org_in = OrganizationCreate(name="No Settings Org", slug="no-settings") org_in = OrganizationCreate(name="No Settings Org", slug="no-settings")
result = await organization_crud.create(session, obj_in=org_in) result = await organization_repo.create(session, obj_in=org_in)
assert result.settings == {} assert result.settings == {}
@@ -118,7 +119,7 @@ class TestGetMultiWithFilters:
await session.commit() await session.commit()
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
orgs, total = await organization_crud.get_multi_with_filters(session) orgs, total = await organization_repo.get_multi_with_filters(session)
assert total == 5 assert total == 5
assert len(orgs) == 5 assert len(orgs) == 5
@@ -134,7 +135,7 @@ class TestGetMultiWithFilters:
await session.commit() await session.commit()
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
orgs, total = await organization_crud.get_multi_with_filters( orgs, total = await organization_repo.get_multi_with_filters(
session, is_active=True session, is_active=True
) )
assert total == 1 assert total == 1
@@ -156,7 +157,7 @@ class TestGetMultiWithFilters:
await session.commit() await session.commit()
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
orgs, total = await organization_crud.get_multi_with_filters( orgs, total = await organization_repo.get_multi_with_filters(
session, search="tech" session, search="tech"
) )
assert total == 1 assert total == 1
@@ -174,7 +175,7 @@ class TestGetMultiWithFilters:
await session.commit() await session.commit()
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
orgs, total = await organization_crud.get_multi_with_filters( orgs, total = await organization_repo.get_multi_with_filters(
session, skip=2, limit=3 session, skip=2, limit=3
) )
assert total == 10 assert total == 10
@@ -192,7 +193,7 @@ class TestGetMultiWithFilters:
await session.commit() await session.commit()
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
orgs, _total = await organization_crud.get_multi_with_filters( orgs, _total = await organization_repo.get_multi_with_filters(
session, sort_by="name", sort_order="asc" session, sort_by="name", sort_order="asc"
) )
assert orgs[0].name == "A Org" assert orgs[0].name == "A Org"
@@ -204,7 +205,7 @@ class TestGetMemberCount:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_member_count_success(self, async_test_db, async_test_user): async def test_get_member_count_success(self, async_test_db, async_test_user):
"""Test getting member count for organization_crud.""" """Test getting member count for organization_repo."""
_test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
@@ -224,7 +225,7 @@ class TestGetMemberCount:
org_id = org.id org_id = org.id
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
count = await organization_crud.get_member_count( count = await organization_repo.get_member_count(
session, organization_id=org_id session, organization_id=org_id
) )
assert count == 1 assert count == 1
@@ -241,7 +242,7 @@ class TestGetMemberCount:
org_id = org.id org_id = org.id
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
count = await organization_crud.get_member_count( count = await organization_repo.get_member_count(
session, organization_id=org_id session, organization_id=org_id
) )
assert count == 0 assert count == 0
@@ -252,7 +253,7 @@ class TestAddUser:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_add_user_success(self, async_test_db, async_test_user): async def test_add_user_success(self, async_test_db, async_test_user):
"""Test successfully adding a user to organization_crud.""" """Test successfully adding a user to organization_repo."""
_test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
@@ -262,7 +263,7 @@ class TestAddUser:
org_id = org.id org_id = org.id
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
result = await organization_crud.add_user( result = await organization_repo.add_user(
session, session,
organization_id=org_id, organization_id=org_id,
user_id=async_test_user.id, user_id=async_test_user.id,
@@ -295,8 +296,8 @@ class TestAddUser:
org_id = org.id org_id = org.id
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
with pytest.raises(ValueError, match="already a member"): with pytest.raises(DuplicateEntryError, match="already a member"):
await organization_crud.add_user( await organization_repo.add_user(
session, organization_id=org_id, user_id=async_test_user.id session, organization_id=org_id, user_id=async_test_user.id
) )
@@ -321,7 +322,7 @@ class TestAddUser:
org_id = org.id org_id = org.id
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
result = await organization_crud.add_user( result = await organization_repo.add_user(
session, session,
organization_id=org_id, organization_id=org_id,
user_id=async_test_user.id, user_id=async_test_user.id,
@@ -337,7 +338,7 @@ class TestRemoveUser:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_remove_user_success(self, async_test_db, async_test_user): async def test_remove_user_success(self, async_test_db, async_test_user):
"""Test successfully removing a user from organization_crud.""" """Test successfully removing a user from organization_repo."""
_test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
@@ -356,7 +357,7 @@ class TestRemoveUser:
org_id = org.id org_id = org.id
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
result = await organization_crud.remove_user( result = await organization_repo.remove_user(
session, organization_id=org_id, user_id=async_test_user.id session, organization_id=org_id, user_id=async_test_user.id
) )
@@ -384,7 +385,7 @@ class TestRemoveUser:
org_id = org.id org_id = org.id
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
result = await organization_crud.remove_user( result = await organization_repo.remove_user(
session, organization_id=org_id, user_id=uuid4() session, organization_id=org_id, user_id=uuid4()
) )
@@ -415,7 +416,7 @@ class TestUpdateUserRole:
org_id = org.id org_id = org.id
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
result = await organization_crud.update_user_role( result = await organization_repo.update_user_role(
session, session,
organization_id=org_id, organization_id=org_id,
user_id=async_test_user.id, user_id=async_test_user.id,
@@ -438,7 +439,7 @@ class TestUpdateUserRole:
org_id = org.id org_id = org.id
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
result = await organization_crud.update_user_role( result = await organization_repo.update_user_role(
session, session,
organization_id=org_id, organization_id=org_id,
user_id=uuid4(), user_id=uuid4(),
@@ -474,7 +475,7 @@ class TestGetOrganizationMembers:
org_id = org.id org_id = org.id
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
members, total = await organization_crud.get_organization_members( members, total = await organization_repo.get_organization_members(
session, organization_id=org_id session, organization_id=org_id
) )
@@ -507,7 +508,7 @@ class TestGetOrganizationMembers:
org_id = org.id org_id = org.id
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
members, total = await organization_crud.get_organization_members( members, total = await organization_repo.get_organization_members(
session, organization_id=org_id, skip=0, limit=10 session, organization_id=org_id, skip=0, limit=10
) )
@@ -538,7 +539,7 @@ class TestGetUserOrganizations:
await session.commit() await session.commit()
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
orgs = await organization_crud.get_user_organizations( orgs = await organization_repo.get_user_organizations(
session, user_id=async_test_user.id session, user_id=async_test_user.id
) )
@@ -574,7 +575,7 @@ class TestGetUserOrganizations:
await session.commit() await session.commit()
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
orgs = await organization_crud.get_user_organizations( orgs = await organization_repo.get_user_organizations(
session, user_id=async_test_user.id, is_active=True session, user_id=async_test_user.id, is_active=True
) )
@@ -587,7 +588,7 @@ class TestGetUserRole:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_user_role_in_org_success(self, async_test_db, async_test_user): async def test_get_user_role_in_org_success(self, async_test_db, async_test_user):
"""Test getting user role in organization_crud.""" """Test getting user role in organization_repo."""
_test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
@@ -606,7 +607,7 @@ class TestGetUserRole:
org_id = org.id org_id = org.id
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
role = await organization_crud.get_user_role_in_org( role = await organization_repo.get_user_role_in_org(
session, user_id=async_test_user.id, organization_id=org_id session, user_id=async_test_user.id, organization_id=org_id
) )
@@ -624,7 +625,7 @@ class TestGetUserRole:
org_id = org.id org_id = org.id
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
role = await organization_crud.get_user_role_in_org( role = await organization_repo.get_user_role_in_org(
session, user_id=uuid4(), organization_id=org_id session, user_id=uuid4(), organization_id=org_id
) )
@@ -655,7 +656,7 @@ class TestIsUserOrgOwner:
org_id = org.id org_id = org.id
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
is_owner = await organization_crud.is_user_org_owner( is_owner = await organization_repo.is_user_org_owner(
session, user_id=async_test_user.id, organization_id=org_id session, user_id=async_test_user.id, organization_id=org_id
) )
@@ -682,7 +683,7 @@ class TestIsUserOrgOwner:
org_id = org.id org_id = org.id
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
is_owner = await organization_crud.is_user_org_owner( is_owner = await organization_repo.is_user_org_owner(
session, user_id=async_test_user.id, organization_id=org_id session, user_id=async_test_user.id, organization_id=org_id
) )
@@ -719,7 +720,7 @@ class TestGetMultiWithMemberCounts:
( (
orgs_with_counts, orgs_with_counts,
total, total,
) = await organization_crud.get_multi_with_member_counts(session) ) = await organization_repo.get_multi_with_member_counts(session)
assert total == 2 assert total == 2
assert len(orgs_with_counts) == 2 assert len(orgs_with_counts) == 2
@@ -744,7 +745,7 @@ class TestGetMultiWithMemberCounts:
( (
orgs_with_counts, orgs_with_counts,
total, total,
) = await organization_crud.get_multi_with_member_counts( ) = await organization_repo.get_multi_with_member_counts(
session, is_active=True session, is_active=True
) )
@@ -766,7 +767,7 @@ class TestGetMultiWithMemberCounts:
( (
orgs_with_counts, orgs_with_counts,
total, total,
) = await organization_crud.get_multi_with_member_counts( ) = await organization_repo.get_multi_with_member_counts(
session, search="tech" session, search="tech"
) )
@@ -800,7 +801,7 @@ class TestGetUserOrganizationsWithDetails:
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
orgs_with_details = ( orgs_with_details = (
await organization_crud.get_user_organizations_with_details( await organization_repo.get_user_organizations_with_details(
session, user_id=async_test_user.id session, user_id=async_test_user.id
) )
) )
@@ -840,7 +841,7 @@ class TestGetUserOrganizationsWithDetails:
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
orgs_with_details = ( orgs_with_details = (
await organization_crud.get_user_organizations_with_details( await organization_repo.get_user_organizations_with_details(
session, user_id=async_test_user.id, is_active=True session, user_id=async_test_user.id, is_active=True
) )
) )
@@ -873,7 +874,7 @@ class TestIsUserOrgAdmin:
org_id = org.id org_id = org.id
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
is_admin = await organization_crud.is_user_org_admin( is_admin = await organization_repo.is_user_org_admin(
session, user_id=async_test_user.id, organization_id=org_id session, user_id=async_test_user.id, organization_id=org_id
) )
@@ -900,7 +901,7 @@ class TestIsUserOrgAdmin:
org_id = org.id org_id = org.id
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
is_admin = await organization_crud.is_user_org_admin( is_admin = await organization_repo.is_user_org_admin(
session, user_id=async_test_user.id, organization_id=org_id session, user_id=async_test_user.id, organization_id=org_id
) )
@@ -927,7 +928,7 @@ class TestIsUserOrgAdmin:
org_id = org.id org_id = org.id
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
is_admin = await organization_crud.is_user_org_admin( is_admin = await organization_repo.is_user_org_admin(
session, user_id=async_test_user.id, organization_id=org_id session, user_id=async_test_user.id, organization_id=org_id
) )
@@ -936,7 +937,7 @@ class TestIsUserOrgAdmin:
class TestOrganizationExceptionHandlers: class TestOrganizationExceptionHandlers:
""" """
Test exception handlers in organization CRUD methods. Test exception handlers in organization repository methods.
Uses mocks to trigger database errors and verify proper error handling. Uses mocks to trigger database errors and verify proper error handling.
Covers lines: 33-35, 57-62, 114-116, 130-132, 207-209, 258-260, 291-294, 326-329, 385-387, 409-411, 466-468, 491-493 Covers lines: 33-35, 57-62, 114-116, 130-132, 207-209, 258-260, 291-294, 326-329, 385-387, 409-411, 466-468, 491-493
""" """
@@ -951,7 +952,7 @@ class TestOrganizationExceptionHandlers:
session, "execute", side_effect=Exception("Database connection lost") session, "execute", side_effect=Exception("Database connection lost")
): ):
with pytest.raises(Exception, match="Database connection lost"): with pytest.raises(Exception, match="Database connection lost"):
await organization_crud.get_by_slug(session, slug="test-slug") await organization_repo.get_by_slug(session, slug="test-slug")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_integrity_error_non_slug(self, async_test_db): async def test_create_integrity_error_non_slug(self, async_test_db):
@@ -972,8 +973,10 @@ class TestOrganizationExceptionHandlers:
with patch.object(session, "commit", side_effect=mock_commit): with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(session, "rollback", new_callable=AsyncMock): with patch.object(session, "rollback", new_callable=AsyncMock):
org_in = OrganizationCreate(name="Test", slug="test") org_in = OrganizationCreate(name="Test", slug="test")
with pytest.raises(ValueError, match="Database integrity error"): with pytest.raises(
await organization_crud.create(session, obj_in=org_in) IntegrityConstraintError, match="Database integrity error"
):
await organization_repo.create(session, obj_in=org_in)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_unexpected_error(self, async_test_db): async def test_create_unexpected_error(self, async_test_db):
@@ -987,7 +990,7 @@ class TestOrganizationExceptionHandlers:
with patch.object(session, "rollback", new_callable=AsyncMock): with patch.object(session, "rollback", new_callable=AsyncMock):
org_in = OrganizationCreate(name="Test", slug="test") org_in = OrganizationCreate(name="Test", slug="test")
with pytest.raises(RuntimeError, match="Unexpected error"): with pytest.raises(RuntimeError, match="Unexpected error"):
await organization_crud.create(session, obj_in=org_in) await organization_repo.create(session, obj_in=org_in)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_multi_with_filters_database_error(self, async_test_db): async def test_get_multi_with_filters_database_error(self, async_test_db):
@@ -999,7 +1002,7 @@ class TestOrganizationExceptionHandlers:
session, "execute", side_effect=Exception("Query timeout") session, "execute", side_effect=Exception("Query timeout")
): ):
with pytest.raises(Exception, match="Query timeout"): with pytest.raises(Exception, match="Query timeout"):
await organization_crud.get_multi_with_filters(session) await organization_repo.get_multi_with_filters(session)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_member_count_database_error(self, async_test_db): async def test_get_member_count_database_error(self, async_test_db):
@@ -1013,7 +1016,7 @@ class TestOrganizationExceptionHandlers:
session, "execute", side_effect=Exception("Count query failed") session, "execute", side_effect=Exception("Count query failed")
): ):
with pytest.raises(Exception, match="Count query failed"): with pytest.raises(Exception, match="Count query failed"):
await organization_crud.get_member_count( await organization_repo.get_member_count(
session, organization_id=uuid4() session, organization_id=uuid4()
) )
@@ -1027,7 +1030,7 @@ class TestOrganizationExceptionHandlers:
session, "execute", side_effect=Exception("Complex query failed") session, "execute", side_effect=Exception("Complex query failed")
): ):
with pytest.raises(Exception, match="Complex query failed"): with pytest.raises(Exception, match="Complex query failed"):
await organization_crud.get_multi_with_member_counts(session) await organization_repo.get_multi_with_member_counts(session)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_add_user_integrity_error(self, async_test_db, async_test_user): async def test_add_user_integrity_error(self, async_test_db, async_test_user):
@@ -1058,9 +1061,10 @@ class TestOrganizationExceptionHandlers:
with patch.object(session, "commit", side_effect=mock_commit): with patch.object(session, "commit", side_effect=mock_commit):
with patch.object(session, "rollback", new_callable=AsyncMock): with patch.object(session, "rollback", new_callable=AsyncMock):
with pytest.raises( with pytest.raises(
ValueError, match="Failed to add user to organization" IntegrityConstraintError,
match="Failed to add user to organization",
): ):
await organization_crud.add_user( await organization_repo.add_user(
session, session,
organization_id=org_id, organization_id=org_id,
user_id=async_test_user.id, user_id=async_test_user.id,
@@ -1078,7 +1082,7 @@ class TestOrganizationExceptionHandlers:
session, "execute", side_effect=Exception("Delete failed") session, "execute", side_effect=Exception("Delete failed")
): ):
with pytest.raises(Exception, match="Delete failed"): with pytest.raises(Exception, match="Delete failed"):
await organization_crud.remove_user( await organization_repo.remove_user(
session, organization_id=uuid4(), user_id=async_test_user.id session, organization_id=uuid4(), user_id=async_test_user.id
) )
@@ -1096,7 +1100,7 @@ class TestOrganizationExceptionHandlers:
session, "execute", side_effect=Exception("Update failed") session, "execute", side_effect=Exception("Update failed")
): ):
with pytest.raises(Exception, match="Update failed"): with pytest.raises(Exception, match="Update failed"):
await organization_crud.update_user_role( await organization_repo.update_user_role(
session, session,
organization_id=uuid4(), organization_id=uuid4(),
user_id=async_test_user.id, user_id=async_test_user.id,
@@ -1115,7 +1119,7 @@ class TestOrganizationExceptionHandlers:
session, "execute", side_effect=Exception("Members query failed") session, "execute", side_effect=Exception("Members query failed")
): ):
with pytest.raises(Exception, match="Members query failed"): with pytest.raises(Exception, match="Members query failed"):
await organization_crud.get_organization_members( await organization_repo.get_organization_members(
session, organization_id=uuid4() session, organization_id=uuid4()
) )
@@ -1131,7 +1135,7 @@ class TestOrganizationExceptionHandlers:
session, "execute", side_effect=Exception("User orgs query failed") session, "execute", side_effect=Exception("User orgs query failed")
): ):
with pytest.raises(Exception, match="User orgs query failed"): with pytest.raises(Exception, match="User orgs query failed"):
await organization_crud.get_user_organizations( await organization_repo.get_user_organizations(
session, user_id=async_test_user.id session, user_id=async_test_user.id
) )
@@ -1147,7 +1151,7 @@ class TestOrganizationExceptionHandlers:
session, "execute", side_effect=Exception("Details query failed") session, "execute", side_effect=Exception("Details query failed")
): ):
with pytest.raises(Exception, match="Details query failed"): with pytest.raises(Exception, match="Details query failed"):
await organization_crud.get_user_organizations_with_details( await organization_repo.get_user_organizations_with_details(
session, user_id=async_test_user.id session, user_id=async_test_user.id
) )
@@ -1165,6 +1169,6 @@ class TestOrganizationExceptionHandlers:
session, "execute", side_effect=Exception("Role query failed") session, "execute", side_effect=Exception("Role query failed")
): ):
with pytest.raises(Exception, match="Role query failed"): with pytest.raises(Exception, match="Role query failed"):
await organization_crud.get_user_role_in_org( await organization_repo.get_user_role_in_org(
session, user_id=async_test_user.id, organization_id=uuid4() session, user_id=async_test_user.id, organization_id=uuid4()
) )

View File

@@ -1,6 +1,6 @@
# tests/crud/test_session_async.py # tests/repositories/test_session_async.py
""" """
Comprehensive tests for async session CRUD operations. Comprehensive tests for async session repository operations.
""" """
from datetime import UTC, datetime, timedelta from datetime import UTC, datetime, timedelta
@@ -8,8 +8,9 @@ from uuid import uuid4
import pytest import pytest
from app.crud.session import session as session_crud from app.core.repository_exceptions import InvalidInputError
from app.models.user_session import UserSession from app.models.user_session import UserSession
from app.repositories.session import session_repo as session_repo
from app.schemas.sessions import SessionCreate from app.schemas.sessions import SessionCreate
@@ -36,7 +37,7 @@ class TestGetByJti:
await session.commit() await session.commit()
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
result = await session_crud.get_by_jti(session, jti="test_jti_123") result = await session_repo.get_by_jti(session, jti="test_jti_123")
assert result is not None assert result is not None
assert result.refresh_token_jti == "test_jti_123" assert result.refresh_token_jti == "test_jti_123"
@@ -46,7 +47,7 @@ class TestGetByJti:
_test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
result = await session_crud.get_by_jti(session, jti="nonexistent") result = await session_repo.get_by_jti(session, jti="nonexistent")
assert result is None assert result is None
@@ -73,7 +74,7 @@ class TestGetActiveByJti:
await session.commit() await session.commit()
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
result = await session_crud.get_active_by_jti(session, jti="active_jti") result = await session_repo.get_active_by_jti(session, jti="active_jti")
assert result is not None assert result is not None
assert result.is_active is True assert result.is_active is True
@@ -97,7 +98,7 @@ class TestGetActiveByJti:
await session.commit() await session.commit()
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
result = await session_crud.get_active_by_jti(session, jti="inactive_jti") result = await session_repo.get_active_by_jti(session, jti="inactive_jti")
assert result is None assert result is None
@@ -134,7 +135,7 @@ class TestGetUserSessions:
await session.commit() await session.commit()
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
results = await session_crud.get_user_sessions( results = await session_repo.get_user_sessions(
session, user_id=str(async_test_user.id), active_only=True session, user_id=str(async_test_user.id), active_only=True
) )
assert len(results) == 1 assert len(results) == 1
@@ -161,7 +162,7 @@ class TestGetUserSessions:
await session.commit() await session.commit()
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
results = await session_crud.get_user_sessions( results = await session_repo.get_user_sessions(
session, user_id=str(async_test_user.id), active_only=False session, user_id=str(async_test_user.id), active_only=False
) )
assert len(results) == 3 assert len(results) == 3
@@ -172,7 +173,7 @@ class TestCreateSession:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_session_success(self, async_test_db, async_test_user): async def test_create_session_success(self, async_test_db, async_test_user):
"""Test successfully creating a session_crud.""" """Test successfully creating a session_repo."""
_test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
@@ -188,7 +189,7 @@ class TestCreateSession:
location_city="San Francisco", location_city="San Francisco",
location_country="USA", location_country="USA",
) )
result = await session_crud.create_session(session, obj_in=session_data) result = await session_repo.create_session(session, obj_in=session_data)
assert result.user_id == async_test_user.id assert result.user_id == async_test_user.id
assert result.refresh_token_jti == "new_jti" assert result.refresh_token_jti == "new_jti"
@@ -201,7 +202,7 @@ class TestDeactivate:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_deactivate_success(self, async_test_db, async_test_user): async def test_deactivate_success(self, async_test_db, async_test_user):
"""Test successfully deactivating a session_crud.""" """Test successfully deactivating a session_repo."""
_test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
@@ -220,7 +221,7 @@ class TestDeactivate:
session_id = user_session.id session_id = user_session.id
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
result = await session_crud.deactivate(session, session_id=str(session_id)) result = await session_repo.deactivate(session, session_id=str(session_id))
assert result is not None assert result is not None
assert result.is_active is False assert result.is_active is False
@@ -230,7 +231,7 @@ class TestDeactivate:
_test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
result = await session_crud.deactivate(session, session_id=str(uuid4())) result = await session_repo.deactivate(session, session_id=str(uuid4()))
assert result is None assert result is None
@@ -261,7 +262,7 @@ class TestDeactivateAllUserSessions:
await session.commit() await session.commit()
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
count = await session_crud.deactivate_all_user_sessions( count = await session_repo.deactivate_all_user_sessions(
session, user_id=str(async_test_user.id) session, user_id=str(async_test_user.id)
) )
assert count == 2 assert count == 2
@@ -291,7 +292,7 @@ class TestUpdateLastUsed:
await session.refresh(user_session) await session.refresh(user_session)
old_time = user_session.last_used_at old_time = user_session.last_used_at
result = await session_crud.update_last_used(session, session=user_session) result = await session_repo.update_last_used(session, session=user_session)
assert result.last_used_at > old_time assert result.last_used_at > old_time
@@ -320,7 +321,7 @@ class TestGetUserSessionCount:
await session.commit() await session.commit()
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
count = await session_crud.get_user_session_count( count = await session_repo.get_user_session_count(
session, user_id=str(async_test_user.id) session, user_id=str(async_test_user.id)
) )
assert count == 3 assert count == 3
@@ -331,7 +332,7 @@ class TestGetUserSessionCount:
_test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
count = await session_crud.get_user_session_count( count = await session_repo.get_user_session_count(
session, user_id=str(uuid4()) session, user_id=str(uuid4())
) )
assert count == 0 assert count == 0
@@ -363,7 +364,7 @@ class TestUpdateRefreshToken:
new_jti = "new_jti_123" new_jti = "new_jti_123"
new_expires = datetime.now(UTC) + timedelta(days=14) new_expires = datetime.now(UTC) + timedelta(days=14)
result = await session_crud.update_refresh_token( result = await session_repo.update_refresh_token(
session, session,
session=user_session, session=user_session,
new_jti=new_jti, new_jti=new_jti,
@@ -409,7 +410,7 @@ class TestCleanupExpired:
# Cleanup # Cleanup
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
count = await session_crud.cleanup_expired(session, keep_days=30) count = await session_repo.cleanup_expired(session, keep_days=30)
assert count == 1 assert count == 1
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -435,7 +436,7 @@ class TestCleanupExpired:
# Cleanup # Cleanup
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
count = await session_crud.cleanup_expired(session, keep_days=30) count = await session_repo.cleanup_expired(session, keep_days=30)
assert count == 0 # Should not delete recent sessions assert count == 0 # Should not delete recent sessions
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -461,7 +462,7 @@ class TestCleanupExpired:
# Cleanup # Cleanup
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
count = await session_crud.cleanup_expired(session, keep_days=30) count = await session_repo.cleanup_expired(session, keep_days=30)
assert count == 0 # Should not delete active sessions assert count == 0 # Should not delete active sessions
@@ -492,7 +493,7 @@ class TestCleanupExpiredForUser:
# Cleanup for user # Cleanup for user
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
count = await session_crud.cleanup_expired_for_user( count = await session_repo.cleanup_expired_for_user(
session, user_id=str(async_test_user.id) session, user_id=str(async_test_user.id)
) )
assert count == 1 assert count == 1
@@ -503,8 +504,8 @@ class TestCleanupExpiredForUser:
_test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
with pytest.raises(ValueError, match="Invalid user ID format"): with pytest.raises(InvalidInputError, match="Invalid user ID format"):
await session_crud.cleanup_expired_for_user( await session_repo.cleanup_expired_for_user(
session, user_id="not-a-valid-uuid" session, user_id="not-a-valid-uuid"
) )
@@ -532,7 +533,7 @@ class TestCleanupExpiredForUser:
# Cleanup # Cleanup
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
count = await session_crud.cleanup_expired_for_user( count = await session_repo.cleanup_expired_for_user(
session, user_id=str(async_test_user.id) session, user_id=str(async_test_user.id)
) )
assert count == 0 # Should not delete active sessions assert count == 0 # Should not delete active sessions
@@ -564,7 +565,7 @@ class TestGetUserSessionsWithUser:
# Get with user relationship # Get with user relationship
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
results = await session_crud.get_user_sessions( results = await session_repo.get_user_sessions(
session, user_id=str(async_test_user.id), with_user=True session, user_id=str(async_test_user.id), with_user=True
) )
assert len(results) >= 1 assert len(results) >= 1

View File

@@ -1,6 +1,6 @@
# tests/crud/test_session_db_failures.py # tests/repositories/test_session_db_failures.py
""" """
Comprehensive tests for session CRUD database failure scenarios. Comprehensive tests for session repository database failure scenarios.
""" """
from datetime import UTC, datetime, timedelta from datetime import UTC, datetime, timedelta
@@ -10,12 +10,13 @@ from uuid import uuid4
import pytest import pytest
from sqlalchemy.exc import OperationalError from sqlalchemy.exc import OperationalError
from app.crud.session import session as session_crud from app.core.repository_exceptions import IntegrityConstraintError
from app.models.user_session import UserSession from app.models.user_session import UserSession
from app.repositories.session import session_repo as session_repo
from app.schemas.sessions import SessionCreate from app.schemas.sessions import SessionCreate
class TestSessionCRUDGetByJtiFailures: class TestSessionRepositoryGetByJtiFailures:
"""Test get_by_jti exception handling.""" """Test get_by_jti exception handling."""
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -30,10 +31,10 @@ class TestSessionCRUDGetByJtiFailures:
with patch.object(session, "execute", side_effect=mock_execute): with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError): with pytest.raises(OperationalError):
await session_crud.get_by_jti(session, jti="test_jti") await session_repo.get_by_jti(session, jti="test_jti")
class TestSessionCRUDGetActiveByJtiFailures: class TestSessionRepositoryGetActiveByJtiFailures:
"""Test get_active_by_jti exception handling.""" """Test get_active_by_jti exception handling."""
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -48,10 +49,10 @@ class TestSessionCRUDGetActiveByJtiFailures:
with patch.object(session, "execute", side_effect=mock_execute): with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError): with pytest.raises(OperationalError):
await session_crud.get_active_by_jti(session, jti="test_jti") await session_repo.get_active_by_jti(session, jti="test_jti")
class TestSessionCRUDGetUserSessionsFailures: class TestSessionRepositoryGetUserSessionsFailures:
"""Test get_user_sessions exception handling.""" """Test get_user_sessions exception handling."""
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -68,12 +69,12 @@ class TestSessionCRUDGetUserSessionsFailures:
with patch.object(session, "execute", side_effect=mock_execute): with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError): with pytest.raises(OperationalError):
await session_crud.get_user_sessions( await session_repo.get_user_sessions(
session, user_id=str(async_test_user.id) session, user_id=str(async_test_user.id)
) )
class TestSessionCRUDCreateSessionFailures: class TestSessionRepositoryCreateSessionFailures:
"""Test create_session exception handling.""" """Test create_session exception handling."""
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -102,8 +103,10 @@ class TestSessionCRUDCreateSessionFailures:
last_used_at=datetime.now(UTC), last_used_at=datetime.now(UTC),
) )
with pytest.raises(ValueError, match="Failed to create session"): with pytest.raises(
await session_crud.create_session(session, obj_in=session_data) IntegrityConstraintError, match="Failed to create session"
):
await session_repo.create_session(session, obj_in=session_data)
mock_rollback.assert_called_once() mock_rollback.assert_called_once()
@@ -133,13 +136,15 @@ class TestSessionCRUDCreateSessionFailures:
last_used_at=datetime.now(UTC), last_used_at=datetime.now(UTC),
) )
with pytest.raises(ValueError, match="Failed to create session"): with pytest.raises(
await session_crud.create_session(session, obj_in=session_data) IntegrityConstraintError, match="Failed to create session"
):
await session_repo.create_session(session, obj_in=session_data)
mock_rollback.assert_called_once() mock_rollback.assert_called_once()
class TestSessionCRUDDeactivateFailures: class TestSessionRepositoryDeactivateFailures:
"""Test deactivate exception handling.""" """Test deactivate exception handling."""
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -177,14 +182,14 @@ class TestSessionCRUDDeactivateFailures:
session, "rollback", new_callable=AsyncMock session, "rollback", new_callable=AsyncMock
) as mock_rollback: ) as mock_rollback:
with pytest.raises(OperationalError): with pytest.raises(OperationalError):
await session_crud.deactivate( await session_repo.deactivate(
session, session_id=str(session_id) session, session_id=str(session_id)
) )
mock_rollback.assert_called_once() mock_rollback.assert_called_once()
class TestSessionCRUDDeactivateAllFailures: class TestSessionRepositoryDeactivateAllFailures:
"""Test deactivate_all_user_sessions exception handling.""" """Test deactivate_all_user_sessions exception handling."""
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -204,14 +209,14 @@ class TestSessionCRUDDeactivateAllFailures:
session, "rollback", new_callable=AsyncMock session, "rollback", new_callable=AsyncMock
) as mock_rollback: ) as mock_rollback:
with pytest.raises(OperationalError): with pytest.raises(OperationalError):
await session_crud.deactivate_all_user_sessions( await session_repo.deactivate_all_user_sessions(
session, user_id=str(async_test_user.id) session, user_id=str(async_test_user.id)
) )
mock_rollback.assert_called_once() mock_rollback.assert_called_once()
class TestSessionCRUDUpdateLastUsedFailures: class TestSessionRepositoryUpdateLastUsedFailures:
"""Test update_last_used exception handling.""" """Test update_last_used exception handling."""
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -254,12 +259,12 @@ class TestSessionCRUDUpdateLastUsedFailures:
session, "rollback", new_callable=AsyncMock session, "rollback", new_callable=AsyncMock
) as mock_rollback: ) as mock_rollback:
with pytest.raises(OperationalError): with pytest.raises(OperationalError):
await session_crud.update_last_used(session, session=sess) await session_repo.update_last_used(session, session=sess)
mock_rollback.assert_called_once() mock_rollback.assert_called_once()
class TestSessionCRUDUpdateRefreshTokenFailures: class TestSessionRepositoryUpdateRefreshTokenFailures:
"""Test update_refresh_token exception handling.""" """Test update_refresh_token exception handling."""
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -302,7 +307,7 @@ class TestSessionCRUDUpdateRefreshTokenFailures:
session, "rollback", new_callable=AsyncMock session, "rollback", new_callable=AsyncMock
) as mock_rollback: ) as mock_rollback:
with pytest.raises(OperationalError): with pytest.raises(OperationalError):
await session_crud.update_refresh_token( await session_repo.update_refresh_token(
session, session,
session=sess, session=sess,
new_jti=str(uuid4()), new_jti=str(uuid4()),
@@ -312,7 +317,7 @@ class TestSessionCRUDUpdateRefreshTokenFailures:
mock_rollback.assert_called_once() mock_rollback.assert_called_once()
class TestSessionCRUDCleanupExpiredFailures: class TestSessionRepositoryCleanupExpiredFailures:
"""Test cleanup_expired exception handling.""" """Test cleanup_expired exception handling."""
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -332,12 +337,12 @@ class TestSessionCRUDCleanupExpiredFailures:
session, "rollback", new_callable=AsyncMock session, "rollback", new_callable=AsyncMock
) as mock_rollback: ) as mock_rollback:
with pytest.raises(OperationalError): with pytest.raises(OperationalError):
await session_crud.cleanup_expired(session, keep_days=30) await session_repo.cleanup_expired(session, keep_days=30)
mock_rollback.assert_called_once() mock_rollback.assert_called_once()
class TestSessionCRUDCleanupExpiredForUserFailures: class TestSessionRepositoryCleanupExpiredForUserFailures:
"""Test cleanup_expired_for_user exception handling.""" """Test cleanup_expired_for_user exception handling."""
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -357,14 +362,14 @@ class TestSessionCRUDCleanupExpiredForUserFailures:
session, "rollback", new_callable=AsyncMock session, "rollback", new_callable=AsyncMock
) as mock_rollback: ) as mock_rollback:
with pytest.raises(OperationalError): with pytest.raises(OperationalError):
await session_crud.cleanup_expired_for_user( await session_repo.cleanup_expired_for_user(
session, user_id=str(async_test_user.id) session, user_id=str(async_test_user.id)
) )
mock_rollback.assert_called_once() mock_rollback.assert_called_once()
class TestSessionCRUDGetUserSessionCountFailures: class TestSessionRepositoryGetUserSessionCountFailures:
"""Test get_user_session_count exception handling.""" """Test get_user_session_count exception handling."""
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -381,6 +386,6 @@ class TestSessionCRUDGetUserSessionCountFailures:
with patch.object(session, "execute", side_effect=mock_execute): with patch.object(session, "execute", side_effect=mock_execute):
with pytest.raises(OperationalError): with pytest.raises(OperationalError):
await session_crud.get_user_session_count( await session_repo.get_user_session_count(
session, user_id=str(async_test_user.id) session, user_id=str(async_test_user.id)
) )

View File

@@ -1,11 +1,12 @@
# tests/crud/test_user_async.py # tests/repositories/test_user_async.py
""" """
Comprehensive tests for async user CRUD operations. Comprehensive tests for async user repository operations.
""" """
import pytest import pytest
from app.crud.user import user as user_crud from app.core.repository_exceptions import DuplicateEntryError, InvalidInputError
from app.repositories.user import user_repo as user_repo
from app.schemas.users import UserCreate, UserUpdate from app.schemas.users import UserCreate, UserUpdate
@@ -18,7 +19,7 @@ class TestGetByEmail:
_test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
result = await user_crud.get_by_email(session, email=async_test_user.email) result = await user_repo.get_by_email(session, email=async_test_user.email)
assert result is not None assert result is not None
assert result.email == async_test_user.email assert result.email == async_test_user.email
assert result.id == async_test_user.id assert result.id == async_test_user.id
@@ -29,7 +30,7 @@ class TestGetByEmail:
_test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
result = await user_crud.get_by_email( result = await user_repo.get_by_email(
session, email="nonexistent@example.com" session, email="nonexistent@example.com"
) )
assert result is None assert result is None
@@ -40,7 +41,7 @@ class TestCreate:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_user_success(self, async_test_db): async def test_create_user_success(self, async_test_db):
"""Test successfully creating a user_crud.""" """Test successfully creating a user_repo."""
_test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
@@ -51,7 +52,7 @@ class TestCreate:
last_name="User", last_name="User",
phone_number="+1234567890", phone_number="+1234567890",
) )
result = await user_crud.create(session, obj_in=user_data) result = await user_repo.create(session, obj_in=user_data)
assert result.email == "newuser@example.com" assert result.email == "newuser@example.com"
assert result.first_name == "New" assert result.first_name == "New"
@@ -75,7 +76,7 @@ class TestCreate:
last_name="User", last_name="User",
is_superuser=True, is_superuser=True,
) )
result = await user_crud.create(session, obj_in=user_data) result = await user_repo.create(session, obj_in=user_data)
assert result.is_superuser is True assert result.is_superuser is True
assert result.email == "superuser@example.com" assert result.email == "superuser@example.com"
@@ -93,8 +94,8 @@ class TestCreate:
last_name="User", last_name="User",
) )
with pytest.raises(ValueError) as exc_info: with pytest.raises(DuplicateEntryError) as exc_info:
await user_crud.create(session, obj_in=user_data) await user_repo.create(session, obj_in=user_data)
assert "already exists" in str(exc_info.value).lower() assert "already exists" in str(exc_info.value).lower()
@@ -109,12 +110,12 @@ class TestUpdate:
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
# Get fresh copy of user # Get fresh copy of user
user = await user_crud.get(session, id=str(async_test_user.id)) user = await user_repo.get(session, id=str(async_test_user.id))
update_data = UserUpdate( update_data = UserUpdate(
first_name="Updated", last_name="Name", phone_number="+9876543210" first_name="Updated", last_name="Name", phone_number="+9876543210"
) )
result = await user_crud.update(session, db_obj=user, obj_in=update_data) result = await user_repo.update(session, db_obj=user, obj_in=update_data)
assert result.first_name == "Updated" assert result.first_name == "Updated"
assert result.last_name == "Name" assert result.last_name == "Name"
@@ -133,16 +134,16 @@ class TestUpdate:
first_name="Pass", first_name="Pass",
last_name="Test", last_name="Test",
) )
user = await user_crud.create(session, obj_in=user_data) user = await user_repo.create(session, obj_in=user_data)
user_id = user.id user_id = user.id
old_password_hash = user.password_hash old_password_hash = user.password_hash
# Update the password # Update the password
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
user = await user_crud.get(session, id=str(user_id)) user = await user_repo.get(session, id=str(user_id))
update_data = UserUpdate(password="NewDifferentPassword123!") update_data = UserUpdate(password="NewDifferentPassword123!")
result = await user_crud.update(session, db_obj=user, obj_in=update_data) result = await user_repo.update(session, db_obj=user, obj_in=update_data)
await session.refresh(result) await session.refresh(result)
assert result.password_hash != old_password_hash assert result.password_hash != old_password_hash
@@ -157,10 +158,10 @@ class TestUpdate:
_test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id)) user = await user_repo.get(session, id=str(async_test_user.id))
update_dict = {"first_name": "DictUpdate"} update_dict = {"first_name": "DictUpdate"}
result = await user_crud.update(session, db_obj=user, obj_in=update_dict) result = await user_repo.update(session, db_obj=user, obj_in=update_dict)
assert result.first_name == "DictUpdate" assert result.first_name == "DictUpdate"
@@ -174,7 +175,7 @@ class TestGetMultiWithTotal:
_test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
users, total = await user_crud.get_multi_with_total( users, total = await user_repo.get_multi_with_total(
session, skip=0, limit=10 session, skip=0, limit=10
) )
assert total >= 1 assert total >= 1
@@ -195,10 +196,10 @@ class TestGetMultiWithTotal:
first_name=f"User{i}", first_name=f"User{i}",
last_name="Test", last_name="Test",
) )
await user_crud.create(session, obj_in=user_data) await user_repo.create(session, obj_in=user_data)
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
users, _total = await user_crud.get_multi_with_total( users, _total = await user_repo.get_multi_with_total(
session, skip=0, limit=10, sort_by="email", sort_order="asc" session, skip=0, limit=10, sort_by="email", sort_order="asc"
) )
@@ -221,10 +222,10 @@ class TestGetMultiWithTotal:
first_name=f"User{i}", first_name=f"User{i}",
last_name="Test", last_name="Test",
) )
await user_crud.create(session, obj_in=user_data) await user_repo.create(session, obj_in=user_data)
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
users, _total = await user_crud.get_multi_with_total( users, _total = await user_repo.get_multi_with_total(
session, skip=0, limit=10, sort_by="email", sort_order="desc" session, skip=0, limit=10, sort_by="email", sort_order="desc"
) )
@@ -246,7 +247,7 @@ class TestGetMultiWithTotal:
first_name="Active", first_name="Active",
last_name="User", last_name="User",
) )
await user_crud.create(session, obj_in=active_user) await user_repo.create(session, obj_in=active_user)
inactive_user = UserCreate( inactive_user = UserCreate(
email="inactive@example.com", email="inactive@example.com",
@@ -254,15 +255,15 @@ class TestGetMultiWithTotal:
first_name="Inactive", first_name="Inactive",
last_name="User", last_name="User",
) )
created_inactive = await user_crud.create(session, obj_in=inactive_user) created_inactive = await user_repo.create(session, obj_in=inactive_user)
# Deactivate the user # Deactivate the user
await user_crud.update( await user_repo.update(
session, db_obj=created_inactive, obj_in={"is_active": False} session, db_obj=created_inactive, obj_in={"is_active": False}
) )
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
users, _total = await user_crud.get_multi_with_total( users, _total = await user_repo.get_multi_with_total(
session, skip=0, limit=100, filters={"is_active": True} session, skip=0, limit=100, filters={"is_active": True}
) )
@@ -282,10 +283,10 @@ class TestGetMultiWithTotal:
first_name="Searchable", first_name="Searchable",
last_name="UserName", last_name="UserName",
) )
await user_crud.create(session, obj_in=user_data) await user_repo.create(session, obj_in=user_data)
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
users, total = await user_crud.get_multi_with_total( users, total = await user_repo.get_multi_with_total(
session, skip=0, limit=100, search="Searchable" session, skip=0, limit=100, search="Searchable"
) )
@@ -306,16 +307,16 @@ class TestGetMultiWithTotal:
first_name=f"Page{i}", first_name=f"Page{i}",
last_name="User", last_name="User",
) )
await user_crud.create(session, obj_in=user_data) await user_repo.create(session, obj_in=user_data)
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
# Get first page # Get first page
users_page1, total = await user_crud.get_multi_with_total( users_page1, total = await user_repo.get_multi_with_total(
session, skip=0, limit=2 session, skip=0, limit=2
) )
# Get second page # Get second page
users_page2, total2 = await user_crud.get_multi_with_total( users_page2, total2 = await user_repo.get_multi_with_total(
session, skip=2, limit=2 session, skip=2, limit=2
) )
@@ -330,8 +331,8 @@ class TestGetMultiWithTotal:
_test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
with pytest.raises(ValueError) as exc_info: with pytest.raises(InvalidInputError) as exc_info:
await user_crud.get_multi_with_total(session, skip=-1, limit=10) await user_repo.get_multi_with_total(session, skip=-1, limit=10)
assert "skip must be non-negative" in str(exc_info.value) assert "skip must be non-negative" in str(exc_info.value)
@@ -341,8 +342,8 @@ class TestGetMultiWithTotal:
_test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
with pytest.raises(ValueError) as exc_info: with pytest.raises(InvalidInputError) as exc_info:
await user_crud.get_multi_with_total(session, skip=0, limit=-1) await user_repo.get_multi_with_total(session, skip=0, limit=-1)
assert "limit must be non-negative" in str(exc_info.value) assert "limit must be non-negative" in str(exc_info.value)
@@ -352,8 +353,8 @@ class TestGetMultiWithTotal:
_test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
with pytest.raises(ValueError) as exc_info: with pytest.raises(InvalidInputError) as exc_info:
await user_crud.get_multi_with_total(session, skip=0, limit=1001) await user_repo.get_multi_with_total(session, skip=0, limit=1001)
assert "Maximum limit is 1000" in str(exc_info.value) assert "Maximum limit is 1000" in str(exc_info.value)
@@ -376,12 +377,12 @@ class TestBulkUpdateStatus:
first_name=f"Bulk{i}", first_name=f"Bulk{i}",
last_name="User", last_name="User",
) )
user = await user_crud.create(session, obj_in=user_data) user = await user_repo.create(session, obj_in=user_data)
user_ids.append(user.id) user_ids.append(user.id)
# Bulk deactivate # Bulk deactivate
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_update_status( count = await user_repo.bulk_update_status(
session, user_ids=user_ids, is_active=False session, user_ids=user_ids, is_active=False
) )
assert count == 3 assert count == 3
@@ -389,7 +390,7 @@ class TestBulkUpdateStatus:
# Verify all are inactive # Verify all are inactive
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
for user_id in user_ids: for user_id in user_ids:
user = await user_crud.get(session, id=str(user_id)) user = await user_repo.get(session, id=str(user_id))
assert user.is_active is False assert user.is_active is False
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -398,7 +399,7 @@ class TestBulkUpdateStatus:
_test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_update_status( count = await user_repo.bulk_update_status(
session, user_ids=[], is_active=False session, user_ids=[], is_active=False
) )
assert count == 0 assert count == 0
@@ -416,21 +417,21 @@ class TestBulkUpdateStatus:
first_name="Reactivate", first_name="Reactivate",
last_name="User", last_name="User",
) )
user = await user_crud.create(session, obj_in=user_data) user = await user_repo.create(session, obj_in=user_data)
# Deactivate # Deactivate
await user_crud.update(session, db_obj=user, obj_in={"is_active": False}) await user_repo.update(session, db_obj=user, obj_in={"is_active": False})
user_id = user.id user_id = user.id
# Reactivate # Reactivate
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_update_status( count = await user_repo.bulk_update_status(
session, user_ids=[user_id], is_active=True session, user_ids=[user_id], is_active=True
) )
assert count == 1 assert count == 1
# Verify active # Verify active
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
user = await user_crud.get(session, id=str(user_id)) user = await user_repo.get(session, id=str(user_id))
assert user.is_active is True assert user.is_active is True
@@ -452,24 +453,24 @@ class TestBulkSoftDelete:
first_name=f"Delete{i}", first_name=f"Delete{i}",
last_name="User", last_name="User",
) )
user = await user_crud.create(session, obj_in=user_data) user = await user_repo.create(session, obj_in=user_data)
user_ids.append(user.id) user_ids.append(user.id)
# Bulk delete # Bulk delete
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_soft_delete(session, user_ids=user_ids) count = await user_repo.bulk_soft_delete(session, user_ids=user_ids)
assert count == 3 assert count == 3
# Verify all are soft deleted # Verify all are soft deleted
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
for user_id in user_ids: for user_id in user_ids:
user = await user_crud.get(session, id=str(user_id)) user = await user_repo.get(session, id=str(user_id))
assert user.deleted_at is not None assert user.deleted_at is not None
assert user.is_active is False assert user.is_active is False
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bulk_soft_delete_with_exclusion(self, async_test_db): async def test_bulk_soft_delete_with_exclusion(self, async_test_db):
"""Test bulk soft delete with excluded user_crud.""" """Test bulk soft delete with excluded user_repo."""
_test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
# Create multiple users # Create multiple users
@@ -482,20 +483,20 @@ class TestBulkSoftDelete:
first_name=f"Exclude{i}", first_name=f"Exclude{i}",
last_name="User", last_name="User",
) )
user = await user_crud.create(session, obj_in=user_data) user = await user_repo.create(session, obj_in=user_data)
user_ids.append(user.id) user_ids.append(user.id)
# Bulk delete, excluding first user # Bulk delete, excluding first user
exclude_id = user_ids[0] exclude_id = user_ids[0]
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_soft_delete( count = await user_repo.bulk_soft_delete(
session, user_ids=user_ids, exclude_user_id=exclude_id session, user_ids=user_ids, exclude_user_id=exclude_id
) )
assert count == 2 # Only 2 deleted assert count == 2 # Only 2 deleted
# Verify excluded user is NOT deleted # Verify excluded user is NOT deleted
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
excluded_user = await user_crud.get(session, id=str(exclude_id)) excluded_user = await user_repo.get(session, id=str(exclude_id))
assert excluded_user.deleted_at is None assert excluded_user.deleted_at is None
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -504,7 +505,7 @@ class TestBulkSoftDelete:
_test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_soft_delete(session, user_ids=[]) count = await user_repo.bulk_soft_delete(session, user_ids=[])
assert count == 0 assert count == 0
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -520,12 +521,12 @@ class TestBulkSoftDelete:
first_name="Only", first_name="Only",
last_name="User", last_name="User",
) )
user = await user_crud.create(session, obj_in=user_data) user = await user_repo.create(session, obj_in=user_data)
user_id = user.id user_id = user.id
# Try to delete but exclude # Try to delete but exclude
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_soft_delete( count = await user_repo.bulk_soft_delete(
session, user_ids=[user_id], exclude_user_id=user_id session, user_ids=[user_id], exclude_user_id=user_id
) )
assert count == 0 assert count == 0
@@ -543,15 +544,15 @@ class TestBulkSoftDelete:
first_name="PreDeleted", first_name="PreDeleted",
last_name="User", last_name="User",
) )
user = await user_crud.create(session, obj_in=user_data) user = await user_repo.create(session, obj_in=user_data)
user_id = user.id user_id = user.id
# First deletion # First deletion
await user_crud.bulk_soft_delete(session, user_ids=[user_id]) await user_repo.bulk_soft_delete(session, user_ids=[user_id])
# Try to delete again # Try to delete again
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
count = await user_crud.bulk_soft_delete(session, user_ids=[user_id]) count = await user_repo.bulk_soft_delete(session, user_ids=[user_id])
assert count == 0 # Already deleted assert count == 0 # Already deleted
@@ -560,16 +561,16 @@ class TestUtilityMethods:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_is_active_true(self, async_test_db, async_test_user): async def test_is_active_true(self, async_test_db, async_test_user):
"""Test is_active returns True for active user_crud.""" """Test is_active returns True for active user_repo."""
_test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id)) user = await user_repo.get(session, id=str(async_test_user.id))
assert user_crud.is_active(user) is True assert user_repo.is_active(user) is True
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_is_active_false(self, async_test_db): async def test_is_active_false(self, async_test_db):
"""Test is_active returns False for inactive user_crud.""" """Test is_active returns False for inactive user_repo."""
_test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
@@ -579,10 +580,10 @@ class TestUtilityMethods:
first_name="Inactive", first_name="Inactive",
last_name="User", last_name="User",
) )
user = await user_crud.create(session, obj_in=user_data) user = await user_repo.create(session, obj_in=user_data)
await user_crud.update(session, db_obj=user, obj_in={"is_active": False}) await user_repo.update(session, db_obj=user, obj_in={"is_active": False})
assert user_crud.is_active(user) is False assert user_repo.is_active(user) is False
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_is_superuser_true(self, async_test_db, async_test_superuser): async def test_is_superuser_true(self, async_test_db, async_test_superuser):
@@ -590,22 +591,22 @@ class TestUtilityMethods:
_test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_superuser.id)) user = await user_repo.get(session, id=str(async_test_superuser.id))
assert user_crud.is_superuser(user) is True assert user_repo.is_superuser(user) is True
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_is_superuser_false(self, async_test_db, async_test_user): async def test_is_superuser_false(self, async_test_db, async_test_user):
"""Test is_superuser returns False for regular user_crud.""" """Test is_superuser returns False for regular user_repo."""
_test_engine, AsyncTestingSessionLocal = async_test_db _test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
user = await user_crud.get(session, id=str(async_test_user.id)) user = await user_repo.get(session, id=str(async_test_user.id))
assert user_crud.is_superuser(user) is False assert user_repo.is_superuser(user) is False
class TestUserExceptionHandlers: class TestUserExceptionHandlers:
""" """
Test exception handlers in user CRUD methods. Test exception handlers in user repository methods.
Covers lines: 30-32, 205-208, 257-260 Covers lines: 30-32, 205-208, 257-260
""" """
@@ -621,7 +622,7 @@ class TestUserExceptionHandlers:
session, "execute", side_effect=Exception("Database query failed") session, "execute", side_effect=Exception("Database query failed")
): ):
with pytest.raises(Exception, match="Database query failed"): with pytest.raises(Exception, match="Database query failed"):
await user_crud.get_by_email(session, email="test@example.com") await user_repo.get_by_email(session, email="test@example.com")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bulk_update_status_database_error( async def test_bulk_update_status_database_error(
@@ -639,7 +640,7 @@ class TestUserExceptionHandlers:
): ):
with patch.object(session, "rollback", new_callable=AsyncMock): with patch.object(session, "rollback", new_callable=AsyncMock):
with pytest.raises(Exception, match="Bulk update failed"): with pytest.raises(Exception, match="Bulk update failed"):
await user_crud.bulk_update_status( await user_repo.bulk_update_status(
session, user_ids=[async_test_user.id], is_active=False session, user_ids=[async_test_user.id], is_active=False
) )
@@ -659,6 +660,6 @@ class TestUserExceptionHandlers:
): ):
with patch.object(session, "rollback", new_callable=AsyncMock): with patch.object(session, "rollback", new_callable=AsyncMock):
with pytest.raises(Exception, match="Bulk delete failed"): with pytest.raises(Exception, match="Bulk delete failed"):
await user_crud.bulk_soft_delete( await user_repo.bulk_soft_delete(
session, user_ids=[async_test_user.id] session, user_ids=[async_test_user.id]
) )

View File

@@ -10,6 +10,7 @@ from app.core.auth import (
get_password_hash, get_password_hash,
verify_password, verify_password,
) )
from app.core.exceptions import DuplicateError
from app.models.user import User from app.models.user import User
from app.schemas.users import Token, UserCreate from app.schemas.users import Token, UserCreate
from app.services.auth_service import AuthenticationError, AuthService from app.services.auth_service import AuthenticationError, AuthService
@@ -152,9 +153,9 @@ class TestAuthServiceUserCreation:
last_name="User", last_name="User",
) )
# Should raise AuthenticationError # Should raise DuplicateError for duplicate email
async with AsyncTestingSessionLocal() as session: async with AsyncTestingSessionLocal() as session:
with pytest.raises(AuthenticationError): with pytest.raises(DuplicateError):
await AuthService.create_user(db=session, user_data=user_data) await AuthService.create_user(db=session, user_data=user_data)

View File

@@ -269,18 +269,18 @@ class TestClientValidation:
async def test_validate_client_legacy_sha256_hash( async def test_validate_client_legacy_sha256_hash(
self, db, confidential_client_legacy_hash self, db, confidential_client_legacy_hash
): ):
"""Test validating a client with legacy SHA-256 hash (backward compatibility).""" """Test that legacy SHA-256 hash is rejected with clear error message."""
client, secret = confidential_client_legacy_hash client, secret = confidential_client_legacy_hash
validated = await service.validate_client(db, client.client_id, secret) with pytest.raises(service.InvalidClientError, match="deprecated hash format"):
assert validated.client_id == client.client_id await service.validate_client(db, client.client_id, secret)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_validate_client_legacy_sha256_wrong_secret( async def test_validate_client_legacy_sha256_wrong_secret(
self, db, confidential_client_legacy_hash self, db, confidential_client_legacy_hash
): ):
"""Test legacy SHA-256 client rejects wrong secret.""" """Test that legacy SHA-256 client with wrong secret is rejected."""
client, _ = confidential_client_legacy_hash client, _ = confidential_client_legacy_hash
with pytest.raises(service.InvalidClientError, match="Invalid client secret"): with pytest.raises(service.InvalidClientError, match="deprecated hash format"):
await service.validate_client(db, client.client_id, "wrong_secret") await service.validate_client(db, client.client_id, "wrong_secret")
def test_validate_redirect_uri_success(self, public_client): def test_validate_redirect_uri_success(self, public_client):

View File

@@ -11,7 +11,8 @@ from uuid import uuid4
import pytest import pytest
from app.core.exceptions import AuthenticationError from app.core.exceptions import AuthenticationError
from app.crud.oauth import oauth_account, oauth_state from app.repositories.oauth_account import oauth_account_repo as oauth_account
from app.repositories.oauth_state import oauth_state_repo as oauth_state
from app.schemas.oauth import OAuthAccountCreate, OAuthStateCreate from app.schemas.oauth import OAuthAccountCreate, OAuthStateCreate
from app.services.oauth_service import OAUTH_PROVIDERS, OAuthService from app.services.oauth_service import OAUTH_PROVIDERS, OAuthService

Some files were not shown because too many files have changed in this diff Show More