Compare commits

...

120 Commits

Author SHA1 Message Date
Felipe Cardoso
a94e29d99c chore(frontend): remove unnecessary newline in overrides field of package.json 2026-03-01 19:40:11 +01:00
Felipe Cardoso
81e48c73ca fix(tests): handle missing schemathesis gracefully in API contract tests
- Replaced `pytest.mark.skipif` with `pytest.skip` to better manage scenarios where `schemathesis` is not installed.
- Added a fallback test function to ensure explicit handling for missing dependencies.
2026-03-01 19:32:49 +01:00
Felipe Cardoso
a3f78dc801 refactor(tests): replace crud references with repo across repository test files
- Updated import statements and test logic to align with `repositories` naming changes.
- Adjusted documentation and test names for consistency with the updated naming convention.
- Improved test descriptions to reflect the repository-based structure.
2026-03-01 19:22:16 +01:00
Felipe Cardoso
07309013d7 chore(frontend): update scripts and docs to use bun run test for consistency
- Replaced `bun test` with `bun run test` in all documentation and scripts for uniformity.
- Removed outdated `glob` override in package configurations.
2026-03-01 18:44:48 +01:00
Felipe Cardoso
846fc31190 feat(api): enhance KeyMap and FieldsConfig handling for improved flexibility
- Added support for unmapped fields in `KeyMap` definitions and parsing.
- Updated `buildKeyMap` to allow aliasing keys without transport layer mappings.
- Improved parameter assignment logic to handle optional `in` mappings.
- Enhanced handling of `allowExtra` fields for more concise and robust configurations.
2026-03-01 18:01:34 +01:00
Felipe Cardoso
ff7a67cb58 chore(frontend): migrate from npm to Bun for dependency management and scripts
- Updated README to replace npm commands with Bun equivalents.
- Added `bun.lock` file to track Bun-managed dependencies.
2026-03-01 18:00:43 +01:00
Felipe Cardoso
0760a8284d feat(tests): add comprehensive benchmarks for auth and performance-critical endpoints
- Introduced benchmarks for password hashing, verification, and JWT token operations.
- Added latency tests for `/register`, `/refresh`, `/sessions`, and `/users/me` endpoints.
- Updated `BENCHMARKS.md` with new tests, thresholds, and execution details.
2026-03-01 17:01:44 +01:00
Felipe Cardoso
ce4d0c7b0d feat(backend): enhance performance benchmarking with baseline detection and documentation
- Updated `make benchmark-check` in Makefile to detect and handle missing baselines, creating them if not found.
- Added `.benchmarks` directory to `.gitignore` for local baseline exclusions.
- Linked benchmarking documentation in `ARCHITECTURE.md` and added comprehensive `BENCHMARKS.md` guide.
2026-03-01 16:30:06 +01:00
Felipe Cardoso
4ceb8ad98c feat(backend): add performance benchmarks and API security tests
- Introduced `benchmark`, `benchmark-save`, and `benchmark-check` Makefile targets for performance testing.
- Added API security fuzzing through the `test-api-security` Makefile target, leveraging Schemathesis.
- Updated Dockerfiles to use Alpine for security and CVE mitigation.
- Enhanced security with `scan-image` and `scan-images` targets for Docker image vulnerability scanning via Trivy.
- Integrated `pytest-benchmark` for performance regression detection, with tests for key API endpoints.
- Extended `uv.lock` and `pyproject.toml` to include performance benchmarking dependencies.
2026-03-01 16:16:18 +01:00
Felipe Cardoso
f8aafb250d fix(backend): suppress license-check output in Makefile for cleaner logs
- Redirect pip-licenses output to `/dev/null` to reduce noise during license checks.
- Retain success and compliance messages for clear feedback.
2026-03-01 14:24:22 +01:00
Felipe Cardoso
4385d20ca6 fix(tests): simplify invalid token test logic in test_auth_security.py
- Removed unnecessary try-except block for JWT encoding failures.
- Adjusted test to directly verify `TokenInvalidError` during decoding.
- Clarified comment on HMAC algorithm compatibility (`HS384` vs. `HS256`).
2026-03-01 14:24:17 +01:00
Felipe Cardoso
1a36907f10 refactor(backend): replace python-jose and passlib with PyJWT and bcrypt for security and simplicity
- Migrated JWT token handling from `python-jose` to `PyJWT`, reducing dependencies and improving error clarity.
- Replaced `passlib` bcrypt integration with direct `bcrypt` usage for password hashing.
- Updated `Makefile`, removing unused CVE ignore based on the replaced dependencies.
- Reflected changes in `ARCHITECTURE.md` and adjusted function headers in `auth.py`.
- Cleaned up `uv.lock` and `pyproject.toml` to remove unused dependencies (`ecdsa`, `rsa`, etc.) and add `PyJWT`.
- Refactored tests and services to align with the updated libraries (`PyJWT` error handling, decoding, and validation).
2026-03-01 14:02:04 +01:00
Felipe Cardoso
0553a1fc53 refactor(logging): switch to parameterized logging for improved performance and clarity
- Replaced f-strings with parameterized logging calls across routes, services, and repositories to optimize log message evaluation.
- Improved exception handling by using `logger.exception` where appropriate for automatic traceback logging.
2026-03-01 13:38:15 +01:00
Felipe Cardoso
57e969ed67 chore(backend): extend Makefile with audit, validation, and security targets
- Added `dep-audit`, `license-check`, `audit`, `validate-all`, and `check` targets for security and quality checks.
- Updated `.PHONY` to include new targets.
- Enhanced `help` command documentation with descriptions of the new commands.
- Updated `ARCHITECTURE.md`, `CLAUDE.md`, and `uv.lock` to reflect related changes. Upgraded dependencies where necessary.
2026-03-01 12:03:34 +01:00
Felipe Cardoso
68275b1dd3 refactor(docs): update architecture to reflect repository migration
- Rename CRUD layer to Repository layer throughout architecture documentation.
- Update dependency injection examples to use repository classes.
- Add async SQLAlchemy pattern for Repository methods (`select()` and transactions).
- Replace CRUD references in FEATURE_EXAMPLE.md with Repository-focused implementation details.
- Highlight repository class responsibilities and remove outdated CRUD patterns.
2026-03-01 11:13:51 +01:00
Felipe Cardoso
80d2dc0cb2 fix(backend): clear VIRTUAL_ENV before invoking pyright
Prevents a spurious warning when the shell's VIRTUAL_ENV points to a
different project's venv. Pyright detects the mismatch and warns; clearing
the variable inline forces pyright to resolve the venv from pyrightconfig.json.
2026-02-28 19:48:33 +01:00
Felipe Cardoso
a8aa416ecb refactor(backend): migrate type checking from mypy to pyright
Replace mypy>=1.8.0 with pyright>=1.1.390. Remove all [tool.mypy] and
[tool.pydantic-mypy] sections from pyproject.toml and add
pyrightconfig.json (standard mode, SQLAlchemy false-positive rules
suppressed globally).

Fixes surfaced by pyright:
- Remove unreachable except AuthError clauses in login/login_oauth (same class as AuthenticationError)
- Fix Pydantic v2 list Field: min_items/max_items → min_length/max_length
- Split OAuthProviderConfig TypedDict into required + optional(email_url) inheritance
- Move JWTError/ExpiredSignatureError from lazy try-block imports to module level
- Add timezone-aware guard to UserSession.is_expired to match sibling models
- Fix is_active: bool → bool | None in three organization repo signatures
- Initialize search_filter = None before conditional block (possibly unbound fix)
- Add bool() casts to model is_expired and repo is_active/is_superuser returns
- Restructure except (JWTError, Exception) into separate except clauses
2026-02-28 19:12:40 +01:00
Felipe Cardoso
4c6bf55bcc Refactor(backend): improve formatting in services, repositories & tests
- Consistently format multi-line function headers, exception handling, and repository method calls for readability.
- Reorganize misplaced imports across modules (e.g., services & tests) into proper sorted order.
- Adjust indentation, line breaks, and spacing inconsistencies in tests and migration files.
- Cleanup unnecessary trailing newlines and reorganize `__all__` declarations for consistency.
2026-02-28 18:37:56 +01:00
Felipe Cardoso
98b455fdc3 refactor(backend): enforce route→service→repo layered architecture
- introduce custom repository exception hierarchy (DuplicateEntryError,
  IntegrityConstraintError, InvalidInputError) replacing raw ValueError
- eliminate all direct repository imports and raw SQL from route layer
- add UserService, SessionService, OrganizationService to service layer
- add get_stats/get_org_distribution service methods replacing admin inline SQL
- fix timing side-channel in authenticate_user via dummy bcrypt check
- replace SHA-256 client secret fallback with explicit InvalidClientError
- replace assert with InvalidGrantError in authorization code exchange
- replace N+1 token revocation loops with bulk UPDATE statements
- rename oauth account token fields (drop misleading 'encrypted' suffix)
- add Alembic migration 0003 for token field column rename
- add 45 new service/repository tests; 975 passing, 94% coverage
2026-02-27 09:32:57 +01:00
Felipe Cardoso
0646c96b19 Add semicolons to mockServiceWorker.js for consistent style compliance
- Updated `mockServiceWorker.js` by adding missing semicolons across the file for improved code consistency and adherence to style guidelines.
- Refactored multi-line logical expressions into single-line where applicable, maintaining readability.
2026-01-01 13:21:31 +01:00
Felipe Cardoso
62afb328fe Upgrade dependencies in package-lock.json
- Upgraded various dependencies across `@esbuild`, `@eslint`, `@hey-api`, and `@img` packages to their latest versions.
- Removed unused `json5` dependency under `@babel/core`.
- Ensured integrity hashes are updated to reflect changes.
2026-01-01 13:21:23 +01:00
Felipe Cardoso
b9a746bc16 Refactor component props formatting for consistency in extends usage across UI and documentation files 2026-01-01 13:19:36 +01:00
Felipe Cardoso
de8e18e97d Update GitHub repository URLs across components and tests
- Replaced all occurrences of the previous repository URL (`your-org/fast-next-template`) with the updated repository URL (`cardosofelipe/pragma-stack.git`) in both frontend components and test files.
- Adjusted related test assertions and documentation links accordingly.
2026-01-01 13:15:08 +01:00
Felipe Cardoso
a3e557d022 Update E2E test for security headers to include worker-src validation 2025-12-26 19:00:18 +01:00
Felipe Cardoso
4e357db25d Update E2E test for security headers to include worker-src validation 2025-12-26 19:00:11 +01:00
Felipe Cardoso
568aad3673 Add E2E tests for security headers
- Implemented tests to verify OWASP-compliant security headers, including X-Frame-Options, X-Content-Type-Options, Referrer-Policy, Permissions-Policy, and Content-Security-Policy.
- Ensured deprecated headers like X-XSS-Protection are not set.
- Validated security headers across multiple routes.
- Updated Playwright configuration to include the new test suite.
2025-12-10 14:53:40 +01:00
Felipe Cardoso
ddcf926158 Add OWASP-compliant security headers to Next.js configuration
- Implemented security headers following OWASP 2025 recommendations, including X-Frame-Options, X-Content-Type-Options, Referrer-Policy, Permissions-Policy, and Content-Security-Policy.
- Applied headers globally across all routes for enhanced security.
2025-12-10 13:55:15 +01:00
Felipe Cardoso
865eeece58 Update dependencies in package-lock.json
- Upgraded multiple packages including `@next/*`, `next`, `js-yaml`, `glob`, and `mdast-util-to-hast` to ensure compatibility and enhance performance.
- Addressed potential security and functionality improvements with newer versions.
2025-12-10 11:19:59 +01:00
Felipe Cardoso
05fb3612f9 Update README header and visuals
- Reorganized the README header for improved branding and clarity.
- Added landing page preview to enhance documentation visuals.
2025-11-27 19:30:09 +01:00
Felipe Cardoso
1b2e7dde35 Refactor OAuth divider component and update README visuals
- Simplified the OAuth divider component with a cleaner layout for improved UI consistency.
- Updated README to include and organize new visuals for key sections, enhancing documentation clarity.
2025-11-27 19:07:28 +01:00
Felipe Cardoso
29074f26a6 Remove outdated documentation files
- Deleted `I18N_IMPLEMENTATION_PLAN.md` and `PROJECT_PROGRESS.md` to declutter the repository.
- These documents were finalized, no longer relevant, and superseded by implemented features and external references.
2025-11-27 18:55:29 +01:00
Felipe Cardoso
77ed190310 Add Makefile targets for database management and improve dev/production workflows
- Introduced `drop-db` and `reset-db` targets for streamlined database operations, including database recreation and migration applications.
- Added `help` target to document available Makefile commands for both development and production environments.
- Expanded Makefile with new targets like `push-images` and `deploy` to enhance production deployment workflows.
- Consolidated redundant code and added descriptions for improved maintainability and user experience.
2025-11-27 10:52:30 +01:00
Felipe Cardoso
2bbe925cef Clean up Alembic migrations
- Removed outdated and redundant Alembic migration files to streamline the migration directory. This improves maintainability and eliminates duplicate or unused scripts.
2025-11-27 09:12:30 +01:00
Felipe Cardoso
4a06b96b2e Update tests to reflect OAuth 2.0 and i18n features
- Replaced outdated assertions with updated content for 'OAuth 2.0 + i18n Ready' across HeroSection, Key Features, and E2E tests.
- Updated TechStack tests to validate inclusion of `next-intl` and `pytest`.
- Refined badge and feature test cases to align with OAuth and internationalization updates.
2025-11-27 07:33:57 +01:00
Felipe Cardoso
088c1725b0 Update ContextSection and TechStackSection with OAuth 2.0 and i18n features
- Replaced outdated features with 'OAuth 2.0 + Social Login' and 'i18n Ready' in ContextSection.
- Updated TechStackSection to include OAuth 2.0 (social login + provider mode) and next-intl (English, Italian) support.
- Refined descriptions in FeatureGrid and HeroSection to highlight new features.
- Improved messaging around OAuth and internationalization readiness across components.
2025-11-26 14:44:12 +01:00
Felipe Cardoso
7ba1767cea Refactor E2E tests for OAuth provider workflows
- Renamed unused `code_verifier` variables to `_code_verifier` for clarity.
- Improved test readability by reformatting long lines and assertions.
- Streamlined `get` request calls by consolidating parameters into single lines.
2025-11-26 14:10:25 +01:00
Felipe Cardoso
c63b6a4f76 Add E2E tests for OAuth consent page workflows
- Added tests for OAuth consent page covering parameter validation, unauthenticated user redirects, authenticated user interactions, scope management, and consent API calls.
- Verified behaviors such as error handling, toggling scopes, loading states, and authorize/deny actions.
- Updated utility methods with `loginViaUI` for improved test setup.
2025-11-26 14:06:36 +01:00
Felipe Cardoso
803b720530 Add comprehensive E2E tests for OAuth provider workflows
- Introduced E2E test coverage for OAuth Provider mode, covering metadata discovery, client management, authorization flows, token operations, consent management, and security checks.
- Verified PKCE enforcement, consent submission, token rotation, and introspection.
- Expanded fixtures and utility methods for testing real OAuth scenarios with PostgreSQL via Testcontainers.
2025-11-26 14:06:20 +01:00
Felipe Cardoso
7ff00426f2 Add detailed OAuth documentation and configuration examples
- Updated `ARCHITECTURE.md` with thorough explanations of OAuth Consumer and Provider modes, supported flows, security features, and endpoints.
- Enhanced `.env.template` with environment variables for OAuth Provider mode setup.
- Expanded `README.md` to highlight OAuth Provider mode capabilities and MCP integration features.
- Added OAuth configuration section to `AGENTS.md`, including key settings for both social login and provider mode.
2025-11-26 13:38:55 +01:00
Felipe Cardoso
b3f0dd4005 Add full OAuth provider functionality and enhance flows
- Implemented OAuth 2.0 Authorization Server endpoints per RFCs, including token, introspection, revocation, and metadata discovery.
- Added user consent submission, listing, and revocation APIs alongside frontend integration for improved UX.
- Enforced stricter OAuth security measures (PKCE, state validation, scopes).
- Refactored schemas and services for consistency and expanded coverage of OAuth workflows.
- Updated documentation and type definitions for new API behaviors.
2025-11-26 13:23:44 +01:00
Felipe Cardoso
707315facd Suppress jsdom XMLHttpRequest errors in Jest tests
- Added `jest.environment.js` to create a custom Jest environment that filters out harmless XMLHttpRequest errors from jsdom's VirtualConsole.
- Updated `jest.config.js` to use the custom environment, reducing noisy test outputs.
2025-11-26 11:23:56 +01:00
Felipe Cardoso
38114b79f9 Mark OAuth consent page as excluded from unit tests 2025-11-26 09:52:47 +01:00
Felipe Cardoso
1cb3658369 Exclude email from user update payload in UserFormDialog 2025-11-26 09:47:10 +01:00
Felipe Cardoso
dc875c5c95 Enhance OAuth security, PKCE, and state validation
- Enforced stricter PKCE requirements by rejecting insecure 'plain' method for public clients.
- Transitioned client secret hashing to bcrypt for improved security and migration compatibility.
- Added constant-time comparison for state parameter validation to prevent timing attacks.
- Improved error handling and logging for OAuth workflows, including malformed headers and invalid scopes.
- Upgraded Google OIDC token validation to verify both signature and nonce.
- Refactored OAuth service methods and schemas for better readability, consistency, and compliance with RFC specifications.
2025-11-26 00:14:26 +01:00
Felipe Cardoso
0ea428b718 Refactor tests for improved readability and fixture consistency
- Reformatted headers in E2E tests to improve readability and ensure consistent style.
- Updated confidential client fixture to use bcrypt for secret hashing, enhancing security and testing backward compatibility with legacy SHA-256 hashes.
- Added new test cases for PKCE verification, rejecting insecure 'plain' methods, and improved error handling.
- Refined session workflows and user agent handling in E2E tests for session management.
- Consolidated schema operation tests and fixed minor formatting inconsistencies.
2025-11-26 00:13:53 +01:00
Felipe Cardoso
400d6f6f75 Enhance OAuth security and state validation
- Implemented stricter OAuth security measures, including CSRF protection via state parameter validation and redirect_uri checks.
- Updated OAuth models to support timezone-aware datetime comparisons, replacing deprecated `utcnow`.
- Enhanced logging for malformed Basic auth headers during token, introspect, and revoke requests.
- Added allowlist validation for OAuth provider domains to prevent open redirect attacks.
- Improved nonce validation for OpenID Connect tokens, ensuring token integrity during Google provider flows.
- Updated E2E and unit tests to cover new security features and expanded OAuth state handling scenarios.
2025-11-25 23:50:43 +01:00
Felipe Cardoso
7716468d72 Add E2E tests for admin and organization workflows
- Introduced E2E tests for admin user and organization management workflows: user listing, creation, updates, bulk actions, and organization membership management.
- Added comprehensive tests for organization CRUD operations, membership visibility, roles, and permission validation.
- Expanded fixtures for superuser and member setup to streamline testing of admin-specific operations.
- Verified pagination, filtering, and action consistency across admin endpoints.
2025-11-25 23:50:02 +01:00
Felipe Cardoso
48f052200f Add OAuth provider mode and MCP integration
- Introduced full OAuth 2.0 Authorization Server functionality for MCP clients.
- Updated documentation with details on endpoints, scopes, and consent management.
- Added a new frontend OAuth consent page for user authorization flows.
- Implemented database models for authorization codes, refresh tokens, and user consents.
- Created unit tests for service methods (PKCE verification, client validation, scope handling).
- Included comprehensive integration tests for OAuth provider workflows.
2025-11-25 23:18:19 +01:00
Felipe Cardoso
fbb030da69 Add E2E workflow tests for organizations, users, sessions, and API contracts
- Introduced comprehensive E2E tests for organization workflows: creation, membership management, and updates.
- Added tests for user management workflows: profile viewing, updates, password changes, and settings.
- Implemented session management tests, including listing, revocation, multi-device handling, and cleanup.
- Included API contract validation tests using Schemathesis, covering protected endpoints and schema structure.
- Enhanced E2E testing infrastructure with full PostgreSQL support and detailed workflow coverage.
2025-11-25 23:13:28 +01:00
Felipe Cardoso
d49f819469 Expand OAuth documentation and roadmap details
- Updated `README.md` to include OAuth/Social Login (Google, GitHub) with PKCE support under Authentication section.
- Adjusted roadmap and status sections in documentation to reflect completed OAuth/social login implementation.
- Clarified future plans by replacing "Additional authentication methods (OAuth, SSO)" with "SSO/SAML authentication".
2025-11-25 22:28:53 +01:00
Felipe Cardoso
507f2e9c00 Refactor E2E tests and fixtures for improved readability and consistency
- Reformatted assertions in `test_database_workflows.py` for better readability.
- Refactored `postgres_url` transformation logic in `conftest.py` for improved clarity.
- Adjusted import handling in `test_api_contracts.py` to streamline usage of Hypothesis and Schemathesis libraries.
2025-11-25 22:27:11 +01:00
Felipe Cardoso
c0b253a010 Add support for E2E testing infrastructure and OAuth configurations
- Introduced make commands for E2E tests using Testcontainers and Schemathesis.
- Updated `.env.demo` with configurable OAuth settings for Google and GitHub.
- Enhanced `README.md` with updated environment setup instructions.
- Added E2E testing dependencies and markers in `pyproject.toml` for real PostgreSQL and API contract validation.
- Included new libraries (`arrow`, `attrs`, `docker`, etc.) for testing and schema validation workflows.
2025-11-25 22:24:23 +01:00
Felipe Cardoso
fcbcff99e9 Add E2E tests for OAuth authentication flows and provider integrations
- Implemented comprehensive E2E tests for OAuth buttons on login and register pages, including Google and GitHub provider interactions.
- Verified OAuth provider buttons' visibility, icons, and proper API integration with mock endpoints.
- Added button interaction tests to ensure correct API calls for authorization and state handling.
- Updated `playwright.config.ts` to include the new `auth-oauth.spec.ts` in test configurations.
- Extended mock handlers in `overrides.ts` and `auth.ts` to support OAuth-specific API workflows and demo scenarios.
2025-11-25 10:40:37 +01:00
Felipe Cardoso
b49678b7df Add E2E tests for authentication flows and admin user management
- Implemented comprehensive E2E tests for critical authentication flows, including login, session management, and logout workflows.
- Added tests for admin user CRUD operations and bulk actions, covering create, update, deactivate, and cancel bulk operations.
- Updated `auth.ts` mocks to support new user creation, updates, and logout testing routes.
- Refactored skipped tests in `settings-profile.spec.ts` and `settings-password.spec.ts` with detailed rationale for omission (e.g., `react-hook-form` state handling limitations).
- Introduced `auth-flows.spec.ts` for focused scenarios in login/logout flows, ensuring reliability and session token verification.
2025-11-25 09:36:42 +01:00
Felipe Cardoso
aeed9dfdbc Add unit tests for OAuthButtons and LinkedAccountsSettings components
- Introduced comprehensive test coverage for `OAuthButtons` and `LinkedAccountsSettings`, including loading states, button behaviors, error handling, and custom class support.
- Implemented `LinkedAccountsPage` tests for rendering and component integration.
- Adjusted E2E coverage exclusions in various components, focusing on UI-heavy and animation-based flows best suited for E2E tests.
- Refined Jest coverage thresholds to align with improved unit test additions.
2025-11-25 08:52:11 +01:00
Felipe Cardoso
13f617828b Add comprehensive tests for OAuth callback flows and update pyproject.toml
- Extended OAuth callback tests to cover various scenarios (e.g., account linking, user creation, inactive users, and token/user info failures).
- Added `app/init_db.py` to the excluded files in `pyproject.toml`.
2025-11-25 08:26:41 +01:00
Felipe Cardoso
84e0a7fe81 Add OAuth flows and UI integration
- Implemented OAuth endpoints (providers list, authorization, callback, linked accounts management).
- Added UI translations for OAuth workflows (auth process messages, linked accounts management).
- Extended TypeScript types and React hooks to support OAuth features.
- Updated app configuration with OAuth-specific settings and provider details.
- Introduced skeleton implementations for authorization and token endpoints in provider mode.
- Included unit test and integration hooks for OAuth capabilities.
2025-11-25 07:59:20 +01:00
Felipe Cardoso
063a35e698 Fix permissions 2025-11-25 01:20:29 +01:00
Felipe Cardoso
a2246fb6e1 Kindly provide the git diff content for an accurate commit message recommendation. 2025-11-25 01:13:40 +01:00
Felipe Cardoso
16ee4e0cb3 Initial implementation of OAuth models, endpoints, and migrations
- Added models for `OAuthClient`, `OAuthState`, and `OAuthAccount`.
- Created Pydantic schemas to support OAuth flows, client management, and linked accounts.
- Implemented skeleton endpoints for OAuth Provider mode: authorization, token, and revocation.
- Updated router imports to include new `/oauth` and `/oauth/provider` routes.
- Added Alembic migration script to create OAuth-related database tables.
- Enhanced `users` table to allow OAuth-only accounts by making `password_hash` nullable.
2025-11-25 00:37:23 +01:00
Felipe Cardoso
e6792c2d6c Update settings-sessions.spec.ts to clarify E2E test skip reason
- Revised the skip rationale to highlight API mocking race condition as the cause.
- Updated documentation with feature status, including production readiness and comprehensive unit test coverage.
2025-11-24 21:57:52 +01:00
Felipe Cardoso
1d20b149dc Refactor e2e tests for clarity and skip outdated cases
- Improved `auth-guard.spec.ts` test formatting for readability by adjusting destructuring syntax.
- Updated `settings-sessions.spec.ts` to note feature completion and skipped tests pending auth storage debugging.
- Removed outdated and redundant test cases from `homepage.spec.ts` to streamline coverage.
- Enabled and updated assertion in `settings-password.spec.ts` to check updated heading for password change form.
2025-11-24 21:38:23 +01:00
Felipe Cardoso
570848cc2d Refactor e2e tests for improved reliability and consistency
- Updated `auth-guard.spec.ts` to configure localStorage before navigation using `context.addInitScript`.
- Enhanced test stability with explicit `waitForLoadState` calls after page reloads.
- Refactored `admin-dashboard.spec.ts` for more descriptive test names aligning with chart updates. Adjusted lazy-loading behavior in the analytics section.
- Reworked `homepage.spec.ts` tests to improve headline and badge visibility checks. Added scroll-triggered animation handling for stats section.
- Enhanced MSW handler in `auth.ts` with mock data for user growth and registration activity charts. Added organization and user status distribution data.
2025-11-24 20:55:04 +01:00
Felipe Cardoso
6b970765ba Refactor components and scripts for improved typing, cleanup unused imports
- Updated chart components (`OrganizationDistributionChart`, `RegistrationActivityChart`, `UserGrowthChart`) with stricter TypeScript interfaces (`TooltipProps`).
- Removed unused imports (`useState`, `Badge`, `API_BASE_URL`) from `DemoModeBanner` and MSW scripts.
- Adjusted MSW function parameters (`_method`, `_operation`) to suppress unused variable warnings.
2025-11-24 20:30:58 +01:00
Felipe Cardoso
e79215b4de Refactor tests, documentation, and component code for consistent formatting and improved readability
- Reformatted test files (`RegistrationActivityChart.test.tsx`, `DemoCredentialsModal.test.tsx`) for indentation consistency.
- Reduced inline style verbosity across components and docs (`DemoModeBanner`, `CodeBlock`, `MarkdownContent`).
- Enhanced Markdown documentation (`sync-msw-with-openapi.md`, `MSW_AUTO_GENERATION.md`) with spacing updates for improved clarity.
- Updated MSW configuration to simplify locale route handling in `browser.ts`.
2025-11-24 20:25:40 +01:00
Felipe Cardoso
3bf28aa121 Override MSW handlers to support custom authentication workflows
- Added mock handlers for `login`, `register`, and `refresh` endpoints with realistic network delay.
- Implemented JWT token generation utilities to simulate authentication flows.
- Enhanced handler configurations for user data validation and session management.
2025-11-24 20:23:15 +01:00
Felipe Cardoso
cda9810a7e Add auto-generated MSW handlers for API endpoints
- Created `generated.ts` to include handlers for all endpoints defined in the OpenAPI specification.
- Simplified demo mode setup by centralizing auto-generated MSW configurations.
- Added handling for authentication, user, organization, and admin API endpoints.
- Included support for realistic network delay simulation and demo session management.
2025-11-24 19:52:40 +01:00
Felipe Cardoso
d47bd34a92 Add comprehensive tests for RegistrationActivityChart and update empty state assertions
- Added new test suite for `RegistrationActivityChart` covering rendering, loading, empty, and error states.
- Updated existing chart tests (`UserStatusChart`, `OrganizationDistributionChart`, `UserGrowthChart`) to assert correct empty state messages.
- Replaced `SessionActivityChart` references in admin tests with `RegistrationActivityChart`.
2025-11-24 19:49:41 +01:00
Felipe Cardoso
5b0ae54365 Remove MSW handlers and update demo credentials for improved standardization
- Deleted `admin.ts`, `auth.ts`, and `users.ts` MSW handler files to streamline demo mode setup.
- Updated demo credentials logic in `DemoCredentialsModal` and `DemoModeBanner` for stronger password requirements (≥12 characters).
- Refined documentation in `CLAUDE.md` to align with new credential standards and auto-generated MSW workflows.
2025-11-24 19:20:28 +01:00
Felipe Cardoso
372af25aaa Refactor Markdown rendering and code blocks styling
- Enhanced Markdown heading hierarchy with subtle anchors and improved spacing.
- Improved styling for links, blockquotes, tables, and horizontal rules using reusable components (`Alert`, `Badge`, `Table`, `Separator`).
- Standardized code block background, button transitions, and copy-to-clipboard feedback.
- Refined readability and visual hierarchy of text elements across Markdown content.
2025-11-24 18:58:01 +01:00
Felipe Cardoso
d0b717a128 Enhance demo mode credential validation and refine MSW configuration
- Updated demo credential logic to accept any password ≥8 characters for improved UX.
- Improved MSW configuration to ignore non-API requests and warn only for unhandled API calls.
- Adjusted `DemoModeBanner` to reflect updated password requirements for demo credentials.
2025-11-24 18:54:05 +01:00
Felipe Cardoso
9d40aece30 Refactor chart components for improved formatting and import optimization
- Consolidated `recharts` imports for `BarChart`, `AreaChart`, and `LineChart` components.
- Reformatted inline styles for tooltips and axis elements to enhance readability and maintain consistency.
- Applied minor cleanups for improved project code styling.
2025-11-24 18:42:13 +01:00
Felipe Cardoso
487c8a3863 Add demo mode support with MSW integration and documentation
- Integrated Mock Service Worker (MSW) for frontend-only demo mode, allowing API call interception without requiring a backend.
- Added `DemoModeBanner` component to indicate active demo mode and display demo credentials.
- Enhanced configuration with `DEMO_MODE` flag and demo credentials for user and admin access.
- Updated ESLint configuration to exclude MSW-related files from linting and coverage.
- Created comprehensive `DEMO_MODE.md` documentation for setup and usage guidelines, including deployment instructions and troubleshooting.
- Updated package dependencies to include MSW and related libraries.
2025-11-24 18:42:05 +01:00
Felipe Cardoso
8659e884e9 Refactor code formatting and suppress security warnings
- Reformatted dicts, loops, and logger calls for improved readability and consistency.
- Suppressed `bandit` warnings (`# noqa: S311`) for non-critical random number generation in demo data.
2025-11-24 17:58:26 +01:00
Felipe Cardoso
a05def5906 Add registration_activity chart and enhance admin statistics
- Introduced `RegistrationActivityChart` to display user registration trends over 14 days.
- Enhanced `AdminStatsResponse` with `registration_activity`, providing improved insights for admin users.
- Updated demo data to include realistic registration activity and organization details.
- Refactored admin page to use updated statistics data model and improved query handling.
- Fixed inconsistent timezone handling in statistical analytics and demo user timestamps.
2025-11-24 17:42:43 +01:00
Felipe Cardoso
9f655913b1 Add adminGetStats API and extend statistics types for admin dashboard
- Introduced `adminGetStats` API endpoint for fetching aggregated admin dashboard statistics.
- Expanded `AdminStatsResponse` to include `registration_activity` and new type definitions for `UserGrowthData`, `OrgDistributionData`, and `UserStatusData`.
- Added `AdminGetStatsData` and `AdminGetStatsResponses` types to improve API integration consistency.
- Updated client generation and type annotations to support the new endpoint structure.
2025-11-24 16:28:59 +01:00
Felipe Cardoso
13abd159fa Remove deprecated middleware and update component tests for branding and auth enhancements
- Deleted `middleware.disabled.ts` as it is no longer needed.
- Refactored `HeroSection` and `HomePage` tests to align with updated branding and messaging.
- Modified `DemoCredentialsModal` to support auto-filled demo credentials in login links.
- Mocked `ThemeToggle`, `LocaleSwitcher`, and `DemoCredentialsModal` in relevant tests.
- Updated admin tests to use `QueryClientProvider` and refactored API mocks for `AdminPage`.
- Replaced test assertions for stats section and badges with new branding content.
2025-11-24 15:04:49 +01:00
Felipe Cardoso
acfe59c8b3 Refactor admin stats API and charts data models for consistency
- Updated `AdminStatsResponse` with streamlined type annotations and added `AdminStatsData` type definition.
- Renamed chart data model fields (`totalUsers` → `total_users`, `activeUsers` → `active_users`, `members` → `value`, etc.) for alignment with backend naming conventions.
- Adjusted related test files to reflect updated data model structure.
- Improved readability of `AdminPage` component by reformatting destructuring in `useQuery`.
2025-11-24 12:44:45 +01:00
Felipe Cardoso
2e4700ae9b Refactor user growth chart data model and enhance demo user creation
- Renamed `totalUsers` and `activeUsers` to `total_users` and `active_users` across frontend and backend for consistency.
- Enhanced demo user creation by randomizing `created_at` dates for realistic charts.
- Expanded demo data to include `is_active` for demo users, improving user status representation.
- Refined admin dashboard statistics to support updated user growth data model.
2025-11-21 14:15:05 +01:00
Felipe Cardoso
8c83e2a699 Add comprehensive demo data loading logic and .env.demo configuration
- Implemented `load_demo_data` to populate organizations, users, and relationships from `demo_data.json`.
- Refactored database initialization to handle demo-specific passwords and multi-entity creation in demo mode.
- Added `demo_data.json` with sample organizations and users for better demo showcase.
- Introduced `.env.demo` to simplify environment setup for demo scenarios.
- Updated `.gitignore` to include `.env.demo` while keeping other `.env` files excluded.
2025-11-21 08:39:07 +01:00
Felipe Cardoso
9b6356b0db Add comprehensive demo data loading logic and .env.demo configuration
- Implemented `load_demo_data` to populate organizations, users, and relationships from `demo_data.json`.
- Refactored database initialization to handle demo-specific passwords and multi-entity creation in demo mode.
- Added `demo_data.json` with sample organizations and users for better demo showcase.
- Introduced `.env.demo` to simplify environment setup for demo scenarios.
- Updated `.gitignore` to include `.env.demo` while keeping other `.env` files excluded.
2025-11-21 08:23:18 +01:00
Felipe Cardoso
a410586cfb Enable demo mode features, auto-fill demo credentials, and enhance branding integration
- Added `DEMO_MODE` to backend configuration with relaxed security support for specific demo accounts.
- Updated password validators to allow predefined weak passwords in demo mode.
- Auto-fill login forms with demo credentials via query parameters for improved demo accessibility.
- Introduced demo user creation logic during database initialization if `DEMO_MODE` is enabled.
- Replaced `img` tags with `next/image` for consistent and optimized visuals in branding elements.
- Refined footer, header, and layout components to incorporate improved logo handling.
2025-11-21 07:42:40 +01:00
Felipe Cardoso
0e34cab921 Add logs and logs-dev targets to Makefile for streamlined log access 2025-11-21 07:32:11 +01:00
Felipe Cardoso
3cf3858fca Update Makefile to refine clean-slate target with explicit dev compose file and orphan removal 2025-11-21 07:25:22 +01:00
Felipe Cardoso
db0c555041 Add ThemeToggle to Header component
- Integrated `ThemeToggle` for light/dark mode functionality in both desktop and mobile views.
- Adjusted layout styles to accommodate new control next to `LocaleSwitcher` with consistent spacing.
2025-11-20 15:16:49 +01:00
Felipe Cardoso
51ad80071a Ensure virtualenv binaries are on PATH in entrypoint script for consistent command execution 2025-11-20 15:16:30 +01:00
Felipe Cardoso
d730ab7526 Update .dockerignore, alembic revision, and entrypoint script for consistency and reliability
- Expanded `.dockerignore` to exclude Python and packaging-related artifacts for cleaner Docker builds.
- Updated Alembic `down_revision` in migration script to reflect correct dependency chain.
- Modified entrypoint script to use `uv` with `--no-project` flag, preventing permission issues in bind-mounted volumes.
2025-11-20 15:12:55 +01:00
Felipe Cardoso
b218be9318 Add logo icon to components and update branding assets
- Integrated `logo-icon.svg` into headers, footer, and development layout for consistent branding.
- Updated `logo.svg` and `logo-icon.svg` with improved gradient and filter naming for clarity.
- Enhanced `README.md` and branding documentation with logo visuals and descriptions.
- Refined visual identity details in docs to emphasize the branding hierarchy and usage.
2025-11-20 14:55:24 +01:00
Felipe Cardoso
e6813c87c3 Add new SVG assets for logo and logo icon
- Introduced `logo.svg` to serve as the primary logo asset with layered design and gradient styling.
- Added `logo-icon.svg` for compact use cases with gradient consistency and simplified structure.
2025-11-20 13:38:42 +01:00
Felipe Cardoso
210204eb7a Revise home page content to align with "PragmaStack" branding
- Updated headers, descriptions, and key messaging across sections for clarity and consistency.
- Replaced outdated stats with branding-focused data, emphasizing open-source, type safety, and documentation quality.
- Refined tone to highlight pragmatic, reliable values over technical metrics.
- Adjusted GitHub icon SVG for accessibility and inline clarity.
2025-11-20 13:16:18 +01:00
Felipe Cardoso
6ad4cda3f4 Refine backend README to align with "PragmaStack" branding and enhance messaging for clarity and engagement. 2025-11-20 13:07:28 +01:00
Felipe Cardoso
54ceaa6f5d Rebrand README to emphasize "PragmaStack" identity and refine messaging for clarity and consistency. 2025-11-20 13:01:11 +01:00
Felipe Cardoso
34e7f69465 Replace "FastNext" references with "PragmaStack" in migration script and configuration settings 2025-11-20 13:01:05 +01:00
Felipe Cardoso
8fdbc2b359 Improve code consistency and documentation readability
- Standardized Markdown formatting across documentation files.
- Fixed inconsistent usage of inline code blocks and improved syntax clarity.
- Updated tests and JSX for cleaner formatting and better readability.
- Adjusted E2E test navigation handlers for multiline code consistency.
- Simplified TypeScript configuration and organized JSON structure for better maintainability.
2025-11-20 12:58:46 +01:00
Felipe Cardoso
28b1cc6e48 Replace "FastNext" branding with "PragmaStack" across the project
- Updated all references, metadata, and templates to reflect the new branding, including layout files, components, and documentation.
- Replaced hardcoded color tokens like `green-600` with semantic tokens (`success`, `warning`, etc.) for improved design consistency.
- Enhanced `globals.css` with new color tokens for success, warning, and destructive states using the OKLCH color model.
- Added comprehensive branding guidelines and updated the design system documentation to align with the new identity.
- Updated tests and mocks to reflect the branding changes and ensured all visual/verbal references match "PragmaStack".
- Added new `branding/README.md` and `branding` docs for mission, values, and visual identity definition.
2025-11-20 12:55:30 +01:00
Felipe Cardoso
5a21847382 Update to Next.js 16 and enhance ESLint configuration
- Migrated from Next.js 15 to Next.js 16, updating all related dependencies and configurations.
- Enhanced ESLint setup with stricter rules, expanded plugin support, and improved type-aware linting options.
- Archived middleware by renaming it to `middleware.disabled.ts` for potential future use.
2025-11-20 12:49:45 +01:00
Felipe Cardoso
444d495f83 Refactor metadata handling for improved maintainability and localization support
- Extracted server-only metadata generation logic into separate files, reducing inline logic in page components.
- Added `/* istanbul ignore file */` annotations for E2E-covered framework-level metadata.
- Standardized `generateMetadata` export patterns across auth, admin, and error pages for consistency.
- Enhanced maintainability and readability by centralizing metadata definitions for each route.
2025-11-20 10:07:15 +01:00
Felipe Cardoso
a943f79ce7 Refactor i18n routing tests with jest mocks and enhance coverage
- Replaced i18n routing tests with new mocked implementations for `next-intl/routing` and `next-intl/navigation`.
- Improved test coverage by introducing component-based tests for navigation hooks and link behavior.
- Updated assertions for clarity and consistency in locale configuration and navigation logic.
2025-11-20 09:45:29 +01:00
Felipe Cardoso
f54905abd0 Update README and documentation with i18n, feature enhancements, and SEO improvements
- Added comprehensive details for internationalization (i18n) support via `next-intl`, including locale-based routing and type-safe translations.
- Highlighted new UX features: animated marketing landing page, toasts, charts, markdown rendering, and session tracking.
- Enhanced SEO capabilities with dynamic sitemaps, robots.txt, and locale-aware metadata.
- Updated `/dev` portal information with live component playground details.
- Documented newly integrated libraries, utilities, and testing updates for better developer insight.
2025-11-20 09:45:03 +01:00
Felipe Cardoso
0105e765b3 Add tests for auth storage logic and i18n routing configuration
- Added comprehensive unit tests for `auth/storage` to handle SSR, E2E paths, storage method selection, and error handling.
- Introduced tests for `i18n/routing` to validate locale configuration, navigation hooks, and link preservation.
- Updated Jest coverage exclusions to include `
2025-11-20 09:24:15 +01:00
Felipe Cardoso
bb06b450fd Delete outdated E2E documentation and performance optimization guides.
- Removed `E2E_COVERAGE_GUIDE.md` and `E2E_PERFORMANCE_OPTIMIZATION.md` from `frontend/docs` due to redundancy and irrelevance to recent workflows.
- Cleared unused scripts (`convert-v8-to-istanbul.ts` and `merge-coverage.ts`) from `frontend/scripts`.
2025-11-19 14:56:24 +01:00
Felipe Cardoso
c1d6a04276 Document AI assistant guidance and improve developer workflows
- Added and updated `CLAUDE.md` to provide comprehensive guidance for integrating Claude Code into project workflows.
- Created `AGENTS.md` for general AI assistant context, including architecture, workflows, and tooling specifics.
- Updated `README.md` with references to AI-focused documentation for better discoverability.
- Simplified instructions and refined file organization to enhance usability for developers and AI assistants.
2025-11-19 14:45:29 +01:00
Felipe Cardoso
d7b333385d Add test cases for session revocation and update test coverage annotations
- Introduced unit tests for individual and bulk session revocation in `SessionsManager` with success callback assertions.
- Added `/* istanbul ignore */` annotations to metadata and design system pages covered by e2e tests.
2025-11-19 14:38:46 +01:00
Felipe Cardoso
f02320e57c Add tests for LocaleSwitcher component and update metadata generation
- Introduced unit tests for `LocaleSwitcher` to cover rendering, UX, accessibility, and locale switching logic.
- Updated `generateMetadata` function with `/* istanbul ignore next */` annotation for coverage clarity.
2025-11-19 14:27:03 +01:00
Felipe Cardoso
3ec589293c Add tests for i18n metadata utilities and improve locale-based metadata generation
- Introduced comprehensive unit tests for `generateLocalizedMetadata` and `generatePageMetadata` utilities.
- Enhanced `siteConfig` validation assertions for structure and localization support.
- Updated metadata generation to handle empty paths, canonical URLs, language alternates, and Open Graph data consistently.
- Annotated server-side middleware with istanbul ignore for coverage clarity.
2025-11-19 14:23:06 +01:00
Felipe Cardoso
7b1bea2966 Refactor i18n integration and update tests for improved localization
- Updated test components (`PasswordResetConfirmForm`, `PasswordChangeForm`) to use i18n keys directly, ensuring accurate validation messages.
- Refined translations in `it.json` to standardize format and content.
- Replaced text-based labels with localized strings in `PasswordResetRequestForm` and `RegisterForm`.
- Introduced `generateLocalizedMetadata` utility and updated layout metadata generation for locale-aware SEO.
- Enhanced e2e tests with locale-prefixed routes and updated assertions for consistency.
- Added comprehensive i18n documentation (`I18N.md`) for usage, architecture, and testing.
2025-11-19 14:07:13 +01:00
Felipe Cardoso
da7b6b5bfa Implement extensive localization improvements across forms and components
- Refactored `it.json` translations with added keys for authentication, admin panel, and settings.
- Updated authentication forms (`LoginForm`, `RegisterForm`, `PasswordResetConfirmForm`) to use localized strings via `next-intl`.
- Enhanced password validation schemas with dynamic translations and refined error messages.
- Adjusted `Header` and related components to include localized navigation and status elements.
- Improved placeholder hints, button labels, and inline validation messages for seamless localization.
2025-11-19 03:02:59 +01:00
Felipe Cardoso
7aa63d79df Implement extensive localization improvements across forms and components
- Refactored `it.json` translations with added keys for authentication, admin panel, and settings.
- Updated authentication forms (`LoginForm`, `RegisterForm`, `PasswordResetConfirmForm`) to use localized strings via `next-intl`.
- Enhanced password validation schemas with dynamic translations and refined error messages.
- Adjusted `Header` and related components to include localized navigation and status elements.
- Improved placeholder hints, button labels, and inline validation messages for seamless localization.
2025-11-19 03:02:13 +01:00
Felipe Cardoso
333c9c40af Add locale switcher component and integrate internationalization improvements
- Introduced `LocaleSwitcher` component for language selection with support for locale-aware dropdown and ARIA accessibility.
- Updated layouts (`Header`, `Breadcrumbs`, `Home`) to include the new locale switcher.
- Expanded localization files (`en.json`, `it.json`) with new keys for language switching.
- Adjusted i18n configuration to enhance routing and message imports.
- Updated Jest module mappings to mock new i18n components and utilities.
2025-11-19 01:31:51 +01:00
Felipe Cardoso
0b192ce030 Update e2e tests and mocks for locale-based routing
- Adjusted assertions and navigation tests to include `/en` locale prefix for consistency.
- Updated next-intl and components-i18n mocks to support locale handling in tests.
- Renamed "Components" link and related references to "Design System" in homepage tests.
- Disabled typing delay in debounce test for improved test reliability.
2025-11-19 01:31:35 +01:00
Felipe Cardoso
da021d0640 Update tests and e2e files to support locale-based routing
- Replaced static paths with dynamic locale subpaths (`/[locale]/*`) in imports, URLs, and assertions across tests.
- Updated `next-intl` mocks for improved compatibility with `locale`-aware components.
- Standardized `page.goto` and navigation tests with `/en` as the base locale for consistency.
2025-11-18 23:26:10 +01:00
Felipe Cardoso
d1b47006f4 Remove all obsolete authentication, settings, admin, and demo-related components and pages
- Eliminated redundant components, pages, and layouts related to authentication (`login`, `register`, `password-reset`, etc.), user settings, admin, and demos.
- Simplified the frontend structure by removing unused dynamic imports, forms, and test code.
- Refactored configurations and metadata imports to exclude references to removed features.
- Streamlined the project for future development and improved maintainability by discarding legacy and unused code.
2025-11-18 12:41:57 +01:00
Felipe Cardoso
a73d3c7d3e Refactor multiline formatting, link definitions, and code consistency across components and tests
- Improved readability by updating multiline statements and object definitions.
- Applied consistent link and button wrapping in `DemoSection` and other components.
- Enhanced test assertions and helper functions with uniform formatting and parentheses usage.
2025-11-18 07:25:23 +01:00
Felipe Cardoso
55ae92c460 Refactor i18n setup and improve structure for maintainability
- Relocated `i18n` configuration files to `src/lib/i18n` for better organization.
- Removed obsolete `request.ts` and `routing.ts` files, simplifying `i18n` setup within the project.
- Added extensive tests for `i18n/utils` to validate locale-related utilities, including locale name, native name, and flag retrieval.
- Introduced a detailed `I18N_IMPLEMENTATION_PLAN.md` to document implementation phases, decisions, and recommendations for future extensions.
- Enhanced TypeScript definitions and modularity across i18n utilities for improved developer experience.
2025-11-18 07:23:54 +01:00
Felipe Cardoso
fe6a98c379 Add internationalization (i18n) with next-intl and Italian translations
- Integrated `next-intl` for server-side and client-side i18n support.
- Added English (`en.json`) and Italian (`it.json`) localization files.
- Configured routing with locale-based subdirectories (`/[locale]/path`) using `next-intl`.
- Introduced type-safe i18n utilities and TypeScript definitions for translation keys.
- Updated middleware to handle locale detection and routing.
- Implemented dynamic translation loading to reduce bundle size.
- Enhanced developer experience with auto-complete and compile-time validation for i18n keys.
2025-11-17 20:27:09 +01:00
Felipe Cardoso
b7c1191335 Refactor locale validation and update style consistency across schemas, tests, and migrations
- Replaced `SUPPORTED_LOCALES` with `supported_locales` for naming consistency.
- Applied formatting improvements to multiline statements for better readability.
- Cleaned up redundant comments and streamlined test assertions.
2025-11-17 20:04:03 +01:00
Felipe Cardoso
68e04a911a Add user locale preference support and locale detection logic
- Introduced `locale` field in user model and schemas with BCP 47 format validation.
- Created Alembic migration to add `locale` column to the `users` table with indexing for better query performance.
- Implemented `get_locale` dependency to detect locale using user preference, `Accept-Language` header, or default to English.
- Added extensive tests for locale validation, dependency logic, and fallback handling.
- Enhanced documentation and comments detailing the locale detection workflow and SUPPORTED_LOCALES configuration.
2025-11-17 19:47:50 +01:00
Felipe Cardoso
3001484948 Update Makefile with dev-full target and frontend scaling option
- Added new `dev-full` target to start all development services, including the frontend.
- Modified `dev` target to exclude the frontend and provide instructions for running it locally.
- Updated `.PHONY` to include the new `dev-full` target.
2025-11-16 20:02:15 +01:00
Felipe Cardoso
c9f4772196 Add and enhance tests for mobile navigation, demo modal, and forbidden page metadata
- Added new test cases for mobile navigation links and buttons in `Header` component.
- Enhanced `Home` tests to verify demo modal behavior (open/close functionality).
- Added metadata validation test for the forbidden page.
- Introduced comprehensive test suite for the DemoTour page, covering structure, navigation, categories, accessibility, and CTAs.
2025-11-16 19:38:46 +01:00
Felipe Cardoso
14e5839476 Update test suite to reflect "Design System" renaming and improved navigation structure
- Replaced "Components" references with "Design System" in both links and test assertions.
- Adjusted `DemoCredentialsModal` tests to include separate links for user/admin login and updated text expectations.
- Enhanced `Home` tests with new demo content validation (`User Dashboard`) and renamed navigation elements.
2025-11-12 17:48:22 +01:00
394 changed files with 38312 additions and 40245 deletions

55
.env.demo Normal file
View File

@@ -0,0 +1,55 @@
# Common settings
PROJECT_NAME=App
VERSION=1.0.0
# Database settings
POSTGRES_USER=postgres
POSTGRES_PASSWORD=postgres
POSTGRES_DB=app
POSTGRES_HOST=db
POSTGRES_PORT=5432
DATABASE_URL=postgresql://postgres:postgres@db:5432/app
# Backend settings
BACKEND_PORT=8000
# CRITICAL: Generate a secure SECRET_KEY for production!
# Generate with: python -c 'import secrets; print(secrets.token_urlsafe(32))'
# Must be at least 32 characters
SECRET_KEY=demo_secret_key_for_testing_only_do_not_use_in_prod
ENVIRONMENT=development
DEMO_MODE=true
DEBUG=true
BACKEND_CORS_ORIGINS=["http://localhost:3000"]
FIRST_SUPERUSER_EMAIL=admin@example.com
# IMPORTANT: Use a strong password (min 12 chars, mixed case, digits)
# Default weak passwords like 'Admin123' are rejected
FIRST_SUPERUSER_PASSWORD=Admin123!
# OAuth Configuration (Social Login)
# Set OAUTH_ENABLED=true and configure at least one provider
OAUTH_ENABLED=false
OAUTH_AUTO_LINK_BY_EMAIL=true
# Google OAuth (from Google Cloud Console > APIs & Services > Credentials)
# https://console.cloud.google.com/apis/credentials
# OAUTH_GOOGLE_CLIENT_ID=your-google-client-id.apps.googleusercontent.com
# OAUTH_GOOGLE_CLIENT_SECRET=your-google-client-secret
# GitHub OAuth (from GitHub > Settings > Developer settings > OAuth Apps)
# https://github.com/settings/developers
# OAUTH_GITHUB_CLIENT_ID=your-github-client-id
# OAUTH_GITHUB_CLIENT_SECRET=your-github-client-secret
# OAuth Provider Mode (Authorization Server for MCP/third-party clients)
# Set OAUTH_PROVIDER_ENABLED=true to act as an OAuth 2.0 Authorization Server
OAUTH_PROVIDER_ENABLED=true
# IMPORTANT: Must be HTTPS in production!
OAUTH_ISSUER=http://localhost:8000
# Frontend settings
FRONTEND_PORT=3000
FRONTEND_URL=http://localhost:3000
NEXT_PUBLIC_API_URL=http://localhost:8000
NEXT_PUBLIC_API_BASE_URL=http://localhost:8000
NEXT_PUBLIC_APP_URL=http://localhost:3000
NODE_ENV=development

View File

@@ -17,6 +17,7 @@ BACKEND_PORT=8000
# Must be at least 32 characters # Must be at least 32 characters
SECRET_KEY=your_secret_key_here_REPLACE_WITH_GENERATED_KEY_32_CHARS_MIN SECRET_KEY=your_secret_key_here_REPLACE_WITH_GENERATED_KEY_32_CHARS_MIN
ENVIRONMENT=development ENVIRONMENT=development
DEMO_MODE=false
DEBUG=true DEBUG=true
BACKEND_CORS_ORIGINS=["http://localhost:3000"] BACKEND_CORS_ORIGINS=["http://localhost:3000"]
FIRST_SUPERUSER_EMAIL=admin@example.com FIRST_SUPERUSER_EMAIL=admin@example.com
@@ -24,7 +25,31 @@ FIRST_SUPERUSER_EMAIL=admin@example.com
# Default weak passwords like 'Admin123' are rejected # Default weak passwords like 'Admin123' are rejected
FIRST_SUPERUSER_PASSWORD=YourStrongPassword123! FIRST_SUPERUSER_PASSWORD=YourStrongPassword123!
# OAuth Configuration (Social Login)
# Set OAUTH_ENABLED=true and configure at least one provider
OAUTH_ENABLED=false
OAUTH_AUTO_LINK_BY_EMAIL=true
# Google OAuth (from Google Cloud Console > APIs & Services > Credentials)
# https://console.cloud.google.com/apis/credentials
# OAUTH_GOOGLE_CLIENT_ID=your-google-client-id.apps.googleusercontent.com
# OAUTH_GOOGLE_CLIENT_SECRET=your-google-client-secret
# GitHub OAuth (from GitHub > Settings > Developer settings > OAuth Apps)
# https://github.com/settings/developers
# OAUTH_GITHUB_CLIENT_ID=your-github-client-id
# OAUTH_GITHUB_CLIENT_SECRET=your-github-client-secret
# OAuth Provider Mode (Authorization Server for MCP/third-party clients)
# Set OAUTH_PROVIDER_ENABLED=true to act as an OAuth 2.0 Authorization Server
OAUTH_PROVIDER_ENABLED=false
# IMPORTANT: Must be HTTPS in production!
OAUTH_ISSUER=http://localhost:8000
# Frontend settings # Frontend settings
FRONTEND_PORT=3000 FRONTEND_PORT=3000
FRONTEND_URL=http://localhost:3000
NEXT_PUBLIC_API_URL=http://localhost:8000 NEXT_PUBLIC_API_URL=http://localhost:8000
NEXT_PUBLIC_API_BASE_URL=http://localhost:8000
NEXT_PUBLIC_APP_URL=http://localhost:3000
NODE_ENV=development NODE_ENV=development

View File

@@ -41,7 +41,7 @@ To enable CI/CD workflows:
- Runs on: Push to main/develop, PRs affecting frontend code - Runs on: Push to main/develop, PRs affecting frontend code
- Tests: Frontend unit tests (Jest) - Tests: Frontend unit tests (Jest)
- Coverage: Uploads to Codecov - Coverage: Uploads to Codecov
- Fast: Uses npm cache - Fast: Uses bun cache
### `e2e-tests.yml` ### `e2e-tests.yml`
- Runs on: All pushes and PRs - Runs on: All pushes and PRs

View File

@@ -0,0 +1,77 @@
# Backend E2E Tests CI Pipeline
#
# Runs end-to-end tests with real PostgreSQL via Testcontainers
# and validates API contracts with Schemathesis.
#
# To enable: Rename this file to backend-e2e-tests.yml
name: Backend E2E Tests
on:
push:
branches: [main, develop]
paths:
- 'backend/**'
- '.github/workflows/backend-e2e-tests.yml'
pull_request:
branches: [main, develop]
paths:
- 'backend/**'
workflow_dispatch:
jobs:
e2e-tests:
runs-on: ubuntu-latest
# E2E test failures don't block merge - they're advisory
continue-on-error: true
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.12'
- name: Install uv
uses: astral-sh/setup-uv@v4
with:
version: "latest"
- name: Cache uv dependencies
uses: actions/cache@v4
with:
path: ~/.cache/uv
key: uv-${{ runner.os }}-${{ hashFiles('backend/uv.lock') }}
restore-keys: |
uv-${{ runner.os }}-
- name: Install dependencies (with E2E)
working-directory: ./backend
run: uv sync --extra dev --extra e2e
- name: Check Docker availability
id: docker-check
run: |
if docker info > /dev/null 2>&1; then
echo "available=true" >> $GITHUB_OUTPUT
echo "Docker is available"
else
echo "available=false" >> $GITHUB_OUTPUT
echo "::warning::Docker not available - E2E tests will be skipped"
fi
- name: Run E2E tests
if: steps.docker-check.outputs.available == 'true'
working-directory: ./backend
env:
IS_TEST: "True"
SECRET_KEY: "e2e-test-secret-key-minimum-32-characters-long"
PYTHONPATH: "."
run: |
uv run pytest tests/e2e/ -v --tb=short
- name: E2E tests skipped
if: steps.docker-check.outputs.available != 'true'
run: echo "E2E tests were skipped due to Docker unavailability"

3
.gitignore vendored
View File

@@ -187,7 +187,7 @@ coverage.xml
.hypothesis/ .hypothesis/
.pytest_cache/ .pytest_cache/
cover/ cover/
backend/.benchmarks
# Translations # Translations
*.mo *.mo
*.pot *.pot
@@ -268,6 +268,7 @@ celerybeat.pid
.env .env
.env.* .env.*
!.env.template !.env.template
!.env.demo
.venv .venv
env/ env/
venv/ venv/

315
AGENTS.md Normal file
View File

@@ -0,0 +1,315 @@
# AGENTS.md
AI coding assistant context for FastAPI + Next.js Full-Stack Template.
## Quick Start
```bash
# Backend (Python with uv)
cd backend
make install-dev # Install dependencies
make test # Run tests
uv run uvicorn app.main:app --reload # Start dev server
# Frontend (Node.js)
cd frontend
bun install # Install dependencies
bun run dev # Start dev server
bun run generate:api # Generate API client from OpenAPI
bun run test:e2e # Run E2E tests
```
**Access points:**
- Frontend: **http://localhost:3000**
- Backend API: **http://localhost:8000**
- API Docs: **http://localhost:8000/docs**
Default superuser (change in production):
- Email: `admin@example.com`
- Password: `admin123`
## Project Architecture
**Full-stack TypeScript/Python application:**
```
├── backend/ # FastAPI backend
│ ├── app/
│ │ ├── api/ # API routes (auth, users, organizations, admin)
│ │ ├── core/ # Core functionality (auth, config, database)
│ │ ├── repositories/ # Repository pattern (database operations)
│ │ ├── models/ # SQLAlchemy ORM models
│ │ ├── schemas/ # Pydantic request/response schemas
│ │ ├── services/ # Business logic layer
│ │ └── utils/ # Utilities (security, device detection)
│ ├── tests/ # 96% coverage, 987 tests
│ └── alembic/ # Database migrations
└── frontend/ # Next.js 16 frontend
├── src/
│ ├── app/ # App Router pages (Next.js 16)
│ ├── components/ # React components
│ ├── lib/
│ │ ├── api/ # Auto-generated API client
│ │ └── stores/ # Zustand state management
│ └── hooks/ # Custom React hooks
└── e2e/ # Playwright E2E tests (56 passing)
```
## Critical Implementation Notes
### Authentication Flow
- **JWT-based**: Access tokens (15 min) + refresh tokens (7 days)
- **OAuth/Social Login**: Google and GitHub with PKCE support
- **Session tracking**: Database-backed with device info, IP, user agent
- **Token refresh**: Validates JTI in database, not just JWT signature
- **Authorization**: FastAPI dependencies in `api/dependencies/auth.py`
- `get_current_user`: Requires valid access token
- `get_current_active_user`: Requires active account
- `get_optional_current_user`: Accepts authenticated or anonymous
- `get_current_superuser`: Requires superuser flag
### OAuth Provider Mode (MCP Integration)
Full OAuth 2.0 Authorization Server for MCP (Model Context Protocol) clients:
- **Authorization Code Flow with PKCE**: RFC 7636 compliant
- **JWT access tokens**: Self-contained, no DB lookup required
- **Opaque refresh tokens**: Stored hashed in database, supports rotation
- **Token introspection**: RFC 7662 compliant endpoint
- **Token revocation**: RFC 7009 compliant endpoint
- **Server metadata**: RFC 8414 compliant discovery endpoint
- **Consent management**: User can review and revoke app permissions
**API endpoints:**
- `GET /.well-known/oauth-authorization-server` - Server metadata
- `GET /oauth/provider/authorize` - Authorization endpoint
- `POST /oauth/provider/authorize/consent` - Consent submission
- `POST /oauth/provider/token` - Token endpoint
- `POST /oauth/provider/revoke` - Token revocation
- `POST /oauth/provider/introspect` - Token introspection
- Client management endpoints (admin only)
**Scopes supported:** `openid`, `profile`, `email`, `read:users`, `write:users`, `admin`
**OAuth Configuration (backend `.env`):**
```bash
# OAuth Social Login (as OAuth Consumer)
OAUTH_ENABLED=true # Enable OAuth social login
OAUTH_AUTO_LINK_BY_EMAIL=true # Auto-link accounts by email
OAUTH_STATE_EXPIRE_MINUTES=10 # CSRF state expiration
# Google OAuth
OAUTH_GOOGLE_CLIENT_ID=your-google-client-id
OAUTH_GOOGLE_CLIENT_SECRET=your-google-client-secret
# GitHub OAuth
OAUTH_GITHUB_CLIENT_ID=your-github-client-id
OAUTH_GITHUB_CLIENT_SECRET=your-github-client-secret
# OAuth Provider Mode (as Authorization Server for MCP)
OAUTH_PROVIDER_ENABLED=true # Enable OAuth provider mode
OAUTH_ISSUER=https://api.yourdomain.com # JWT issuer URL (must be HTTPS in production)
```
### Database Pattern
- **Async SQLAlchemy 2.0** with PostgreSQL
- **Connection pooling**: 20 base connections, 50 max overflow
- **Repository base class**: `repositories/base.py` with common operations
- **Migrations**: Alembic with helper script `migrate.py`
- `python migrate.py auto "message"` - Generate and apply
- `python migrate.py list` - View history
### Frontend State Management
- **Zustand stores**: Lightweight state management
- **TanStack Query**: API data fetching/caching
- **Auto-generated client**: From OpenAPI spec via `bun run generate:api`
- **Dependency Injection**: ALWAYS use `useAuth()` from `AuthContext`, NEVER import `useAuthStore` directly
### Internationalization (i18n)
- **next-intl v4**: Type-safe internationalization library
- **Locale routing**: `/en/*`, `/it/*` (English and Italian supported)
- **Translation files**: `frontend/messages/en.json`, `frontend/messages/it.json`
- **LocaleSwitcher**: Component for seamless language switching
- **SEO-friendly**: Locale-aware metadata, sitemaps, and robots.txt
- **Type safety**: Full TypeScript support for translations
- **Utilities**: `frontend/src/lib/i18n/` (metadata, routing, utils)
### Organization System
Three-tier RBAC:
- **Owner**: Full control (delete org, manage all members)
- **Admin**: Add/remove members, assign admin role (not owner)
- **Member**: Read-only organization access
Permission dependencies in `api/dependencies/permissions.py`:
- `require_organization_owner`
- `require_organization_admin`
- `require_organization_member`
- `can_manage_organization_member`
### Testing Infrastructure
**Backend Unit/Integration (pytest + SQLite):**
- 96% coverage, 987 tests
- Security-focused: JWT attacks, session hijacking, privilege escalation
- Async fixtures in `tests/conftest.py`
- Run: `IS_TEST=True uv run pytest` or `make test`
- Coverage: `make test-cov`
**Backend E2E (pytest + Testcontainers + Schemathesis):**
- Real PostgreSQL via Docker containers
- OpenAPI contract testing with Schemathesis
- Install: `make install-e2e`
- Run: `make test-e2e`
- Schema tests: `make test-e2e-schema`
- Docs: `backend/docs/E2E_TESTING.md`
**Frontend Unit Tests (Jest):**
- 97% coverage
- Component, hook, and utility testing
- Run: `bun run test`
- Coverage: `bun run test:coverage`
**Frontend E2E Tests (Playwright):**
- 56 passing, 1 skipped (zero flaky tests)
- Complete user flows (auth, navigation, settings)
- Run: `bun run test:e2e`
- UI mode: `bun run test:e2e:ui`
### Development Tooling
**Backend:**
- **uv**: Modern Python package manager (10-100x faster than pip)
- **Ruff**: All-in-one linting/formatting (replaces Black, Flake8, isort)
- **Pyright**: Static type checking (strict mode)
- **pip-audit**: Dependency vulnerability scanning (OSV database)
- **detect-secrets**: Hardcoded secrets detection
- **pip-licenses**: License compliance checking
- **pre-commit**: Git hook framework (Ruff, detect-secrets, standard checks)
- **Makefile**: `make help` for all commands
**Frontend:**
- **Next.js 16**: App Router with React 19
- **TypeScript**: Full type safety
- **TailwindCSS + shadcn/ui**: Design system
- **ESLint + Prettier**: Code quality
### Environment Configuration
**Backend** (`.env`):
```bash
POSTGRES_USER=postgres
POSTGRES_PASSWORD=your_password
POSTGRES_HOST=db
POSTGRES_PORT=5432
POSTGRES_DB=app
SECRET_KEY=your-secret-key-min-32-chars
ENVIRONMENT=development|production
CSP_MODE=relaxed|strict|disabled
FIRST_SUPERUSER_EMAIL=admin@example.com
FIRST_SUPERUSER_PASSWORD=admin123
BACKEND_CORS_ORIGINS=["http://localhost:3000"]
```
**Frontend** (`.env.local`):
```bash
NEXT_PUBLIC_API_URL=http://localhost:8000/api/v1
```
## Common Development Workflows
### Adding a New API Endpoint
1. **Define schema** in `backend/app/schemas/`
2. **Create repository** in `backend/app/repositories/`
3. **Implement route** in `backend/app/api/routes/`
4. **Register router** in `backend/app/api/main.py`
5. **Write tests** in `backend/tests/api/`
6. **Generate frontend client**: `bun run generate:api`
### Database Migrations
```bash
cd backend
python migrate.py generate "description" # Create migration
python migrate.py apply # Apply migrations
python migrate.py auto "description" # Generate + apply
```
### Frontend Component Development
1. **Create component** in `frontend/src/components/`
2. **Follow design system** (see `frontend/docs/design-system/`)
3. **Use dependency injection** for auth (`useAuth()` not `useAuthStore`)
4. **Write tests** in `frontend/tests/` or `__tests__/`
5. **Run type check**: `bun run type-check`
## Security Features
- **Password hashing**: bcrypt with salt rounds
- **Rate limiting**: 60 req/min default, 10 req/min on auth endpoints
- **Security headers**: CSP, X-Frame-Options, HSTS, etc.
- **CSRF protection**: Built into FastAPI
- **Session revocation**: Database-backed session tracking
- **Comprehensive security tests**: JWT algorithm attacks, session hijacking, privilege escalation
- **Dependency vulnerability scanning**: `make dep-audit` (pip-audit against OSV database)
- **License compliance**: `make license-check` (blocks GPL-3.0/AGPL)
- **Secrets detection**: Pre-commit hook blocks hardcoded secrets
- **Unified security pipeline**: `make audit` (all security checks), `make check` (quality + security + tests)
## Docker Deployment
```bash
# Development (with hot reload)
docker-compose -f docker-compose.dev.yml up
# Production
docker-compose up -d
# Run migrations
docker-compose exec backend alembic upgrade head
# Create first superuser
docker-compose exec backend python -c "from app.init_db import init_db; import asyncio; asyncio.run(init_db())"
```
## Documentation
**For comprehensive documentation, see:**
- **[README.md](./README.md)** - User-facing project overview
- **[CLAUDE.md](./CLAUDE.md)** - Claude Code-specific guidance
- **Backend docs**: `backend/docs/` (Architecture, Coding Standards, Common Pitfalls, Feature Examples)
- **Frontend docs**: `frontend/docs/` (Design System, Architecture, E2E Testing)
- **API docs**: http://localhost:8000/docs (Swagger UI when running)
## Current Status (Nov 2025)
### Completed Features ✅
- Authentication system (JWT with refresh tokens, OAuth/social login)
- **OAuth Provider Mode (MCP-ready)**: Full OAuth 2.0 Authorization Server
- Session management (device tracking, revocation)
- User management (full lifecycle, password change)
- Organization system (multi-tenant with RBAC)
- Admin panel (user/org management, bulk operations)
- **Internationalization (i18n)** with English and Italian
- Comprehensive test coverage (96% backend, 97% frontend unit, 56 E2E tests)
- Design system documentation
- **Marketing landing page** with animations
- **`/dev` documentation portal** with live examples
- **Toast notifications**, charts, markdown rendering
- **SEO optimization** (sitemap, robots.txt, locale metadata)
- Docker deployment
### In Progress 🚧
- Frontend admin pages (70% complete)
- Email integration (templates ready, SMTP pending)
### Planned 🔮
- GitHub Actions CI/CD
- Additional languages (Spanish, French, German, etc.)
- SSO/SAML authentication
- Real-time notifications (WebSockets)
- Webhook system
- Background job processing
- File upload/storage

785
CLAUDE.md
View File

@@ -1,10 +1,14 @@
# CLAUDE.md # CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. Claude Code context for FastAPI + Next.js Full-Stack Template.
## Critical User Preferences **See [AGENTS.md](./AGENTS.md) for project context, architecture, and development commands.**
### File Operations - NEVER Use Heredoc/Cat Append ## Claude Code-Specific Guidance
### Critical User Preferences
#### File Operations - NEVER Use Heredoc/Cat Append
**ALWAYS use Read/Write/Edit tools instead of `cat >> file << EOF` commands.** **ALWAYS use Read/Write/Edit tools instead of `cat >> file << EOF` commands.**
This triggers manual approval dialogs and disrupts workflow. This triggers manual approval dialogs and disrupts workflow.
@@ -18,215 +22,53 @@ EOF
# CORRECT ✅ - Use Read, then Write tools # CORRECT ✅ - Use Read, then Write tools
``` ```
### Work Style #### Work Style
- User prefers autonomous operation without frequent interruptions - User prefers autonomous operation without frequent interruptions
- Ask for batch permissions upfront for long work sessions - Ask for batch permissions upfront for long work sessions
- Work independently, document decisions clearly - Work independently, document decisions clearly
- Only use emojis if the user explicitly requests it
## Project Architecture ### When Working with This Stack
This is a **FastAPI + Next.js full-stack application** with the following structure: **Dependency Management:**
- Backend uses **uv** (modern Python package manager), not pip
- Always use `uv run` prefix: `IS_TEST=True uv run pytest`
- Or use Makefile commands: `make test`, `make install-dev`
- Add dependencies: `uv add <package>` or `uv add --dev <package>`
### Backend (FastAPI) **Database Migrations:**
``` - Use the `migrate.py` helper script, not Alembic directly
backend/app/ - Generate + apply: `python migrate.py auto "message"`
├── api/ # API routes organized by version - Never commit migrations without testing them first
│ ├── routes/ # Endpoint implementations (auth, users, sessions, admin, organizations) - Check current state: `python migrate.py current`
│ └── dependencies/ # FastAPI dependencies (auth, permissions)
├── core/ # Core functionality
│ ├── config.py # Settings (Pydantic BaseSettings)
│ ├── database.py # SQLAlchemy async engine setup
│ ├── auth.py # JWT token generation/validation
│ └── exceptions.py # Custom exception classes and handlers
├── crud/ # Database CRUD operations (base, user, session, organization)
├── models/ # SQLAlchemy ORM models
├── schemas/ # Pydantic request/response schemas
├── services/ # Business logic layer (auth_service)
└── utils/ # Utilities (security, device detection, test helpers)
```
### Frontend (Next.js 15) **Frontend API Client Generation:**
``` - Run `bun run generate:api` after backend schema changes
frontend/src/ - Client is auto-generated from OpenAPI spec
├── app/ # Next.js App Router pages - Located in `frontend/src/lib/api/generated/`
├── components/ # React components (auth/, ui/) - NEVER manually edit generated files
├── lib/
│ ├── api/ # API client (auto-generated from OpenAPI)
│ ├── stores/ # Zustand state management
│ └── utils/ # Utility functions
└── hooks/ # Custom React hooks
```
## Development Commands **Testing Commands:**
- Backend unit/integration: `IS_TEST=True uv run pytest` (always prefix with `IS_TEST=True`)
- Backend E2E (requires Docker): `make test-e2e`
- Frontend unit: `bun run test`
- Frontend E2E: `bun run test:e2e`
- Use `make test` or `make test-cov` in backend for convenience
### Backend **Security & Quality Commands (Backend):**
- `make validate` — lint + format + type checks
- `make audit` — dependency vulnerabilities + license compliance
- `make validate-all` — quality + security checks
- `make check`**full pipeline**: quality + security + tests
#### Setup **Backend E2E Testing (requires Docker):**
- Install deps: `make install-e2e`
**Dependencies are managed with [uv](https://docs.astral.sh/uv/) - the modern, fast Python package manager.** - Run all E2E tests: `make test-e2e`
- Run schema tests only: `make test-e2e-schema`
```bash - Run all tests: `make test-all` (unit + E2E)
cd backend - Uses Testcontainers (real PostgreSQL) + Schemathesis (OpenAPI contract testing)
- Markers: `@pytest.mark.e2e`, `@pytest.mark.postgres`, `@pytest.mark.schemathesis`
# Install uv (if not already installed) - See: `backend/docs/E2E_TESTING.md` for complete guide
curl -LsSf https://astral.sh/uv/install.sh | sh
# Install all dependencies (production + dev) from uv.lock
uv sync --extra dev
# Or use the Makefile
make install-dev
```
**Why uv?**
- 🚀 10-100x faster than pip
- 🔒 Reproducible builds with `uv.lock`
- 📦 Modern dependency resolution
- ⚡ Built by Astral (creators of Ruff)
#### Database Migrations
```bash
# Using the migration helper (preferred)
python migrate.py generate "migration message" # Generate migration
python migrate.py apply # Apply migrations
python migrate.py auto "message" # Generate and apply in one step
python migrate.py list # List all migrations
python migrate.py current # Show current revision
python migrate.py check # Check DB connection
# Or using Alembic directly
alembic revision --autogenerate -m "message"
alembic upgrade head
```
#### Testing
**Test Coverage: High (comprehensive test suite)**
- Security-focused testing with JWT algorithm attack prevention (CVE-2015-9235)
- Session hijacking and privilege escalation tests included
- Missing lines justified as defensive code, error handlers, and production-only code
```bash
# Run all tests (uses pytest-xdist for parallel execution)
make test
# Run with coverage report
make test-cov
# Or run directly with uv
IS_TEST=True uv run pytest
# Run specific test file
IS_TEST=True uv run pytest tests/api/test_auth.py -v
# Run single test
IS_TEST=True uv run pytest tests/api/test_auth.py::TestLogin::test_login_success -v
```
**Available Make Commands:**
```bash
make help # Show all available commands
make install-dev # Install all dependencies
make validate # Run lint + format + type checks
make test # Run tests
make test-cov # Run tests with coverage
```
#### Running Locally
```bash
cd backend
uv run uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
```
### Frontend
#### Setup
```bash
cd frontend
npm install
```
#### Development
```bash
npm run dev # Start dev server on http://localhost:3000
npm run build # Production build
npm run lint # ESLint
npm run type-check # TypeScript checking
```
#### Testing
```bash
# Unit tests (Jest)
npm test # Run all unit tests
npm run test:watch # Watch mode
npm run test:coverage # With coverage
# E2E tests (Playwright)
npm run test:e2e # Run all E2E tests
npm run test:e2e:ui # Open Playwright UI
npm run test:e2e:debug # Debug mode
npx playwright test auth-login.spec.ts # Run specific file
```
**E2E Test Best Practices:**
- Use `Promise.all()` pattern for Next.js Link navigation:
```typescript
await Promise.all([
page.waitForURL('/target', { timeout: 10000 }),
link.click()
]);
```
- Use ID-based selectors for validation errors (e.g., `#email-error`)
- Error IDs use dashes not underscores (`#new-password-error`)
- Target `.border-destructive[role="alert"]` to avoid Next.js route announcer conflicts
- Uses 12 workers in non-CI mode (`workers: 12` in `playwright.config.ts`)
- URL assertions should use regex to handle query params: `/\/auth\/login/`
### Docker
```bash
# Development (with hot reload)
docker-compose -f docker-compose.dev.yml up
# Production
docker-compose up -d
# Rebuild specific service
docker-compose build backend
docker-compose build frontend
```
## Key Architectural Patterns
### Authentication Flow
1. **Login**: `POST /api/v1/auth/login` returns access + refresh tokens
- Access token: 15 minutes expiry (JWT)
- Refresh token: 7 days expiry (JWT with JTI stored in DB)
- Session tracking with device info (IP, user agent, device ID)
2. **Token Refresh**: `POST /api/v1/auth/refresh` validates refresh token JTI
- Checks session is active in database
- Issues new access token (refresh token remains valid)
- Updates session `last_used_at`
3. **Authorization**: FastAPI dependencies in `api/dependencies/auth.py`
- `get_current_user`: Validates access token, returns User (raises 401 if invalid)
- `get_current_active_user`: Requires valid access token + active account
- `get_optional_current_user`: Accepts both authenticated and anonymous users (returns User or None)
- `get_current_superuser`: Requires superuser flag
### Database Pattern: Async SQLAlchemy
- **Engine**: Created in `core/database.py` with connection pooling
- **Sessions**: AsyncSession from `async_sessionmaker`
- **CRUD**: Base class in `crud/base.py` with common operations
- Inherits: `CRUDUser`, `CRUDSession`, `CRUDOrganization`
- Pattern: `async def get(db: AsyncSession, id: str) -> Model | None`
### Frontend State Management
- **Zustand stores**: `lib/stores/` (authStore, etc.)
- **TanStack Query**: API data fetching/caching
- **Auto-generated client**: `lib/api/generated/` from OpenAPI spec
- Generate with: `npm run generate:api` (runs `scripts/generate-api-client.sh`)
### 🔴 CRITICAL: Auth Store Dependency Injection Pattern ### 🔴 CRITICAL: Auth Store Dependency Injection Pattern
@@ -252,423 +94,160 @@ const { user, isAuthenticated } = useAuth();
1. `AuthContext.tsx` - DI boundary, legitimately needs real store 1. `AuthContext.tsx` - DI boundary, legitimately needs real store
2. `client.ts` - Non-React context, uses dynamic import + `__TEST_AUTH_STORE__` check 2. `client.ts` - Non-React context, uses dynamic import + `__TEST_AUTH_STORE__` check
**See**: `frontend/docs/ARCHITECTURE_FIX_REPORT.md` for full details. ### E2E Test Best Practices
### Session Management Architecture When writing or fixing Playwright tests:
**Database-backed session tracking** (not just JWT):
- Each refresh token has a corresponding `UserSession` record
- Tracks: device info, IP, location, last used timestamp
- Supports session revocation (logout from specific devices)
- Cleanup job removes expired sessions
### Permission System **Navigation Pattern:**
Three-tier organization roles: ```typescript
- **Owner**: Full control (delete org, manage all members) // ✅ CORRECT - Use Promise.all for Next.js Link clicks
- **Admin**: Can add/remove members, assign admin role (not owner) await Promise.all([
- **Member**: Read-only organization access page.waitForURL('/target', { timeout: 10000 }),
link.click()
]);
```
Dependencies in `api/dependencies/permissions.py`: **Selectors:**
- `require_organization_owner` - Use ID-based selectors for validation errors: `#email-error`
- `require_organization_admin` - Error IDs use dashes not underscores: `#new-password-error`
- `require_organization_member` - Target `.border-destructive[role="alert"]` to avoid Next.js route announcer conflicts
- `can_manage_organization_member` (owner or admin, but not self-demotion) - Avoid generic `[role="alert"]` which matches multiple elements
## Testing Infrastructure **URL Assertions:**
```typescript
// ✅ Use regex to handle query params
await expect(page).toHaveURL(/\/auth\/login/);
### Backend Test Patterns // ❌ Don't use exact strings (fails with query params)
await expect(page).toHaveURL('/auth/login');
```
**Fixtures** (in `tests/conftest.py`): **Configuration:**
- `async_test_db`: Fresh SQLite in-memory database per test - Uses 12 workers in non-CI mode (`playwright.config.ts`)
- `client`: AsyncClient with test database override - Reduces to 2 workers in CI for stability
- `async_test_user`: Pre-created regular user - Tests are designed to be non-flaky with proper waits
- `async_test_superuser`: Pre-created superuser
- `user_token` / `superuser_token`: Access tokens for API calls
**Database Mocking for Exception Testing**: ### Important Implementation Details
**Authentication Testing:**
- Backend fixtures in `tests/conftest.py`:
- `async_test_db`: Fresh SQLite per test
- `async_test_user` / `async_test_superuser`: Pre-created users
- `user_token` / `superuser_token`: Access tokens for API calls
- Always use `@pytest.mark.asyncio` for async tests
- Use `@pytest_asyncio.fixture` for async fixtures
**Database Testing:**
```python ```python
# Mock database exceptions correctly
from unittest.mock import patch, AsyncMock from unittest.mock import patch, AsyncMock
# Mock database commit to raise exception
async def mock_commit(): async def mock_commit():
raise OperationalError("Connection lost", {}, Exception()) raise OperationalError("Connection lost", {}, Exception())
with patch.object(session, 'commit', side_effect=mock_commit): with patch.object(session, 'commit', side_effect=mock_commit):
with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback: with patch.object(session, 'rollback', new_callable=AsyncMock) as mock_rollback:
with pytest.raises(OperationalError): with pytest.raises(OperationalError):
await crud_method(session, obj_in=data) await repo_method(session, obj_in=data)
mock_rollback.assert_called_once() mock_rollback.assert_called_once()
``` ```
**Testing Routes**: **Frontend Component Development:**
```python - Follow design system docs in `frontend/docs/design-system/`
@pytest.mark.asyncio - Read `08-ai-guidelines.md` for AI code generation rules
async def test_endpoint(client, user_token): - Use parent-controlled spacing (see `04-spacing-philosophy.md`)
response = await client.get( - WCAG AA compliance required (see `07-accessibility.md`)
"/api/v1/endpoint",
headers={"Authorization": f"Bearer {user_token}"} **Security Considerations:**
) - Backend has comprehensive security tests (JWT attacks, session hijacking)
assert response.status_code == 200 - Never skip security headers in production
``` - Rate limiting is configured in route decorators: `@limiter.limit("10/minute")`
- Session revocation is database-backed, not just JWT expiry
**IMPORTANT**: Use `@pytest_asyncio.fixture` for async fixtures, not `@pytest.fixture` - Run `make audit` to check for dependency vulnerabilities and license compliance
- Run `make check` for the full pipeline: quality + security + tests
### Frontend Test Patterns - Pre-commit hooks enforce Ruff lint/format and detect-secrets on every commit
- Setup hooks: `cd backend && uv run pre-commit install`
**Unit Tests (Jest)**:
```typescript ### Common Workflows Guidance
// SSR-safe mocking
jest.mock('@/lib/stores/authStore', () => ({ **When Adding a New Feature:**
useAuthStore: jest.fn() 1. Start with backend schema and repository
})); 2. Implement API route with proper authorization
3. Write backend tests (aim for >90% coverage)
beforeEach(() => { 4. Generate frontend API client: `bun run generate:api`
(useAuthStore as jest.Mock).mockReturnValue({ 5. Implement frontend components
user: mockUser, 6. Write frontend unit tests
login: mockLogin 7. Add E2E tests for critical flows
}); 8. Update relevant documentation
});
``` **When Fixing Tests:**
- Backend: Check test database isolation and async fixture usage
**E2E Tests (Playwright)**: - Frontend unit: Verify mocking of `useAuth()` not `useAuthStore`
```typescript - E2E: Use `Promise.all()` pattern and regex URL assertions
test('navigation', async ({ page }) => {
await page.goto('/'); **When Debugging:**
- Backend: Check `IS_TEST=True` environment variable is set
const link = page.getByRole('link', { name: 'Login' }); - Frontend: Run `bun run type-check` first
await Promise.all([ - E2E: Use `bun run test:e2e:debug` for step-by-step debugging
page.waitForURL(/\/auth\/login/, { timeout: 10000 }), - Check logs: Backend has detailed error logging
link.click()
]); **Demo Mode (Frontend-Only Showcase):**
- Enable: `echo "NEXT_PUBLIC_DEMO_MODE=true" > frontend/.env.local`
await expect(page).toHaveURL(/\/auth\/login/); - Uses MSW (Mock Service Worker) to intercept API calls in browser
}); - Zero backend required - perfect for Vercel deployments
``` - **Fully Automated**: MSW handlers auto-generated from OpenAPI spec
- Run `bun run generate:api` → updates both API client AND MSW handlers
## Configuration - No manual synchronization needed!
- Demo credentials (any password ≥8 chars works):
### Environment Variables - User: `demo@example.com` / `DemoPass123`
- Admin: `admin@example.com` / `AdminPass123`
**Backend** (`.env`): - **Safe**: MSW never runs during tests (Jest or Playwright)
```bash - **Coverage**: Mock files excluded from linting and coverage
# Database - **Documentation**: `frontend/docs/DEMO_MODE.md` for complete guide
POSTGRES_USER=postgres
POSTGRES_PASSWORD=your_password ### Tool Usage Preferences
POSTGRES_HOST=db
POSTGRES_PORT=5432 **Prefer specialized tools over bash:**
POSTGRES_DB=app - Use Read/Write/Edit tools for file operations
- Never use `cat`, `echo >`, or heredoc for file manipulation
# Security - Use Task tool with `subagent_type=Explore` for codebase exploration
SECRET_KEY=your-secret-key-min-32-chars - Use Grep tool for code search, not bash `grep`
ENVIRONMENT=development|production
CSP_MODE=relaxed|strict|disabled **When to use parallel tool calls:**
- Independent git commands: `git status`, `git diff`, `git log`
# First Superuser (auto-created on init) - Reading multiple unrelated files
FIRST_SUPERUSER_EMAIL=admin@example.com - Running multiple test suites simultaneously
FIRST_SUPERUSER_PASSWORD=admin123 - Independent validation steps
# CORS ## Custom Skills
BACKEND_CORS_ORIGINS=["http://localhost:3000"]
``` No Claude Code Skills installed yet. To create one, invoke the built-in "skill-creator" skill.
**Frontend** (`.env.local`): **Potential skill ideas for this project:**
```bash - API endpoint generator workflow (schema → repository → route → tests → frontend client)
NEXT_PUBLIC_API_URL=http://localhost:8000/api/v1 - Component generator with design system compliance
``` - Database migration troubleshooting helper
- Test coverage analyzer and improvement suggester
### Database Connection Pooling - E2E test generator for new features
Configured in `core/config.py`:
- `db_pool_size`: 20 (default connections) ## Additional Resources
- `db_max_overflow`: 50 (max overflow)
- `db_pool_timeout`: 30 seconds **Comprehensive Documentation:**
- `db_pool_recycle`: 3600 seconds (recycle after 1 hour) - [AGENTS.md](./AGENTS.md) - Framework-agnostic AI assistant context
- [README.md](./README.md) - User-facing project overview
### Security Headers - `backend/docs/` - Backend architecture, coding standards, common pitfalls
Automatically applied via middleware in `main.py`: - `frontend/docs/design-system/` - Complete design system guide
- `X-Frame-Options: DENY`
- `X-Content-Type-Options: nosniff` **API Documentation (when running):**
- `X-XSS-Protection: 1; mode=block` - Swagger UI: http://localhost:8000/docs
- `Strict-Transport-Security` (production only) - ReDoc: http://localhost:8000/redoc
- Content-Security-Policy (configurable via `CSP_MODE`) - OpenAPI JSON: http://localhost:8000/api/v1/openapi.json
### Rate Limiting **Testing Documentation:**
- Implemented with `slowapi` - Backend tests: `backend/tests/` (97% coverage)
- Default: 60 requests/minute per IP - Frontend E2E: `frontend/e2e/README.md`
- Applied to auth endpoints (login, register, password reset) - Design system: `frontend/docs/design-system/08-ai-guidelines.md`
- Override in route decorators: `@limiter.limit("10/minute")`
---
## Common Workflows
**For project architecture, development commands, and general context, see [AGENTS.md](./AGENTS.md).**
### Adding a New API Endpoint
1. **Create schema** (`backend/app/schemas/`):
```python
class ItemCreate(BaseModel):
name: str
description: Optional[str] = None
class ItemResponse(BaseModel):
id: UUID
name: str
created_at: datetime
```
2. **Create CRUD operations** (`backend/app/crud/`):
```python
class CRUDItem(CRUDBase[Item, ItemCreate, ItemUpdate]):
async def get_by_name(self, db: AsyncSession, name: str) -> Item | None:
result = await db.execute(select(Item).where(Item.name == name))
return result.scalar_one_or_none()
item = CRUDItem(Item)
```
3. **Create route** (`backend/app/api/routes/items.py`):
```python
from app.api.dependencies.auth import get_current_user
@router.post("/", response_model=ItemResponse)
async def create_item(
item_in: ItemCreate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
item = await item_crud.create(db, obj_in=item_in)
return item
```
4. **Register router** (`backend/app/api/main.py`):
```python
from app.api.routes import items
api_router.include_router(items.router, prefix="/items", tags=["Items"])
```
5. **Write tests** (`backend/tests/api/test_items.py`):
```python
@pytest.mark.asyncio
async def test_create_item(client, user_token):
response = await client.post(
"/api/v1/items",
headers={"Authorization": f"Bearer {user_token}"},
json={"name": "Test Item"}
)
assert response.status_code == 201
```
6. **Generate frontend client**:
```bash
cd frontend
npm run generate:api
```
### Adding a New React Component
1. **Create component** (`frontend/src/components/`):
```typescript
export function MyComponent() {
const { user } = useAuthStore();
return <div>Hello {user?.firstName}</div>;
}
```
2. **Add tests** (`frontend/src/components/__tests__/`):
```typescript
import { render, screen } from '@testing-library/react';
test('renders component', () => {
render(<MyComponent />);
expect(screen.getByText(/Hello/)).toBeInTheDocument();
});
```
3. **Add to page** (`frontend/src/app/page.tsx`):
```typescript
import { MyComponent } from '@/components/MyComponent';
export default function Page() {
return <MyComponent />;
}
```
## Development Tooling Stack
**State-of-the-art Python tooling (Nov 2025):**
### Dependency Management: uv
- **Fast**: 10-100x faster than pip
- **Reliable**: Reproducible builds with `uv.lock` lockfile
- **Modern**: Built by Astral (Ruff creators) in Rust
- **Commands**:
- `make install-dev` - Install all dependencies
- `make sync` - Sync from lockfile
- `uv add <package>` - Add new dependency
- `uv add --dev <package>` - Add dev dependency
### Code Quality: Ruff + mypy
- **Ruff**: All-in-one linting, formatting, and import sorting
- Replaces: Black, Flake8, isort
- **10-100x faster** than alternatives
- `make lint`, `make format`, `make validate`
- **mypy**: Type checking with Pydantic plugin
- Gradual typing approach
- Strategic per-module configurations
### Configuration: pyproject.toml
- Single source of truth for all tools
- Dependencies defined in `[project.dependencies]`
- Dev dependencies in `[project.optional-dependencies]`
- Tool configs: Ruff, mypy, pytest, coverage
## Current Project Status (Nov 2025)
### Completed Features
- ✅ Authentication system (JWT with refresh tokens)
- ✅ Session management (device tracking, revocation)
- ✅ User management (CRUD, password change)
- ✅ Organization system (multi-tenant with roles)
- ✅ Admin panel (user/org management, bulk operations)
- ✅ E2E test suite (56 passing, 1 skipped, zero flaky tests)
### Test Coverage
- **Backend**: 97% overall (743 tests, all passing) ✅
- Comprehensive security testing (JWT attacks, session hijacking, privilege escalation)
- User CRUD: 100% ✅
- Session CRUD: 100% ✅
- Auth routes: 99% ✅
- Organization routes: 100% ✅
- Permissions: 100% ✅
- 84 missing lines justified (defensive code, error handlers, production-only code)
- **Frontend E2E**: 56 passing, 1 skipped across 7 files ✅
- auth-login.spec.ts (19 tests)
- auth-register.spec.ts (14 tests)
- auth-password-reset.spec.ts (10 tests)
- navigation.spec.ts (10 tests)
- settings-password.spec.ts (3 tests)
- settings-profile.spec.ts (2 tests)
- settings-navigation.spec.ts (5 tests)
- settings-sessions.spec.ts (1 skipped - route not yet implemented)
## Email Service Integration
The project includes a **placeholder email service** (`backend/app/services/email_service.py`) designed for easy integration with production email providers.
### Current Implementation
**Console Backend (Default)**:
- Logs email content to console/logs instead of sending
- Safe for development and testing
- No external dependencies required
### Production Integration
To enable email functionality, implement one of these approaches:
**Option 1: SMTP Integration** (Recommended for most use cases)
```python
# In app/services/email_service.py, complete the SMTPEmailBackend implementation
from aiosmtplib import SMTP
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
# Add environment variables to .env:
# SMTP_HOST=smtp.gmail.com
# SMTP_PORT=587
# SMTP_USERNAME=your-email@gmail.com
# SMTP_PASSWORD=your-app-password
```
**Option 2: Third-Party Service** (SendGrid, AWS SES, Mailgun, etc.)
```python
# Create a new backend class, e.g., SendGridEmailBackend
class SendGridEmailBackend(EmailBackend):
def __init__(self, api_key: str):
self.api_key = api_key
self.client = sendgrid.SendGridAPIClient(api_key)
async def send_email(self, to, subject, html_content, text_content=None):
# Implement SendGrid sending logic
pass
# Update global instance in email_service.py:
# email_service = EmailService(SendGridEmailBackend(settings.SENDGRID_API_KEY))
```
**Option 3: External Microservice**
- Use a dedicated email microservice via HTTP API
- Implement `HTTPEmailBackend` that makes async HTTP requests
### Email Templates Included
The service includes pre-built templates for:
- **Password Reset**: `send_password_reset_email()` - 1 hour expiry
- **Email Verification**: `send_email_verification()` - 24 hour expiry
Both include responsive HTML and plain text versions.
### Integration Points
Email sending is called from:
- `app/api/routes/auth.py` - Password reset flow (placeholder comments)
- Registration flow - Ready for email verification integration
**Note**: Current auth routes have placeholder comments where email functionality should be integrated. Search for "TODO: Send email" in the codebase.
## API Documentation
Once backend is running:
- **Swagger UI**: http://localhost:8000/docs
- **ReDoc**: http://localhost:8000/redoc
- **OpenAPI JSON**: http://localhost:8000/api/v1/openapi.json
## Troubleshooting
### Tests failing with "Module was never imported"
Run with single process: `pytest -n 0`
### Coverage not improving despite new tests
- Verify tests actually execute endpoints (check response.status_code)
- Generate HTML coverage: `pytest --cov=app --cov-report=html -n 0`
- Check for dependency override issues in test fixtures
### Frontend type errors
```bash
npm run type-check # Check all types
npx tsc --noEmit # Same but shorter
```
### E2E tests flaking
- Check worker count (should be 4, not 16+)
- Use `Promise.all()` for navigation
- Use regex for URL assertions
- Target specific selectors (avoid generic `[role="alert"]`)
### Database migration conflicts
```bash
python migrate.py list # Check migration history
alembic downgrade -1 # Downgrade one revision
alembic upgrade head # Re-apply
```
## Additional Documentation
### Backend Documentation
- `backend/docs/ARCHITECTURE.md`: System architecture and design patterns
- `backend/docs/CODING_STANDARDS.md`: Code quality standards and best practices
- `backend/docs/COMMON_PITFALLS.md`: Common mistakes and how to avoid them
- `backend/docs/FEATURE_EXAMPLE.md`: Step-by-step feature implementation guide
### Frontend Documentation
- **`frontend/docs/ARCHITECTURE_FIX_REPORT.md`**: ⭐ Critical DI pattern fixes (READ THIS!)
- `frontend/e2e/README.md`: E2E testing setup and guidelines
- **`frontend/docs/design-system/`**: Comprehensive design system documentation
- `README.md`: Hub with learning paths (start here)
- `00-quick-start.md`: 5-minute crash course
- `01-foundations.md`: Colors (OKLCH), typography, spacing, shadows
- `02-components.md`: shadcn/ui component library guide
- `03-layouts.md`: Layout patterns (Grid vs Flex decision trees)
- `04-spacing-philosophy.md`: Parent-controlled spacing strategy
- `05-component-creation.md`: When to create vs compose components
- `06-forms.md`: Form patterns with react-hook-form + Zod
- `07-accessibility.md`: WCAG AA compliance, keyboard navigation, screen readers
- `08-ai-guidelines.md`: **AI code generation rules (read this!)**
- `99-reference.md`: Quick reference cheat sheet (bookmark this)

View File

@@ -90,22 +90,27 @@ Ready to write some code? Awesome!
```bash ```bash
cd backend cd backend
# Setup virtual environment # Install dependencies (uv manages virtual environment automatically)
python -m venv .venv make install-dev
source .venv/bin/activate
# Install dependencies # Setup pre-commit hooks
pip install -r requirements.txt uv run pre-commit install
# Setup environment # Setup environment
cp .env.example .env cp .env.example .env
# Edit .env with your settings # Edit .env with your settings
# Run migrations # Run migrations
alembic upgrade head python migrate.py apply
# Run quality + security checks
make validate-all
# Run tests # Run tests
IS_TEST=True pytest make test
# Run full pipeline (quality + security + tests)
make check
# Start dev server # Start dev server
uvicorn app.main:app --reload uvicorn app.main:app --reload
@@ -117,20 +122,20 @@ uvicorn app.main:app --reload
cd frontend cd frontend
# Install dependencies # Install dependencies
npm install bun install
# Setup environment # Setup environment
cp .env.local.example .env.local cp .env.local.example .env.local
# Generate API client # Generate API client
npm run generate:api bun run generate:api
# Run tests # Run tests
npm test bun run test
npm run test:e2e:ui bun run test:e2e:ui
# Start dev server # Start dev server
npm run dev bun run dev
``` ```
--- ---
@@ -199,7 +204,7 @@ export function UserProfile({ userId }: UserProfileProps) {
### Key Patterns ### Key Patterns
- **Backend**: Use CRUD pattern, keep routes thin, business logic in services - **Backend**: Use repository pattern, keep routes thin, business logic in services
- **Frontend**: Use React Query for server state, Zustand for client state - **Frontend**: Use React Query for server state, Zustand for client state
- **Both**: Handle errors gracefully, log appropriately, write tests - **Both**: Handle errors gracefully, log appropriately, write tests
@@ -320,7 +325,7 @@ Fixed stuff
### Before Submitting ### Before Submitting
- [ ] Code follows project style guidelines - [ ] Code follows project style guidelines
- [ ] All tests pass locally - [ ] `make check` passes (quality + security + tests) in backend
- [ ] New tests added for new features - [ ] New tests added for new features
- [ ] Documentation updated if needed - [ ] Documentation updated if needed
- [ ] No merge conflicts with `main` - [ ] No merge conflicts with `main`

File diff suppressed because it is too large Load Diff

119
Makefile
View File

@@ -1,31 +1,124 @@
.PHONY: dev prod down clean clean-slate .PHONY: help dev dev-full prod down logs logs-dev clean clean-slate drop-db reset-db push-images deploy scan-images
VERSION ?= latest VERSION ?= latest
REGISTRY := gitea.pragmazest.com/cardosofelipe/app REGISTRY ?= ghcr.io/cardosofelipe/pragma-stack
# Default target
help:
@echo "FastAPI + Next.js Full-Stack Template"
@echo ""
@echo "Development:"
@echo " make dev - Start backend + db (frontend runs separately)"
@echo " make dev-full - Start all services including frontend"
@echo " make down - Stop all services"
@echo " make logs-dev - Follow dev container logs"
@echo ""
@echo "Database:"
@echo " make drop-db - Drop and recreate empty database"
@echo " make reset-db - Drop database and apply all migrations"
@echo ""
@echo "Production:"
@echo " make prod - Start production stack"
@echo " make deploy - Pull and deploy latest images"
@echo " make push-images - Build and push images to registry"
@echo " make scan-images - Scan production images for CVEs (requires trivy)"
@echo " make logs - Follow production container logs"
@echo ""
@echo "Cleanup:"
@echo " make clean - Stop containers"
@echo " make clean-slate - Stop containers AND delete volumes (DATA LOSS!)"
@echo ""
@echo "Subdirectory commands:"
@echo " cd backend && make help - Backend-specific commands"
@echo " cd frontend && npm run - Frontend-specific commands"
# ============================================================================
# Development
# ============================================================================
dev: dev:
docker compose -f docker-compose.dev.yml up --build -d # Bring up all dev services except the frontend
docker compose -f docker-compose.dev.yml up --build -d --scale frontend=0
@echo ""
@echo "Frontend is not started by 'make dev'."
@echo "To run the frontend locally, open a new terminal and run:"
@echo " cd frontend && npm run dev"
prod: dev-full:
docker compose up --build -d # Bring up all dev services including the frontend (full stack)
docker compose -f docker-compose.dev.yml up --build -d
down: down:
docker compose down docker compose down
logs:
docker compose logs -f
logs-dev:
docker compose -f docker-compose.dev.yml logs -f
# ============================================================================
# Database Management
# ============================================================================
drop-db:
@echo "Dropping local database..."
@docker compose -f docker-compose.dev.yml exec -T db psql -U postgres -c "DROP DATABASE IF EXISTS app WITH (FORCE);" 2>/dev/null || \
docker compose -f docker-compose.dev.yml exec -T db psql -U postgres -c "DROP DATABASE IF EXISTS app;"
@docker compose -f docker-compose.dev.yml exec -T db psql -U postgres -c "CREATE DATABASE app;"
@echo "Database dropped and recreated (empty)"
reset-db: drop-db
@echo "Applying migrations..."
@cd backend && uv run python migrate.py --local apply
@echo "Database reset complete!"
# ============================================================================
# Production / Deployment
# ============================================================================
prod:
docker compose up --build -d
deploy: deploy:
docker compose -f docker-compose.deploy.yml pull docker compose -f docker-compose.deploy.yml pull
docker compose -f docker-compose.deploy.yml up -d docker compose -f docker-compose.deploy.yml up -d
clean:
docker compose down -
# WARNING! THIS REMOVES CONTAINERS AND VOLUMES AS WELL - DO NOT USE THIS UNLESS YOU WANT TO START OVER WITH DATA AND ALL
clean-slate:
docker compose down -v
push-images: push-images:
docker build -t $(REGISTRY)/backend:$(VERSION) ./backend docker build -t $(REGISTRY)/backend:$(VERSION) ./backend
docker build -t $(REGISTRY)/frontend:$(VERSION) ./frontend docker build -t $(REGISTRY)/frontend:$(VERSION) ./frontend
docker push $(REGISTRY)/backend:$(VERSION) docker push $(REGISTRY)/backend:$(VERSION)
docker push $(REGISTRY)/frontend:$(VERSION) docker push $(REGISTRY)/frontend:$(VERSION)
scan-images:
@docker info > /dev/null 2>&1 || (echo "❌ Docker is not running!"; exit 1)
@echo "🐳 Building and scanning production images for CVEs..."
docker build -t $(REGISTRY)/backend:scan --target production ./backend
docker build -t $(REGISTRY)/frontend:scan --target runner ./frontend
@echo ""
@echo "=== Backend Image Scan ==="
@if command -v trivy > /dev/null 2>&1; then \
trivy image --severity HIGH,CRITICAL --exit-code 1 $(REGISTRY)/backend:scan; \
else \
echo " Trivy not found locally, using Docker to run Trivy..."; \
docker run --rm -v /var/run/docker.sock:/var/run/docker.sock aquasec/trivy image --severity HIGH,CRITICAL --exit-code 1 $(REGISTRY)/backend:scan; \
fi
@echo ""
@echo "=== Frontend Image Scan ==="
@if command -v trivy > /dev/null 2>&1; then \
trivy image --severity HIGH,CRITICAL --exit-code 1 $(REGISTRY)/frontend:scan; \
else \
docker run --rm -v /var/run/docker.sock:/var/run/docker.sock aquasec/trivy image --severity HIGH,CRITICAL --exit-code 1 $(REGISTRY)/frontend:scan; \
fi
@echo "✅ No HIGH/CRITICAL CVEs found in production images!"
# ============================================================================
# Cleanup
# ============================================================================
clean:
docker compose down
# WARNING! THIS REMOVES CONTAINERS AND VOLUMES AS WELL - DO NOT USE THIS UNLESS YOU WANT TO START OVER WITH DATA AND ALL
clean-slate:
docker compose -f docker-compose.dev.yml down -v --remove-orphans

224
README.md
View File

@@ -1,29 +1,29 @@
# FastAPI + Next.js Full-Stack Template # <img src="frontend/public/logo.svg" alt="PragmaStack" width="32" height="32" style="vertical-align: middle" /> PragmaStack
> **Production-ready, security-first, full-stack TypeScript/Python template with authentication, multi-tenancy, and a comprehensive admin panel.** > **The Pragmatic Full-Stack Template. Production-ready, security-first, and opinionated.**
<!--
TODO: Replace these static badges with dynamic CI/CD badges when GitHub Actions is set up
Example: https://github.com/YOUR_ORG/YOUR_REPO/actions/workflows/backend-tests.yml/badge.svg
-->
[![Backend Unit Tests](https://img.shields.io/badge/backend_unit_tests-passing-success)](./backend/tests)
[![Backend Coverage](https://img.shields.io/badge/backend_coverage-97%25-brightgreen)](./backend/tests) [![Backend Coverage](https://img.shields.io/badge/backend_coverage-97%25-brightgreen)](./backend/tests)
[![Frontend Unit Tests](https://img.shields.io/badge/frontend_unit_tests-passing-success)](./frontend/tests)
[![Frontend Coverage](https://img.shields.io/badge/frontend_coverage-97%25-brightgreen)](./frontend/tests) [![Frontend Coverage](https://img.shields.io/badge/frontend_coverage-97%25-brightgreen)](./frontend/tests)
[![E2E Tests](https://img.shields.io/badge/e2e_tests-passing-success)](./frontend/e2e) [![E2E Tests](https://img.shields.io/badge/e2e_tests-passing-success)](./frontend/e2e)
[![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)](./LICENSE) [![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)](./LICENSE)
[![PRs Welcome](https://img.shields.io/badge/PRs-welcome-brightgreen.svg)](./CONTRIBUTING.md) [![PRs Welcome](https://img.shields.io/badge/PRs-welcome-brightgreen.svg)](./CONTRIBUTING.md)
![Landing Page](docs/images/landing.png)
--- ---
## Why This Template? ## Why PragmaStack?
Building a modern full-stack application from scratch means solving the same problems over and over: authentication, authorization, multi-tenancy, admin panels, session management, database migrations, API documentation, testing infrastructure... Building a modern full-stack application often leads to "analysis paralysis" or "boilerplate fatigue". You spend weeks setting up authentication, testing, and linting before writing a single line of business logic.
**This template gives you all of that, battle-tested and ready to go.** **PragmaStack cuts through the noise.**
Instead of spending weeks on boilerplate, you can focus on building your unique features. Whether you're building a SaaS product, an internal tool, or a side project, this template provides a rock-solid foundation with modern best practices baked in. We provide a **pragmatic**, opinionated foundation that prioritizes:
- **Speed**: Ship features, not config files.
- **Robustness**: Security and testing are not optional.
- **Clarity**: Code that is easy to read and maintain.
Whether you're building a SaaS, an internal tool, or a side project, PragmaStack gives you a solid starting point without the bloat.
--- ---
@@ -31,12 +31,26 @@ Instead of spending weeks on boilerplate, you can focus on building your unique
### 🔐 **Authentication & Security** ### 🔐 **Authentication & Security**
- JWT-based authentication with access + refresh tokens - JWT-based authentication with access + refresh tokens
- **OAuth/Social Login** (Google, GitHub) with PKCE support
- **OAuth 2.0 Authorization Server** (MCP-ready) for third-party integrations
- Session management with device tracking and revocation - Session management with device tracking and revocation
- Password reset flow (email integration ready) - Password reset flow (email integration ready)
- Secure password hashing (bcrypt) - Secure password hashing (bcrypt)
- CSRF protection, rate limiting, and security headers - CSRF protection, rate limiting, and security headers
- Comprehensive security tests (JWT algorithm attacks, session hijacking, privilege escalation) - Comprehensive security tests (JWT algorithm attacks, session hijacking, privilege escalation)
### 🔌 **OAuth Provider Mode (MCP Integration)**
Full OAuth 2.0 Authorization Server for Model Context Protocol (MCP) and third-party clients:
- **RFC 7636**: Authorization Code Flow with PKCE (S256 only)
- **RFC 8414**: Server metadata discovery at `/.well-known/oauth-authorization-server`
- **RFC 7662**: Token introspection endpoint
- **RFC 7009**: Token revocation endpoint
- **JWT access tokens**: Self-contained, configurable lifetime
- **Opaque refresh tokens**: Secure rotation, database-backed revocation
- **Consent management**: Users can review and revoke app permissions
- **Client management**: Admin endpoints for registering OAuth clients
- **Scopes**: `openid`, `profile`, `email`, `read:users`, `write:users`, `admin`
### 👥 **Multi-Tenancy & Organizations** ### 👥 **Multi-Tenancy & Organizations**
- Full organization system with role-based access control (Owner, Admin, Member) - Full organization system with role-based access control (Owner, Admin, Member)
- Invite/remove members, manage permissions - Invite/remove members, manage permissions
@@ -44,18 +58,35 @@ Instead of spending weeks on boilerplate, you can focus on building your unique
- User can belong to multiple organizations - User can belong to multiple organizations
### 🛠️ **Admin Panel** ### 🛠️ **Admin Panel**
- Complete user management (CRUD, activate/deactivate, bulk operations) - Complete user management (full lifecycle, activate/deactivate, bulk operations)
- Organization management (create, edit, delete, member management) - Organization management (create, edit, delete, member management)
- Session monitoring across all users - Session monitoring across all users
- Real-time statistics dashboard - Real-time statistics dashboard
- Admin-only routes with proper authorization - Admin-only routes with proper authorization
### 🎨 **Modern Frontend** ### 🎨 **Modern Frontend**
- Next.js 15 with App Router and React 19 - Next.js 16 with App Router and React 19
- Comprehensive design system built on shadcn/ui + TailwindCSS - **PragmaStack Design System** built on shadcn/ui + TailwindCSS
- Pre-configured theme with dark mode support (coming soon) - Pre-configured theme with dark mode support (coming soon)
- Responsive, accessible components (WCAG AA compliant) - Responsive, accessible components (WCAG AA compliant)
- Developer documentation at `/dev` (in progress) - Rich marketing landing page with animated components
- Live component showcase and documentation at `/dev`
### 🌍 **Internationalization (i18n)**
- Built-in multi-language support with next-intl v4
- Locale-based routing (`/en/*`, `/it/*`)
- Seamless language switching with LocaleSwitcher component
- SEO-friendly URLs and metadata per locale
- Translation files for English and Italian (easily extensible)
- Type-safe translations throughout the app
### 🎯 **Content & UX Features**
- **Toast notifications** with Sonner for elegant user feedback
- **Smooth animations** powered by Framer Motion
- **Markdown rendering** with syntax highlighting (GitHub Flavored Markdown)
- **Charts and visualizations** ready with Recharts
- **SEO optimization** with dynamic sitemap and robots.txt generation
- **Session tracking UI** with device information and revocation controls
### 🧪 **Comprehensive Testing** ### 🧪 **Comprehensive Testing**
- **Backend Testing**: ~97% unit test coverage - **Backend Testing**: ~97% unit test coverage
@@ -75,9 +106,10 @@ Instead of spending weeks on boilerplate, you can focus on building your unique
### 📚 **Developer Experience** ### 📚 **Developer Experience**
- Auto-generated TypeScript API client from OpenAPI spec - Auto-generated TypeScript API client from OpenAPI spec
- Interactive API documentation (Swagger + ReDoc) - Interactive API documentation (Swagger + ReDoc)
- Database migrations with Alembic - Database migrations with Alembic helper script
- Hot reload in development - Hot reload in development for both frontend and backend
- Comprehensive code documentation - Comprehensive code documentation and design system docs
- Live component playground at `/dev` with code examples
- Docker support for easy deployment - Docker support for easy deployment
- VSCode workspace settings included - VSCode workspace settings included
@@ -89,6 +121,68 @@ Instead of spending weeks on boilerplate, you can focus on building your unique
- Health check endpoints - Health check endpoints
- Production security headers - Production security headers
- Rate limiting on sensitive endpoints - Rate limiting on sensitive endpoints
- SEO optimization with dynamic sitemaps and robots.txt
- Multi-language SEO with locale-specific metadata
- Performance monitoring and bundle analysis
---
## 📸 Screenshots
<details>
<summary>Click to view screenshots</summary>
### Landing Page
![Landing Page](docs/images/landing.png)
### Authentication
![Login Page](docs/images/login.png)
### Admin Dashboard
![Admin Dashboard](docs/images/admin-dashboard.png)
### Design System
![Components](docs/images/design-system.png)
</details>
---
## 🎭 Demo Mode
**Try the frontend without a backend!** Perfect for:
- **Free deployment** on Vercel (no backend costs)
- **Portfolio showcasing** with live demos
- **Client presentations** without infrastructure setup
### Quick Start
```bash
cd frontend
echo "NEXT_PUBLIC_DEMO_MODE=true" > .env.local
bun run dev
```
**Demo Credentials:**
- Regular user: `demo@example.com` / `DemoPass123`
- Admin user: `admin@example.com` / `AdminPass123`
Demo mode uses [Mock Service Worker (MSW)](https://mswjs.io/) to intercept API calls in the browser. Your code remains unchanged - the same components work with both real and mocked backends.
**Key Features:**
- ✅ Zero backend required
- ✅ All features functional (auth, admin, stats)
- ✅ Realistic network delays and errors
- ✅ Does NOT interfere with tests (97%+ coverage maintained)
- ✅ One-line toggle: `NEXT_PUBLIC_DEMO_MODE=true`
📖 **[Complete Demo Mode Documentation](./frontend/docs/DEMO_MODE.md)**
--- ---
@@ -103,13 +197,18 @@ Instead of spending weeks on boilerplate, you can focus on building your unique
- **[pytest](https://pytest.org/)** - Testing framework with async support - **[pytest](https://pytest.org/)** - Testing framework with async support
### Frontend ### Frontend
- **[Next.js 15](https://nextjs.org/)** - React framework with App Router - **[Next.js 16](https://nextjs.org/)** - React framework with App Router
- **[React 19](https://react.dev/)** - UI library - **[React 19](https://react.dev/)** - UI library
- **[TypeScript](https://www.typescriptlang.org/)** - Type-safe JavaScript - **[TypeScript](https://www.typescriptlang.org/)** - Type-safe JavaScript
- **[TailwindCSS](https://tailwindcss.com/)** - Utility-first CSS framework - **[TailwindCSS](https://tailwindcss.com/)** - Utility-first CSS framework
- **[shadcn/ui](https://ui.shadcn.com/)** - Beautiful, accessible component library - **[shadcn/ui](https://ui.shadcn.com/)** - Beautiful, accessible component library
- **[next-intl](https://next-intl.dev/)** - Internationalization (i18n) with type safety
- **[TanStack Query](https://tanstack.com/query)** - Powerful data fetching/caching - **[TanStack Query](https://tanstack.com/query)** - Powerful data fetching/caching
- **[Zustand](https://zustand-demo.pmnd.rs/)** - Lightweight state management - **[Zustand](https://zustand-demo.pmnd.rs/)** - Lightweight state management
- **[Framer Motion](https://www.framer.com/motion/)** - Production-ready animation library
- **[Sonner](https://sonner.emilkowal.ski/)** - Beautiful toast notifications
- **[Recharts](https://recharts.org/)** - Composable charting library
- **[React Markdown](https://github.com/remarkjs/react-markdown)** - Markdown rendering with GFM support
- **[Playwright](https://playwright.dev/)** - End-to-end testing - **[Playwright](https://playwright.dev/)** - End-to-end testing
### DevOps ### DevOps
@@ -135,12 +234,11 @@ The fastest way to get started is with Docker:
```bash ```bash
# Clone the repository # Clone the repository
git clone https://github.com/yourusername/fast-next-template.git git clone https://github.com/cardosofelipe/pragma-stack.git
cd fast-next-template cd fast-next-template
# Copy environment files # Copy environment file
cp backend/.env.example backend/.env cp .env.template .env
cp frontend/.env.local.example frontend/.env.local
# Start all services (backend, frontend, database) # Start all services (backend, frontend, database)
docker-compose up docker-compose up
@@ -200,17 +298,17 @@ uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
cd frontend cd frontend
# Install dependencies # Install dependencies
npm install bun install
# Setup environment # Setup environment
cp .env.local.example .env.local cp .env.local.example .env.local
# Edit .env.local with your backend URL # Edit .env.local with your backend URL
# Generate API client # Generate API client
npm run generate:api bun run generate:api
# Start development server # Start development server
npm run dev bun run dev
``` ```
Visit http://localhost:3000 to see your app! Visit http://localhost:3000 to see your app!
@@ -224,7 +322,7 @@ Visit http://localhost:3000 to see your app!
│ ├── app/ │ ├── app/
│ │ ├── api/ # API routes and dependencies │ │ ├── api/ # API routes and dependencies
│ │ ├── core/ # Core functionality (auth, config, database) │ │ ├── core/ # Core functionality (auth, config, database)
│ │ ├── crud/ # Database operations │ │ ├── repositories/ # Repository pattern (database operations)
│ │ ├── models/ # SQLAlchemy models │ │ ├── models/ # SQLAlchemy models
│ │ ├── schemas/ # Pydantic schemas │ │ ├── schemas/ # Pydantic schemas
│ │ ├── services/ # Business logic │ │ ├── services/ # Business logic
@@ -279,7 +377,7 @@ open htmlcov/index.html
``` ```
**Test types:** **Test types:**
- **Unit tests**: CRUD operations, utilities, business logic - **Unit tests**: Repository operations, utilities, business logic
- **Integration tests**: API endpoints with database - **Integration tests**: API endpoints with database
- **Security tests**: JWT algorithm attacks, session hijacking, privilege escalation - **Security tests**: JWT algorithm attacks, session hijacking, privilege escalation
- **Error handling tests**: Database failures, validation errors - **Error handling tests**: Database failures, validation errors
@@ -292,13 +390,13 @@ open htmlcov/index.html
cd frontend cd frontend
# Run unit tests # Run unit tests
npm test bun run test
# Run with coverage # Run with coverage
npm run test:coverage bun run test:coverage
# Watch mode # Watch mode
npm run test:watch bun run test:watch
``` ```
**Test types:** **Test types:**
@@ -316,10 +414,10 @@ npm run test:watch
cd frontend cd frontend
# Run E2E tests # Run E2E tests
npm run test:e2e bun run test:e2e
# Run E2E tests in UI mode (recommended for development) # Run E2E tests in UI mode (recommended for development)
npm run test:e2e:ui bun run test:e2e:ui
# Run specific test file # Run specific test file
npx playwright test auth-login.spec.ts npx playwright test auth-login.spec.ts
@@ -338,6 +436,17 @@ npx playwright show-report
--- ---
## 🤖 AI-Friendly Documentation
This project includes comprehensive documentation designed for AI coding assistants:
- **[AGENTS.md](./AGENTS.md)** - Framework-agnostic AI assistant context for PragmaStack
- **[CLAUDE.md](./CLAUDE.md)** - Claude Code-specific guidance
These files provide AI assistants with the **PragmaStack** architecture, patterns, and best practices.
---
## 🗄️ Database Migrations ## 🗄️ Database Migrations
The template uses Alembic for database migrations: The template uses Alembic for database migrations:
@@ -365,22 +474,25 @@ python migrate.py current
## 📖 Documentation ## 📖 Documentation
### AI Assistant Documentation
- **[AGENTS.md](./AGENTS.md)** - Framework-agnostic AI coding assistant context
- **[CLAUDE.md](./CLAUDE.md)** - Claude Code-specific guidance and preferences
### Backend Documentation ### Backend Documentation
- **[ARCHITECTURE.md](./backend/docs/ARCHITECTURE.md)** - System architecture and design patterns - **[ARCHITECTURE.md](./backend/docs/ARCHITECTURE.md)** - System architecture and design patterns
- **[CODING_STANDARDS.md](./backend/docs/CODING_STANDARDS.md)** - Code quality standards - **[CODING_STANDARDS.md](./backend/docs/CODING_STANDARDS.md)** - Code quality standards
- **[COMMON_PITFALLS.md](./backend/docs/COMMON_PITFALLS.md)** - Common mistakes to avoid - **[COMMON_PITFALLS.md](./backend/docs/COMMON_PITFALLS.md)** - Common mistakes to avoid
- **[FEATURE_EXAMPLE.md](./backend/docs/FEATURE_EXAMPLE.md)** - Step-by-step feature guide - **[FEATURE_EXAMPLE.md](./backend/docs/FEATURE_EXAMPLE.md)** - Step-by-step feature guide
- **[CLAUDE.md](./CLAUDE.md)** - Comprehensive development guide
### Frontend Documentation ### Frontend Documentation
- **[Design System Docs](./frontend/docs/design-system/)** - Complete design system guide - **[PragmaStack Design System](./frontend/docs/design-system/)** - Complete design system guide
- Quick start, foundations (colors, typography, spacing) - Quick start, foundations (colors, typography, spacing)
- Component library guide - Component library guide
- Layout patterns, spacing philosophy - Layout patterns, spacing philosophy
- Forms, accessibility, AI guidelines - Forms, accessibility, AI guidelines
- **[ARCHITECTURE_FIX_REPORT.md](./frontend/docs/ARCHITECTURE_FIX_REPORT.md)** - Critical dependency injection patterns
- **[E2E Testing Guide](./frontend/e2e/README.md)** - E2E testing setup and best practices - **[E2E Testing Guide](./frontend/e2e/README.md)** - E2E testing setup and best practices
### API Documentation ### API Documentation
@@ -429,37 +541,43 @@ docker-compose down
## 🛣️ Roadmap & Status ## 🛣️ Roadmap & Status
### ✅ Completed ### ✅ Completed
- [x] Authentication system (JWT, refresh tokens, session management) - [x] Authentication system (JWT, refresh tokens, session management, OAuth)
- [x] User management (CRUD, profile, password change) - [x] User management (full lifecycle, profile, password change)
- [x] Organization system with RBAC (Owner, Admin, Member) - [x] Organization system with RBAC (Owner, Admin, Member)
- [x] Admin panel (users, organizations, sessions, statistics) - [x] Admin panel (users, organizations, sessions, statistics)
- [x] **Internationalization (i18n)** with next-intl (English + Italian)
- [x] Backend testing infrastructure (~97% coverage) - [x] Backend testing infrastructure (~97% coverage)
- [x] Frontend unit testing infrastructure (~97% coverage) - [x] Frontend unit testing infrastructure (~97% coverage)
- [x] Frontend E2E testing (Playwright, zero flaky tests) - [x] Frontend E2E testing (Playwright, zero flaky tests)
- [x] Design system documentation - [x] Design system documentation
- [x] Database migrations - [x] **Marketing landing page** with animated components
- [x] **`/dev` documentation portal** with live component examples
- [x] **Toast notifications** system (Sonner)
- [x] **Charts and visualizations** (Recharts)
- [x] **Animation system** (Framer Motion)
- [x] **Markdown rendering** with syntax highlighting
- [x] **SEO optimization** (sitemap, robots.txt, locale-aware metadata)
- [x] Database migrations with helper script
- [x] Docker deployment - [x] Docker deployment
- [x] API documentation (OpenAPI/Swagger) - [x] API documentation (OpenAPI/Swagger)
### 🚧 In Progress ### 🚧 In Progress
- [ ] Frontend admin pages (70% complete)
- [ ] Dark mode theme
- [ ] `/dev` documentation page with examples
- [ ] Email integration (templates ready, SMTP pending) - [ ] Email integration (templates ready, SMTP pending)
- [ ] Chart/visualization components
### 🔮 Planned ### 🔮 Planned
- [ ] GitHub Actions CI/CD pipelines - [ ] GitHub Actions CI/CD pipelines
- [ ] Dynamic test coverage badges from CI - [ ] Dynamic test coverage badges from CI
- [ ] E2E test coverage reporting - [ ] E2E test coverage reporting
- [ ] Additional authentication methods (OAuth, SSO) - [ ] OAuth token encryption at rest (security hardening)
- [ ] Additional languages (Spanish, French, German, etc.)
- [ ] SSO/SAML authentication
- [ ] Real-time notifications with WebSockets
- [ ] Webhook system - [ ] Webhook system
- [ ] Background job processing - [ ] File upload/storage (S3-compatible)
- [ ] File upload/storage - [ ] Audit logging system
- [ ] Notification system
- [ ] Audit logging
- [ ] API versioning example - [ ] API versioning example
--- ---
## 🤝 Contributing ## 🤝 Contributing
@@ -489,7 +607,7 @@ Contributions are welcome! Whether you're fixing bugs, improving documentation,
### Reporting Issues ### Reporting Issues
Found a bug? Have a suggestion? [Open an issue](https://github.com/yourusername/fast-next-template/issues)! Found a bug? Have a suggestion? [Open an issue](https://github.com/cardosofelipe/pragma-stack/issues)!
Please include: Please include:
- Clear description of the issue/suggestion - Clear description of the issue/suggestion
@@ -523,8 +641,8 @@ This template is built on the shoulders of giants:
## 💬 Questions? ## 💬 Questions?
- **Documentation**: Check the `/docs` folders in backend and frontend - **Documentation**: Check the `/docs` folders in backend and frontend
- **Issues**: [GitHub Issues](https://github.com/yourusername/fast-next-template/issues) - **Issues**: [GitHub Issues](https://github.com/cardosofelipe/pragma-stack/issues)
- **Discussions**: [GitHub Discussions](https://github.com/yourusername/fast-next-template/discussions) - **Discussions**: [GitHub Discussions](https://github.com/cardosofelipe/pragma-stack/discussions)
--- ---

View File

@@ -11,16 +11,19 @@ omit =
app/utils/auth_test_utils.py app/utils/auth_test_utils.py
# Async implementations not yet in use # Async implementations not yet in use
app/crud/base_async.py app/repositories/base_async.py
app/core/database_async.py app/core/database_async.py
# CLI scripts - run manually, not tested
app/init_db.py
# __init__ files with no logic # __init__ files with no logic
app/__init__.py app/__init__.py
app/api/__init__.py app/api/__init__.py
app/api/routes/__init__.py app/api/routes/__init__.py
app/api/dependencies/__init__.py app/api/dependencies/__init__.py
app/core/__init__.py app/core/__init__.py
app/crud/__init__.py app/repositories/__init__.py
app/models/__init__.py app/models/__init__.py
app/schemas/__init__.py app/schemas/__init__.py
app/services/__init__.py app/services/__init__.py

View File

@@ -1,2 +1,17 @@
.venv .venv
*.iml *.iml
# Python build and cache artifacts
__pycache__/
.pytest_cache/
.mypy_cache/
.ruff_cache/
*.pyc
*.pyo
# Packaging artifacts
*.egg-info/
build/
dist/
htmlcov/
.uv_cache/

View File

@@ -0,0 +1,44 @@
# Pre-commit hooks for backend quality and security checks.
#
# Install:
# cd backend && uv run pre-commit install
#
# Run manually on all files:
# cd backend && uv run pre-commit run --all-files
#
# Skip hooks temporarily:
# git commit --no-verify
#
repos:
# ── Code Quality ──────────────────────────────────────────────────────────
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.14.4
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- id: ruff-format
# ── General File Hygiene ──────────────────────────────────────────────────
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-toml
- id: check-merge-conflict
- id: check-added-large-files
args: [--maxkb=500]
- id: debug-statements
# ── Security ──────────────────────────────────────────────────────────────
- repo: https://github.com/Yelp/detect-secrets
rev: v1.5.0
hooks:
- id: detect-secrets
args: ['--baseline', '.secrets.baseline']
exclude: |
(?x)^(
.*\.lock$|
.*\.svg$
)$

1073
backend/.secrets.baseline Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -1,9 +1,6 @@
# Development stage # Development stage
FROM python:3.12-slim AS development FROM python:3.12-slim AS development
# Create non-root user
RUN groupadd -r appuser && useradd -r -g appuser appuser
WORKDIR /app WORKDIR /app
ENV PYTHONDONTWRITEBYTECODE=1 \ ENV PYTHONDONTWRITEBYTECODE=1 \
PYTHONUNBUFFERED=1 \ PYTHONUNBUFFERED=1 \
@@ -31,19 +28,16 @@ COPY . .
COPY entrypoint.sh /usr/local/bin/ COPY entrypoint.sh /usr/local/bin/
RUN chmod +x /usr/local/bin/entrypoint.sh RUN chmod +x /usr/local/bin/entrypoint.sh
# Set ownership to non-root user # Note: Running as root in development for bind mount compatibility
RUN chown -R appuser:appuser /app # Production stage uses non-root user for security
# Switch to non-root user
USER appuser
ENTRYPOINT ["/usr/local/bin/entrypoint.sh"] ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
# Production stage # Production stage — Alpine eliminates glibc CVEs (e.g. CVE-2026-0861)
FROM python:3.12-slim AS production FROM python:3.12-alpine AS production
# Create non-root user # Create non-root user
RUN groupadd -r appuser && useradd -r -g appuser appuser RUN addgroup -S appuser && adduser -S -G appuser appuser
WORKDIR /app WORKDIR /app
ENV PYTHONDONTWRITEBYTECODE=1 \ ENV PYTHONDONTWRITEBYTECODE=1 \
@@ -54,18 +48,18 @@ ENV PYTHONDONTWRITEBYTECODE=1 \
UV_NO_CACHE=1 UV_NO_CACHE=1
# Install system dependencies and uv # Install system dependencies and uv
RUN apt-get update && \ RUN apk add --no-cache postgresql-client curl ca-certificates && \
apt-get install -y --no-install-recommends postgresql-client curl ca-certificates && \
curl -LsSf https://astral.sh/uv/install.sh | sh && \ curl -LsSf https://astral.sh/uv/install.sh | sh && \
mv /root/.local/bin/uv* /usr/local/bin/ && \ mv /root/.local/bin/uv* /usr/local/bin/
apt-get clean && \
rm -rf /var/lib/apt/lists/*
# Copy dependency files # Copy dependency files
COPY pyproject.toml uv.lock ./ COPY pyproject.toml uv.lock ./
# Install only production dependencies using uv (no dev dependencies) # Install build dependencies, compile Python packages, then remove build deps
RUN uv sync --frozen --no-dev RUN apk add --no-cache --virtual .build-deps \
gcc g++ musl-dev python3-dev linux-headers libffi-dev openssl-dev && \
uv sync --frozen --no-dev && \
apk del .build-deps
# Copy application code # Copy application code
COPY . . COPY . .

View File

@@ -1,4 +1,7 @@
.PHONY: help lint lint-fix format format-check type-check test test-cov validate clean install-dev sync .PHONY: help lint lint-fix format format-check type-check test test-cov validate clean install-dev sync check-docker install-e2e test-e2e test-e2e-schema test-all dep-audit license-check audit validate-all check benchmark benchmark-check benchmark-save scan-image test-api-security
# Prevent a stale VIRTUAL_ENV in the caller's shell from confusing uv
unexport VIRTUAL_ENV
# Default target # Default target
help: help:
@@ -6,6 +9,7 @@ help:
@echo "" @echo ""
@echo "Setup:" @echo "Setup:"
@echo " make install-dev - Install all dependencies with uv (includes dev)" @echo " make install-dev - Install all dependencies with uv (includes dev)"
@echo " make install-e2e - Install E2E test dependencies (requires Docker)"
@echo " make sync - Sync dependencies from uv.lock" @echo " make sync - Sync dependencies from uv.lock"
@echo "" @echo ""
@echo "Quality Checks:" @echo "Quality Checks:"
@@ -13,12 +17,30 @@ help:
@echo " make lint-fix - Run Ruff linter with auto-fix" @echo " make lint-fix - Run Ruff linter with auto-fix"
@echo " make format - Format code with Ruff" @echo " make format - Format code with Ruff"
@echo " make format-check - Check if code is formatted" @echo " make format-check - Check if code is formatted"
@echo " make type-check - Run mypy type checking" @echo " make type-check - Run pyright type checking"
@echo " make validate - Run all checks (lint + format + types)" @echo " make validate - Run all checks (lint + format + types + schema fuzz)"
@echo ""
@echo "Performance:"
@echo " make benchmark - Run performance benchmarks"
@echo " make benchmark-save - Run benchmarks and save as baseline"
@echo " make benchmark-check - Run benchmarks and compare against baseline"
@echo ""
@echo "Security & Audit:"
@echo " make dep-audit - Scan dependencies for known vulnerabilities"
@echo " make license-check - Check dependency license compliance"
@echo " make audit - Run all security audits (deps + licenses)"
@echo " make scan-image - Scan Docker image for CVEs (requires trivy)"
@echo " make validate-all - Run all quality + security checks"
@echo " make check - Full pipeline: quality + security + tests"
@echo "" @echo ""
@echo "Testing:" @echo "Testing:"
@echo " make test - Run pytest" @echo " make test - Run pytest (unit/integration, SQLite)"
@echo " make test-cov - Run pytest with coverage report" @echo " make test-cov - Run pytest with coverage report"
@echo " make test-e2e - Run E2E tests (PostgreSQL, requires Docker)"
@echo " make test-e2e-schema - Run Schemathesis API schema tests"
@echo " make test-all - Run all tests (unit + E2E)"
@echo " make check-docker - Check if Docker is available"
@echo " make check - Full pipeline: quality + security + tests"
@echo "" @echo ""
@echo "Cleanup:" @echo "Cleanup:"
@echo " make clean - Remove cache and build artifacts" @echo " make clean - Remove cache and build artifacts"
@@ -58,12 +80,52 @@ format-check:
@uv run ruff format --check app/ tests/ @uv run ruff format --check app/ tests/
type-check: type-check:
@echo "🔎 Running mypy type checking..." @echo "🔎 Running pyright type checking..."
@uv run mypy app/ @uv run pyright app/
validate: lint format-check type-check validate: lint format-check type-check test-api-security
@echo "✅ All quality checks passed!" @echo "✅ All quality checks passed!"
# API Security Testing (Schemathesis property-based fuzzing)
test-api-security: check-docker
@echo "🔐 Running Schemathesis API security fuzzing..."
@IS_TEST=True PYTHONPATH=. uv run pytest tests/e2e/ -v -m "schemathesis" --tb=short -n 0
@echo "✅ API schema security tests passed!"
# ============================================================================
# Security & Audit
# ============================================================================
dep-audit:
@echo "🔒 Scanning dependencies for known vulnerabilities..."
@uv run pip-audit --desc --skip-editable
@echo "✅ No known vulnerabilities found!"
license-check:
@echo "📜 Checking dependency license compliance..."
@uv run pip-licenses --fail-on="GPL-3.0-or-later;AGPL-3.0-or-later" --format=plain > /dev/null
@echo "✅ All dependency licenses are compliant!"
audit: dep-audit license-check
@echo "✅ All security audits passed!"
scan-image: check-docker
@echo "🐳 Scanning Docker image for OS-level CVEs with Trivy..."
@docker build -t pragma-backend:scan -q --target production .
@if command -v trivy > /dev/null 2>&1; then \
trivy image --severity HIGH,CRITICAL --exit-code 1 pragma-backend:scan; \
else \
echo " Trivy not found locally, using Docker to run Trivy..."; \
docker run --rm -v /var/run/docker.sock:/var/run/docker.sock aquasec/trivy image --severity HIGH,CRITICAL --exit-code 1 pragma-backend:scan; \
fi
@echo "✅ No HIGH/CRITICAL CVEs found in Docker image!"
validate-all: validate audit
@echo "✅ All quality + security checks passed!"
check: validate-all test
@echo "✅ Full validation pipeline complete!"
# ============================================================================ # ============================================================================
# Testing # Testing
# ============================================================================ # ============================================================================
@@ -77,6 +139,68 @@ test-cov:
@IS_TEST=True PYTHONPATH=. uv run pytest --cov=app --cov-report=term-missing --cov-report=html -n 16 @IS_TEST=True PYTHONPATH=. uv run pytest --cov=app --cov-report=term-missing --cov-report=html -n 16
@echo "📊 Coverage report generated in htmlcov/index.html" @echo "📊 Coverage report generated in htmlcov/index.html"
# ============================================================================
# E2E Testing (requires Docker)
# ============================================================================
check-docker:
@docker info > /dev/null 2>&1 || (echo ""; \
echo "Docker is not running!"; \
echo ""; \
echo "E2E tests require Docker to be running."; \
echo "Please start Docker Desktop or Docker Engine and try again."; \
echo ""; \
echo "Quick start:"; \
echo " macOS/Windows: Open Docker Desktop"; \
echo " Linux: sudo systemctl start docker"; \
echo ""; \
exit 1)
@echo "Docker is available"
install-e2e:
@echo "📦 Installing E2E test dependencies..."
@uv sync --extra dev --extra e2e
@echo "✅ E2E dependencies installed!"
test-e2e: check-docker
@echo "🧪 Running E2E tests with PostgreSQL..."
@IS_TEST=True PYTHONPATH=. uv run pytest tests/e2e/ -v --tb=short -n 0
@echo "✅ E2E tests complete!"
test-e2e-schema: check-docker
@echo "🧪 Running Schemathesis API schema tests..."
@IS_TEST=True PYTHONPATH=. uv run pytest tests/e2e/ -v -m "schemathesis" --tb=short -n 0
# ============================================================================
# Performance Benchmarks
# ============================================================================
benchmark:
@echo "⏱️ Running performance benchmarks..."
@IS_TEST=True PYTHONPATH=. uv run pytest tests/benchmarks/ -v --benchmark-only --benchmark-sort=mean -p no:xdist --override-ini='addopts='
benchmark-save:
@echo "⏱️ Running benchmarks and saving baseline..."
@IS_TEST=True PYTHONPATH=. uv run pytest tests/benchmarks/ -v --benchmark-only --benchmark-save=baseline --benchmark-sort=mean -p no:xdist --override-ini='addopts='
@echo "✅ Benchmark baseline saved to .benchmarks/"
benchmark-check:
@echo "⏱️ Running benchmarks and comparing against baseline..."
@if find .benchmarks -name '*_baseline*' -print -quit 2>/dev/null | grep -q .; then \
IS_TEST=True PYTHONPATH=. uv run pytest tests/benchmarks/ -v --benchmark-only --benchmark-compare=0001_baseline --benchmark-sort=mean --benchmark-compare-fail=mean:200% -p no:xdist --override-ini='addopts='; \
echo "✅ No performance regressions detected!"; \
else \
echo "⚠️ No benchmark baseline found. Run 'make benchmark-save' first to create one."; \
echo " Running benchmarks without comparison..."; \
IS_TEST=True PYTHONPATH=. uv run pytest tests/benchmarks/ -v --benchmark-only --benchmark-save=baseline --benchmark-sort=mean -p no:xdist --override-ini='addopts='; \
echo "✅ Benchmark baseline created. Future runs of 'make benchmark-check' will compare against it."; \
fi
test-all:
@echo "🧪 Running ALL tests (unit + E2E)..."
@$(MAKE) test
@$(MAKE) test-e2e
# ============================================================================ # ============================================================================
# Cleanup # Cleanup
# ============================================================================ # ============================================================================
@@ -85,7 +209,7 @@ clean:
@echo "🧹 Cleaning up..." @echo "🧹 Cleaning up..."
@find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true @find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true
@find . -type d -name ".pytest_cache" -exec rm -rf {} + 2>/dev/null || true @find . -type d -name ".pytest_cache" -exec rm -rf {} + 2>/dev/null || true
@find . -type d -name ".mypy_cache" -exec rm -rf {} + 2>/dev/null || true @find . -type d -name ".pyright" -exec rm -rf {} + 2>/dev/null || true
@find . -type d -name ".ruff_cache" -exec rm -rf {} + 2>/dev/null || true @find . -type d -name ".ruff_cache" -exec rm -rf {} + 2>/dev/null || true
@find . -type d -name "*.egg-info" -exec rm -rf {} + 2>/dev/null || true @find . -type d -name "*.egg-info" -exec rm -rf {} + 2>/dev/null || true
@find . -type d -name "htmlcov" -exec rm -rf {} + 2>/dev/null || true @find . -type d -name "htmlcov" -exec rm -rf {} + 2>/dev/null || true

View File

@@ -1,10 +1,12 @@
# Backend API # PragmaStack Backend API
> FastAPI-based REST API with async SQLAlchemy, JWT authentication, and comprehensive testing. > The pragmatic, production-ready FastAPI backend for PragmaStack.
## Overview ## Overview
Production-ready FastAPI backend featuring: Opinionated, secure, and fast. This backend provides the solid foundation you need to ship features, not boilerplate.
Features:
- **Authentication**: JWT with refresh tokens, session management, device tracking - **Authentication**: JWT with refresh tokens, session management, device tracking
- **Database**: Async PostgreSQL with SQLAlchemy 2.0, Alembic migrations - **Database**: Async PostgreSQL with SQLAlchemy 2.0, Alembic migrations
@@ -12,7 +14,9 @@ Production-ready FastAPI backend featuring:
- **Multi-tenancy**: Organization-based access control with roles (Owner/Admin/Member) - **Multi-tenancy**: Organization-based access control with roles (Owner/Admin/Member)
- **Testing**: 97%+ coverage with security-focused test suite - **Testing**: 97%+ coverage with security-focused test suite
- **Performance**: Async throughout, connection pooling, optimized queries - **Performance**: Async throughout, connection pooling, optimized queries
- **Modern Tooling**: uv for dependencies, Ruff for linting/formatting, mypy for type checking - **Modern Tooling**: uv for dependencies, Ruff for linting/formatting, Pyright for type checking
- **Security Auditing**: Automated dependency vulnerability scanning, license compliance, secrets detection
- **Pre-commit Hooks**: Ruff, detect-secrets, and standard checks on every commit
## Quick Start ## Quick Start
@@ -147,7 +151,7 @@ uv pip list --outdated
# Run any Python command via uv (no activation needed) # Run any Python command via uv (no activation needed)
uv run python script.py uv run python script.py
uv run pytest uv run pytest
uv run mypy app/ uv run pyright app/
# Or activate the virtual environment # Or activate the virtual environment
source .venv/bin/activate source .venv/bin/activate
@@ -169,12 +173,22 @@ make lint # Run Ruff linter (check only)
make lint-fix # Run Ruff with auto-fix make lint-fix # Run Ruff with auto-fix
make format # Format code with Ruff make format # Format code with Ruff
make format-check # Check if code is formatted make format-check # Check if code is formatted
make type-check # Run mypy type checking make type-check # Run Pyright type checking
make validate # Run all checks (lint + format + types) make validate # Run all checks (lint + format + types)
# Security & Audit
make dep-audit # Scan dependencies for known vulnerabilities (CVEs)
make license-check # Check dependency license compliance
make audit # Run all security audits (deps + licenses)
make validate-all # Run all quality + security checks
make check # Full pipeline: quality + security + tests
# Testing # Testing
make test # Run all tests make test # Run all tests
make test-cov # Run tests with coverage report make test-cov # Run tests with coverage report
make test-e2e # Run E2E tests (PostgreSQL, requires Docker)
make test-e2e-schema # Run Schemathesis API schema tests
make test-all # Run all tests (unit + E2E)
# Utilities # Utilities
make clean # Remove cache and build artifacts make clean # Remove cache and build artifacts
@@ -250,7 +264,7 @@ app/
│ ├── database.py # Database engine setup │ ├── database.py # Database engine setup
│ ├── auth.py # JWT token handling │ ├── auth.py # JWT token handling
│ └── exceptions.py # Custom exceptions │ └── exceptions.py # Custom exceptions
├── crud/ # Database operations ├── repositories/ # Repository pattern (database operations)
├── models/ # SQLAlchemy ORM models ├── models/ # SQLAlchemy ORM models
├── schemas/ # Pydantic request/response schemas ├── schemas/ # Pydantic request/response schemas
├── services/ # Business logic layer ├── services/ # Business logic layer
@@ -350,18 +364,29 @@ open htmlcov/index.html
# Using Makefile (recommended) # Using Makefile (recommended)
make lint # Ruff linting make lint # Ruff linting
make format # Ruff formatting make format # Ruff formatting
make type-check # mypy type checking make type-check # Pyright type checking
make validate # All checks at once make validate # All checks at once
# Security audits
make dep-audit # Scan dependencies for CVEs
make license-check # Check license compliance
make audit # All security audits
make validate-all # Quality + security checks
make check # Full pipeline: quality + security + tests
# Using uv directly # Using uv directly
uv run ruff check app/ tests/ uv run ruff check app/ tests/
uv run ruff format app/ tests/ uv run ruff format app/ tests/
uv run mypy app/ uv run pyright app/
``` ```
**Tools:** **Tools:**
- **Ruff**: All-in-one linting, formatting, and import sorting (replaces Black, Flake8, isort) - **Ruff**: All-in-one linting, formatting, and import sorting (replaces Black, Flake8, isort)
- **mypy**: Static type checking with Pydantic plugin - **Pyright**: Static type checking (strict mode)
- **pip-audit**: Dependency vulnerability scanning against the OSV database
- **pip-licenses**: Dependency license compliance checking
- **detect-secrets**: Hardcoded secrets/credentials detection
- **pre-commit**: Git hook framework for automated checks on every commit
All configurations are in `pyproject.toml`. All configurations are in `pyproject.toml`.
@@ -437,7 +462,7 @@ See [docs/FEATURE_EXAMPLE.md](docs/FEATURE_EXAMPLE.md) for step-by-step guide.
Quick overview: Quick overview:
1. Create Pydantic schemas in `app/schemas/` 1. Create Pydantic schemas in `app/schemas/`
2. Create CRUD operations in `app/crud/` 2. Create repository in `app/repositories/`
3. Create route in `app/api/routes/` 3. Create route in `app/api/routes/`
4. Register router in `app/api/main.py` 4. Register router in `app/api/main.py`
5. Write tests in `tests/api/` 5. Write tests in `tests/api/`
@@ -587,13 +612,42 @@ Configured in `app/core/config.py`:
- **Security Headers**: CSP, HSTS, X-Frame-Options, etc. - **Security Headers**: CSP, HSTS, X-Frame-Options, etc.
- **Input Validation**: Pydantic schemas, SQL injection prevention (ORM) - **Input Validation**: Pydantic schemas, SQL injection prevention (ORM)
### Security Auditing
Automated, deterministic security checks are built into the development workflow:
```bash
# Scan dependencies for known vulnerabilities (CVEs)
make dep-audit
# Check dependency license compliance (blocks GPL-3.0/AGPL)
make license-check
# Run all security audits
make audit
# Full pipeline: quality + security + tests
make check
```
**Pre-commit hooks** automatically run on every commit:
- **Ruff** lint + format checks
- **detect-secrets** blocks commits containing hardcoded secrets
- **Standard checks**: trailing whitespace, YAML/TOML validation, merge conflict detection, large file prevention
Setup pre-commit hooks:
```bash
uv run pre-commit install
```
### Security Best Practices ### Security Best Practices
1. **Never commit secrets**: Use `.env` files (git-ignored) 1. **Never commit secrets**: Use `.env` files (git-ignored), enforced by detect-secrets pre-commit hook
2. **Strong SECRET_KEY**: Min 32 chars, cryptographically random 2. **Strong SECRET_KEY**: Min 32 chars, cryptographically random
3. **HTTPS in production**: Required for token security 3. **HTTPS in production**: Required for token security
4. **Regular updates**: Keep dependencies current (`uv sync --upgrade`) 4. **Regular updates**: Keep dependencies current (`uv sync --upgrade`), run `make dep-audit` to check for CVEs
5. **Audit logs**: Monitor authentication events 5. **Audit logs**: Monitor authentication events
6. **Run `make check` before pushing**: Validates quality, security, and tests in one command
--- ---
@@ -643,7 +697,11 @@ logging.basicConfig(level=logging.INFO)
**Built with modern Python tooling:** **Built with modern Python tooling:**
- 🚀 **uv** - 10-100x faster dependency management - 🚀 **uv** - 10-100x faster dependency management
-**Ruff** - 10-100x faster linting & formatting -**Ruff** - 10-100x faster linting & formatting
- 🔍 **mypy** - Static type checking - 🔍 **Pyright** - Static type checking (strict mode)
-**pytest** - Comprehensive test suite -**pytest** - Comprehensive test suite
- 🔒 **pip-audit** - Dependency vulnerability scanning
- 🔑 **detect-secrets** - Hardcoded secrets detection
- 📜 **pip-licenses** - License compliance checking
- 🪝 **pre-commit** - Automated git hooks
**All configured in a single `pyproject.toml` file!** **All configured in a single `pyproject.toml` file!**

View File

@@ -2,6 +2,13 @@
script_location = app/alembic script_location = app/alembic
sqlalchemy.url = postgresql://postgres:postgres@db:5432/app sqlalchemy.url = postgresql://postgres:postgres@db:5432/app
# Use sequential naming: 0001_message.py, 0002_message.py, etc.
# The rev_id is still used internally but filename is cleaner
file_template = %%(rev)s_%%(slug)s
# Allow specifying custom revision IDs via --rev-id flag
revision_environment = true
[loggers] [loggers]
keys = root,sqlalchemy,alembic keys = root,sqlalchemy,alembic

View File

@@ -22,6 +22,25 @@ from app.models import *
# access to the values within the .ini file in use. # access to the values within the .ini file in use.
config = context.config config = context.config
def include_object(object, name, type_, reflected, compare_to):
"""
Filter objects for autogenerate.
Skip comparing functional indexes (like LOWER(column)) and partial indexes
(with WHERE clauses) as Alembic cannot reliably detect these from models.
These should be managed manually via dedicated performance migrations.
Convention: Any index starting with "ix_perf_" is automatically excluded.
This allows adding new performance indexes without updating this file.
"""
if type_ == "index" and name:
# Convention-based: any index prefixed with ix_perf_ is manual
if name.startswith("ix_perf_"):
return False
return True
# Interpret the config file for Python logging. # Interpret the config file for Python logging.
# This line sets up loggers basically. # This line sets up loggers basically.
if config.config_file_name is not None: if config.config_file_name is not None:
@@ -100,6 +119,8 @@ def run_migrations_offline() -> None:
target_metadata=target_metadata, target_metadata=target_metadata,
literal_binds=True, literal_binds=True,
dialect_opts={"paramstyle": "named"}, dialect_opts={"paramstyle": "named"},
compare_type=True,
include_object=include_object,
) )
with context.begin_transaction(): with context.begin_transaction():
@@ -123,7 +144,12 @@ def run_migrations_online() -> None:
) )
with connectable.connect() as connection: with connectable.connect() as connection:
context.configure(connection=connection, target_metadata=target_metadata) context.configure(
connection=connection,
target_metadata=target_metadata,
compare_type=True,
include_object=include_object,
)
with context.begin_transaction(): with context.begin_transaction():
context.run_migrations() context.run_migrations()

View File

@@ -0,0 +1,446 @@
"""initial models
Revision ID: 0001
Revises:
Create Date: 2025-11-27 09:08:09.464506
"""
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = "0001"
down_revision: str | None = None
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"oauth_states",
sa.Column("state", sa.String(length=255), nullable=False),
sa.Column("code_verifier", sa.String(length=128), nullable=True),
sa.Column("nonce", sa.String(length=255), nullable=True),
sa.Column("provider", sa.String(length=50), nullable=False),
sa.Column("redirect_uri", sa.String(length=500), nullable=True),
sa.Column("user_id", sa.UUID(), nullable=True),
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("id", sa.UUID(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
op.f("ix_oauth_states_state"), "oauth_states", ["state"], unique=True
)
op.create_table(
"organizations",
sa.Column("name", sa.String(length=255), nullable=False),
sa.Column("slug", sa.String(length=255), nullable=False),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("is_active", sa.Boolean(), nullable=False),
sa.Column("settings", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column("id", sa.UUID(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
op.f("ix_organizations_is_active"), "organizations", ["is_active"], unique=False
)
op.create_index(
op.f("ix_organizations_name"), "organizations", ["name"], unique=False
)
op.create_index(
"ix_organizations_name_active",
"organizations",
["name", "is_active"],
unique=False,
)
op.create_index(
op.f("ix_organizations_slug"), "organizations", ["slug"], unique=True
)
op.create_index(
"ix_organizations_slug_active",
"organizations",
["slug", "is_active"],
unique=False,
)
op.create_table(
"users",
sa.Column("email", sa.String(length=255), nullable=False),
sa.Column("password_hash", sa.String(length=255), nullable=True),
sa.Column("first_name", sa.String(length=100), nullable=False),
sa.Column("last_name", sa.String(length=100), nullable=True),
sa.Column("phone_number", sa.String(length=20), nullable=True),
sa.Column("is_active", sa.Boolean(), nullable=False),
sa.Column("is_superuser", sa.Boolean(), nullable=False),
sa.Column(
"preferences", postgresql.JSONB(astext_type=sa.Text()), nullable=True
),
sa.Column("locale", sa.String(length=10), nullable=True),
sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("id", sa.UUID(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(op.f("ix_users_deleted_at"), "users", ["deleted_at"], unique=False)
op.create_index(op.f("ix_users_email"), "users", ["email"], unique=True)
op.create_index(op.f("ix_users_is_active"), "users", ["is_active"], unique=False)
op.create_index(
op.f("ix_users_is_superuser"), "users", ["is_superuser"], unique=False
)
op.create_index(op.f("ix_users_locale"), "users", ["locale"], unique=False)
op.create_table(
"oauth_accounts",
sa.Column("user_id", sa.UUID(), nullable=False),
sa.Column("provider", sa.String(length=50), nullable=False),
sa.Column("provider_user_id", sa.String(length=255), nullable=False),
sa.Column("provider_email", sa.String(length=255), nullable=True),
sa.Column("access_token_encrypted", sa.String(length=2048), nullable=True),
sa.Column("refresh_token_encrypted", sa.String(length=2048), nullable=True),
sa.Column("token_expires_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("id", sa.UUID(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint(
"provider", "provider_user_id", name="uq_oauth_provider_user"
),
)
op.create_index(
op.f("ix_oauth_accounts_provider"), "oauth_accounts", ["provider"], unique=False
)
op.create_index(
op.f("ix_oauth_accounts_provider_email"),
"oauth_accounts",
["provider_email"],
unique=False,
)
op.create_index(
op.f("ix_oauth_accounts_user_id"), "oauth_accounts", ["user_id"], unique=False
)
op.create_index(
"ix_oauth_accounts_user_provider",
"oauth_accounts",
["user_id", "provider"],
unique=False,
)
op.create_table(
"oauth_clients",
sa.Column("client_id", sa.String(length=64), nullable=False),
sa.Column("client_secret_hash", sa.String(length=255), nullable=True),
sa.Column("client_name", sa.String(length=255), nullable=False),
sa.Column("client_description", sa.String(length=1000), nullable=True),
sa.Column("client_type", sa.String(length=20), nullable=False),
sa.Column(
"redirect_uris", postgresql.JSONB(astext_type=sa.Text()), nullable=False
),
sa.Column(
"allowed_scopes", postgresql.JSONB(astext_type=sa.Text()), nullable=False
),
sa.Column("access_token_lifetime", sa.String(length=10), nullable=False),
sa.Column("refresh_token_lifetime", sa.String(length=10), nullable=False),
sa.Column("is_active", sa.Boolean(), nullable=False),
sa.Column("owner_user_id", sa.UUID(), nullable=True),
sa.Column("mcp_server_url", sa.String(length=2048), nullable=True),
sa.Column("id", sa.UUID(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.ForeignKeyConstraint(["owner_user_id"], ["users.id"], ondelete="SET NULL"),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
op.f("ix_oauth_clients_client_id"), "oauth_clients", ["client_id"], unique=True
)
op.create_index(
op.f("ix_oauth_clients_is_active"), "oauth_clients", ["is_active"], unique=False
)
op.create_table(
"user_organizations",
sa.Column("user_id", sa.UUID(), nullable=False),
sa.Column("organization_id", sa.UUID(), nullable=False),
sa.Column(
"role",
sa.Enum("OWNER", "ADMIN", "MEMBER", "GUEST", name="organizationrole"),
nullable=False,
),
sa.Column("is_active", sa.Boolean(), nullable=False),
sa.Column("custom_permissions", sa.String(length=500), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.ForeignKeyConstraint(
["organization_id"], ["organizations.id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("user_id", "organization_id"),
)
op.create_index(
"ix_user_org_org_active",
"user_organizations",
["organization_id", "is_active"],
unique=False,
)
op.create_index("ix_user_org_role", "user_organizations", ["role"], unique=False)
op.create_index(
"ix_user_org_user_active",
"user_organizations",
["user_id", "is_active"],
unique=False,
)
op.create_index(
op.f("ix_user_organizations_is_active"),
"user_organizations",
["is_active"],
unique=False,
)
op.create_table(
"user_sessions",
sa.Column("user_id", sa.UUID(), nullable=False),
sa.Column("refresh_token_jti", sa.String(length=255), nullable=False),
sa.Column("device_name", sa.String(length=255), nullable=True),
sa.Column("device_id", sa.String(length=255), nullable=True),
sa.Column("ip_address", sa.String(length=45), nullable=True),
sa.Column("user_agent", sa.String(length=500), nullable=True),
sa.Column("last_used_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("is_active", sa.Boolean(), nullable=False),
sa.Column("location_city", sa.String(length=100), nullable=True),
sa.Column("location_country", sa.String(length=100), nullable=True),
sa.Column("id", sa.UUID(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
op.f("ix_user_sessions_is_active"), "user_sessions", ["is_active"], unique=False
)
op.create_index(
"ix_user_sessions_jti_active",
"user_sessions",
["refresh_token_jti", "is_active"],
unique=False,
)
op.create_index(
op.f("ix_user_sessions_refresh_token_jti"),
"user_sessions",
["refresh_token_jti"],
unique=True,
)
op.create_index(
"ix_user_sessions_user_active",
"user_sessions",
["user_id", "is_active"],
unique=False,
)
op.create_index(
op.f("ix_user_sessions_user_id"), "user_sessions", ["user_id"], unique=False
)
op.create_table(
"oauth_authorization_codes",
sa.Column("code", sa.String(length=128), nullable=False),
sa.Column("client_id", sa.String(length=64), nullable=False),
sa.Column("user_id", sa.UUID(), nullable=False),
sa.Column("redirect_uri", sa.String(length=2048), nullable=False),
sa.Column("scope", sa.String(length=1000), nullable=False),
sa.Column("code_challenge", sa.String(length=128), nullable=True),
sa.Column("code_challenge_method", sa.String(length=10), nullable=True),
sa.Column("state", sa.String(length=256), nullable=True),
sa.Column("nonce", sa.String(length=256), nullable=True),
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("used", sa.Boolean(), nullable=False),
sa.Column("id", sa.UUID(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.ForeignKeyConstraint(
["client_id"], ["oauth_clients.client_id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"ix_oauth_authorization_codes_client_user",
"oauth_authorization_codes",
["client_id", "user_id"],
unique=False,
)
op.create_index(
op.f("ix_oauth_authorization_codes_code"),
"oauth_authorization_codes",
["code"],
unique=True,
)
op.create_index(
"ix_oauth_authorization_codes_expires_at",
"oauth_authorization_codes",
["expires_at"],
unique=False,
)
op.create_table(
"oauth_consents",
sa.Column("user_id", sa.UUID(), nullable=False),
sa.Column("client_id", sa.String(length=64), nullable=False),
sa.Column("granted_scopes", sa.String(length=1000), nullable=False),
sa.Column("id", sa.UUID(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.ForeignKeyConstraint(
["client_id"], ["oauth_clients.client_id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"ix_oauth_consents_user_client",
"oauth_consents",
["user_id", "client_id"],
unique=True,
)
op.create_table(
"oauth_provider_refresh_tokens",
sa.Column("token_hash", sa.String(length=64), nullable=False),
sa.Column("jti", sa.String(length=64), nullable=False),
sa.Column("client_id", sa.String(length=64), nullable=False),
sa.Column("user_id", sa.UUID(), nullable=False),
sa.Column("scope", sa.String(length=1000), nullable=False),
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("revoked", sa.Boolean(), nullable=False),
sa.Column("last_used_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("device_info", sa.String(length=500), nullable=True),
sa.Column("ip_address", sa.String(length=45), nullable=True),
sa.Column("id", sa.UUID(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.ForeignKeyConstraint(
["client_id"], ["oauth_clients.client_id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"ix_oauth_provider_refresh_tokens_client_user",
"oauth_provider_refresh_tokens",
["client_id", "user_id"],
unique=False,
)
op.create_index(
"ix_oauth_provider_refresh_tokens_expires_at",
"oauth_provider_refresh_tokens",
["expires_at"],
unique=False,
)
op.create_index(
op.f("ix_oauth_provider_refresh_tokens_jti"),
"oauth_provider_refresh_tokens",
["jti"],
unique=True,
)
op.create_index(
op.f("ix_oauth_provider_refresh_tokens_revoked"),
"oauth_provider_refresh_tokens",
["revoked"],
unique=False,
)
op.create_index(
op.f("ix_oauth_provider_refresh_tokens_token_hash"),
"oauth_provider_refresh_tokens",
["token_hash"],
unique=True,
)
op.create_index(
"ix_oauth_provider_refresh_tokens_user_revoked",
"oauth_provider_refresh_tokens",
["user_id", "revoked"],
unique=False,
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(
"ix_oauth_provider_refresh_tokens_user_revoked",
table_name="oauth_provider_refresh_tokens",
)
op.drop_index(
op.f("ix_oauth_provider_refresh_tokens_token_hash"),
table_name="oauth_provider_refresh_tokens",
)
op.drop_index(
op.f("ix_oauth_provider_refresh_tokens_revoked"),
table_name="oauth_provider_refresh_tokens",
)
op.drop_index(
op.f("ix_oauth_provider_refresh_tokens_jti"),
table_name="oauth_provider_refresh_tokens",
)
op.drop_index(
"ix_oauth_provider_refresh_tokens_expires_at",
table_name="oauth_provider_refresh_tokens",
)
op.drop_index(
"ix_oauth_provider_refresh_tokens_client_user",
table_name="oauth_provider_refresh_tokens",
)
op.drop_table("oauth_provider_refresh_tokens")
op.drop_index("ix_oauth_consents_user_client", table_name="oauth_consents")
op.drop_table("oauth_consents")
op.drop_index(
"ix_oauth_authorization_codes_expires_at",
table_name="oauth_authorization_codes",
)
op.drop_index(
op.f("ix_oauth_authorization_codes_code"),
table_name="oauth_authorization_codes",
)
op.drop_index(
"ix_oauth_authorization_codes_client_user",
table_name="oauth_authorization_codes",
)
op.drop_table("oauth_authorization_codes")
op.drop_index(op.f("ix_user_sessions_user_id"), table_name="user_sessions")
op.drop_index("ix_user_sessions_user_active", table_name="user_sessions")
op.drop_index(
op.f("ix_user_sessions_refresh_token_jti"), table_name="user_sessions"
)
op.drop_index("ix_user_sessions_jti_active", table_name="user_sessions")
op.drop_index(op.f("ix_user_sessions_is_active"), table_name="user_sessions")
op.drop_table("user_sessions")
op.drop_index(
op.f("ix_user_organizations_is_active"), table_name="user_organizations"
)
op.drop_index("ix_user_org_user_active", table_name="user_organizations")
op.drop_index("ix_user_org_role", table_name="user_organizations")
op.drop_index("ix_user_org_org_active", table_name="user_organizations")
op.drop_table("user_organizations")
op.drop_index(op.f("ix_oauth_clients_is_active"), table_name="oauth_clients")
op.drop_index(op.f("ix_oauth_clients_client_id"), table_name="oauth_clients")
op.drop_table("oauth_clients")
op.drop_index("ix_oauth_accounts_user_provider", table_name="oauth_accounts")
op.drop_index(op.f("ix_oauth_accounts_user_id"), table_name="oauth_accounts")
op.drop_index(op.f("ix_oauth_accounts_provider_email"), table_name="oauth_accounts")
op.drop_index(op.f("ix_oauth_accounts_provider"), table_name="oauth_accounts")
op.drop_table("oauth_accounts")
op.drop_index(op.f("ix_users_locale"), table_name="users")
op.drop_index(op.f("ix_users_is_superuser"), table_name="users")
op.drop_index(op.f("ix_users_is_active"), table_name="users")
op.drop_index(op.f("ix_users_email"), table_name="users")
op.drop_index(op.f("ix_users_deleted_at"), table_name="users")
op.drop_table("users")
op.drop_index("ix_organizations_slug_active", table_name="organizations")
op.drop_index(op.f("ix_organizations_slug"), table_name="organizations")
op.drop_index("ix_organizations_name_active", table_name="organizations")
op.drop_index(op.f("ix_organizations_name"), table_name="organizations")
op.drop_index(op.f("ix_organizations_is_active"), table_name="organizations")
op.drop_table("organizations")
op.drop_index(op.f("ix_oauth_states_state"), table_name="oauth_states")
op.drop_table("oauth_states")
# ### end Alembic commands ###

View File

@@ -0,0 +1,127 @@
"""Add performance indexes
Revision ID: 0002
Revises: 0001
Create Date: 2025-11-27
Performance indexes that Alembic cannot auto-detect:
- Functional indexes (LOWER expressions)
- Partial indexes (WHERE clauses)
These indexes use the ix_perf_ prefix and are excluded from autogenerate
via the include_object() function in env.py.
"""
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "0002"
down_revision: str | None = "0001"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
# ==========================================================================
# USERS TABLE - Performance indexes for authentication
# ==========================================================================
# Case-insensitive email lookup for login/registration
# Query: SELECT * FROM users WHERE LOWER(email) = LOWER(:email) AND deleted_at IS NULL
# Impact: High - every login, registration check, password reset
op.create_index(
"ix_perf_users_email_lower",
"users",
[sa.text("LOWER(email)")],
unique=False,
postgresql_where=sa.text("deleted_at IS NULL"),
)
# Active users lookup (non-soft-deleted)
# Query: SELECT * FROM users WHERE deleted_at IS NULL AND ...
# Impact: Medium - user listings, admin queries
op.create_index(
"ix_perf_users_active",
"users",
["is_active"],
unique=False,
postgresql_where=sa.text("deleted_at IS NULL"),
)
# ==========================================================================
# ORGANIZATIONS TABLE - Performance indexes for multi-tenant lookups
# ==========================================================================
# Case-insensitive slug lookup for URL routing
# Query: SELECT * FROM organizations WHERE LOWER(slug) = LOWER(:slug) AND is_active = true
# Impact: Medium - every organization page load
op.create_index(
"ix_perf_organizations_slug_lower",
"organizations",
[sa.text("LOWER(slug)")],
unique=False,
postgresql_where=sa.text("is_active = true"),
)
# ==========================================================================
# USER SESSIONS TABLE - Performance indexes for session management
# ==========================================================================
# Expired session cleanup
# Query: SELECT * FROM user_sessions WHERE expires_at < NOW() AND is_active = true
# Impact: Medium - background cleanup jobs
op.create_index(
"ix_perf_user_sessions_expires",
"user_sessions",
["expires_at"],
unique=False,
postgresql_where=sa.text("is_active = true"),
)
# ==========================================================================
# OAUTH PROVIDER TOKENS - Performance indexes for token management
# ==========================================================================
# Expired refresh token cleanup
# Query: SELECT * FROM oauth_provider_refresh_tokens WHERE expires_at < NOW() AND revoked = false
# Impact: Medium - OAuth token cleanup, validation
op.create_index(
"ix_perf_oauth_refresh_tokens_expires",
"oauth_provider_refresh_tokens",
["expires_at"],
unique=False,
postgresql_where=sa.text("revoked = false"),
)
# ==========================================================================
# OAUTH AUTHORIZATION CODES - Performance indexes for auth flow
# ==========================================================================
# Expired authorization code cleanup
# Query: DELETE FROM oauth_authorization_codes WHERE expires_at < NOW() AND used = false
# Impact: Low-Medium - OAuth cleanup jobs
op.create_index(
"ix_perf_oauth_auth_codes_expires",
"oauth_authorization_codes",
["expires_at"],
unique=False,
postgresql_where=sa.text("used = false"),
)
def downgrade() -> None:
# Drop indexes in reverse order
op.drop_index(
"ix_perf_oauth_auth_codes_expires", table_name="oauth_authorization_codes"
)
op.drop_index(
"ix_perf_oauth_refresh_tokens_expires",
table_name="oauth_provider_refresh_tokens",
)
op.drop_index("ix_perf_user_sessions_expires", table_name="user_sessions")
op.drop_index("ix_perf_organizations_slug_lower", table_name="organizations")
op.drop_index("ix_perf_users_active", table_name="users")
op.drop_index("ix_perf_users_email_lower", table_name="users")

View File

@@ -0,0 +1,35 @@
"""rename oauth account token fields drop encrypted suffix
Revision ID: 0003
Revises: 0002
Create Date: 2026-02-27 01:03:18.869178
"""
from collections.abc import Sequence
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "0003"
down_revision: str | None = "0002"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
op.alter_column(
"oauth_accounts", "access_token_encrypted", new_column_name="access_token"
)
op.alter_column(
"oauth_accounts", "refresh_token_encrypted", new_column_name="refresh_token"
)
def downgrade() -> None:
op.alter_column(
"oauth_accounts", "access_token", new_column_name="access_token_encrypted"
)
op.alter_column(
"oauth_accounts", "refresh_token", new_column_name="refresh_token_encrypted"
)

View File

@@ -1,78 +0,0 @@
"""add_performance_indexes
Revision ID: 1174fffbe3e4
Revises: fbf6318a8a36
Create Date: 2025-11-01 04:15:25.367010
"""
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "1174fffbe3e4"
down_revision: str | None = "fbf6318a8a36"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Add performance indexes for optimized queries."""
# Index for session cleanup queries
# Optimizes: DELETE WHERE is_active = FALSE AND expires_at < now AND created_at < cutoff
op.create_index(
"ix_user_sessions_cleanup",
"user_sessions",
["is_active", "expires_at", "created_at"],
unique=False,
postgresql_where=sa.text("is_active = false"),
)
# Index for user search queries (basic trigram support without pg_trgm extension)
# Optimizes: WHERE email ILIKE '%search%' OR first_name ILIKE '%search%'
# Note: For better performance, consider enabling pg_trgm extension
op.create_index(
"ix_users_email_lower",
"users",
[sa.text("LOWER(email)")],
unique=False,
postgresql_where=sa.text("deleted_at IS NULL"),
)
op.create_index(
"ix_users_first_name_lower",
"users",
[sa.text("LOWER(first_name)")],
unique=False,
postgresql_where=sa.text("deleted_at IS NULL"),
)
op.create_index(
"ix_users_last_name_lower",
"users",
[sa.text("LOWER(last_name)")],
unique=False,
postgresql_where=sa.text("deleted_at IS NULL"),
)
# Index for organization search
op.create_index(
"ix_organizations_name_lower",
"organizations",
[sa.text("LOWER(name)")],
unique=False,
)
def downgrade() -> None:
"""Remove performance indexes."""
# Drop indexes in reverse order
op.drop_index("ix_organizations_name_lower", table_name="organizations")
op.drop_index("ix_users_last_name_lower", table_name="users")
op.drop_index("ix_users_first_name_lower", table_name="users")
op.drop_index("ix_users_email_lower", table_name="users")
op.drop_index("ix_user_sessions_cleanup", table_name="user_sessions")

View File

@@ -1,36 +0,0 @@
"""add_soft_delete_to_users
Revision ID: 2d0fcec3b06d
Revises: 9e4f2a1b8c7d
Create Date: 2025-10-30 16:40:21.000021
"""
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "2d0fcec3b06d"
down_revision: str | None = "9e4f2a1b8c7d"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
# Add deleted_at column for soft deletes
op.add_column(
"users", sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True)
)
# Add index on deleted_at for efficient queries
op.create_index("ix_users_deleted_at", "users", ["deleted_at"])
def downgrade() -> None:
# Remove index
op.drop_index("ix_users_deleted_at", table_name="users")
# Remove column
op.drop_column("users", "deleted_at")

View File

@@ -1,46 +0,0 @@
"""Add all initial models
Revision ID: 38bf9e7e74b3
Revises: 7396957cbe80
Create Date: 2025-02-28 09:19:33.212278
"""
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "38bf9e7e74b3"
down_revision: str | None = "7396957cbe80"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
op.create_table(
"users",
sa.Column("email", sa.String(), nullable=False),
sa.Column("password_hash", sa.String(), nullable=False),
sa.Column("first_name", sa.String(), nullable=False),
sa.Column("last_name", sa.String(), nullable=True),
sa.Column("phone_number", sa.String(), nullable=True),
sa.Column("is_active", sa.Boolean(), nullable=False),
sa.Column("is_superuser", sa.Boolean(), nullable=False),
sa.Column("preferences", sa.JSON(), nullable=True),
sa.Column("id", sa.UUID(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(op.f("ix_users_email"), "users", ["email"], unique=True)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f("ix_users_email"), table_name="users")
op.drop_table("users")
# ### end Alembic commands ###

View File

@@ -1,89 +0,0 @@
"""add_user_sessions_table
Revision ID: 549b50ea888d
Revises: b76c725fc3cf
Create Date: 2025-10-31 07:41:18.729544
"""
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "549b50ea888d"
down_revision: str | None = "b76c725fc3cf"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
# Create user_sessions table for per-device session management
op.create_table(
"user_sessions",
sa.Column("id", sa.UUID(), nullable=False),
sa.Column("user_id", sa.UUID(), nullable=False),
sa.Column("refresh_token_jti", sa.String(length=255), nullable=False),
sa.Column("device_name", sa.String(length=255), nullable=True),
sa.Column("device_id", sa.String(length=255), nullable=True),
sa.Column("ip_address", sa.String(length=45), nullable=True),
sa.Column("user_agent", sa.String(length=500), nullable=True),
sa.Column("last_used_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("is_active", sa.Boolean(), nullable=False, server_default="true"),
sa.Column("location_city", sa.String(length=100), nullable=True),
sa.Column("location_country", sa.String(length=100), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
# Create foreign key to users table
op.create_foreign_key(
"fk_user_sessions_user_id",
"user_sessions",
"users",
["user_id"],
["id"],
ondelete="CASCADE",
)
# Create indexes for performance
# 1. Lookup session by refresh token JTI (most common query)
op.create_index(
"ix_user_sessions_jti", "user_sessions", ["refresh_token_jti"], unique=True
)
# 2. Lookup sessions by user ID
op.create_index("ix_user_sessions_user_id", "user_sessions", ["user_id"])
# 3. Composite index for active sessions by user
op.create_index(
"ix_user_sessions_user_active", "user_sessions", ["user_id", "is_active"]
)
# 4. Index on expires_at for cleanup job
op.create_index("ix_user_sessions_expires_at", "user_sessions", ["expires_at"])
# 5. Composite index for active session lookup by JTI
op.create_index(
"ix_user_sessions_jti_active",
"user_sessions",
["refresh_token_jti", "is_active"],
)
def downgrade() -> None:
# Drop indexes first
op.drop_index("ix_user_sessions_jti_active", table_name="user_sessions")
op.drop_index("ix_user_sessions_expires_at", table_name="user_sessions")
op.drop_index("ix_user_sessions_user_active", table_name="user_sessions")
op.drop_index("ix_user_sessions_user_id", table_name="user_sessions")
op.drop_index("ix_user_sessions_jti", table_name="user_sessions")
# Drop foreign key
op.drop_constraint("fk_user_sessions_user_id", "user_sessions", type_="foreignkey")
# Drop table
op.drop_table("user_sessions")

View File

@@ -1,23 +0,0 @@
"""Initial empty migration
Revision ID: 7396957cbe80
Revises:
Create Date: 2025-02-27 12:47:46.445313
"""
from collections.abc import Sequence
# revision identifiers, used by Alembic.
revision: str = "7396957cbe80"
down_revision: str | None = None
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
pass
def downgrade() -> None:
pass

View File

@@ -1,116 +0,0 @@
"""Add missing indexes and fix column types
Revision ID: 9e4f2a1b8c7d
Revises: 38bf9e7e74b3
Create Date: 2025-10-30 10:00:00.000000
"""
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "9e4f2a1b8c7d"
down_revision: str | None = "38bf9e7e74b3"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
# Add missing indexes for is_active and is_superuser
op.create_index(op.f("ix_users_is_active"), "users", ["is_active"], unique=False)
op.create_index(
op.f("ix_users_is_superuser"), "users", ["is_superuser"], unique=False
)
# Fix column types to match model definitions with explicit lengths
op.alter_column(
"users",
"email",
existing_type=sa.String(),
type_=sa.String(length=255),
nullable=False,
)
op.alter_column(
"users",
"password_hash",
existing_type=sa.String(),
type_=sa.String(length=255),
nullable=False,
)
op.alter_column(
"users",
"first_name",
existing_type=sa.String(),
type_=sa.String(length=100),
nullable=False,
server_default="user",
) # Add server default
op.alter_column(
"users",
"last_name",
existing_type=sa.String(),
type_=sa.String(length=100),
nullable=True,
)
op.alter_column(
"users",
"phone_number",
existing_type=sa.String(),
type_=sa.String(length=20),
nullable=True,
)
def downgrade() -> None:
# Revert column types
op.alter_column(
"users",
"phone_number",
existing_type=sa.String(length=20),
type_=sa.String(),
nullable=True,
)
op.alter_column(
"users",
"last_name",
existing_type=sa.String(length=100),
type_=sa.String(),
nullable=True,
)
op.alter_column(
"users",
"first_name",
existing_type=sa.String(length=100),
type_=sa.String(),
nullable=False,
server_default=None,
) # Remove server default
op.alter_column(
"users",
"password_hash",
existing_type=sa.String(length=255),
type_=sa.String(),
nullable=False,
)
op.alter_column(
"users",
"email",
existing_type=sa.String(length=255),
type_=sa.String(),
nullable=False,
)
# Drop indexes
op.drop_index(op.f("ix_users_is_superuser"), table_name="users")
op.drop_index(op.f("ix_users_is_active"), table_name="users")

View File

@@ -1,48 +0,0 @@
"""add_composite_indexes
Revision ID: b76c725fc3cf
Revises: 2d0fcec3b06d
Create Date: 2025-10-30 16:41:33.273135
"""
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "b76c725fc3cf"
down_revision: str | None = "2d0fcec3b06d"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
# Add composite indexes for common query patterns
# Composite index for filtering active users by role
op.create_index(
"ix_users_active_superuser",
"users",
["is_active", "is_superuser"],
postgresql_where=sa.text("deleted_at IS NULL"),
)
# Composite index for sorting active users by creation date
op.create_index(
"ix_users_active_created",
"users",
["is_active", "created_at"],
postgresql_where=sa.text("deleted_at IS NULL"),
)
# Composite index for email lookup of non-deleted users
op.create_index("ix_users_email_not_deleted", "users", ["email", "deleted_at"])
def downgrade() -> None:
# Remove composite indexes
op.drop_index("ix_users_email_not_deleted", table_name="users")
op.drop_index("ix_users_active_created", table_name="users")
op.drop_index("ix_users_active_superuser", table_name="users")

View File

@@ -1,127 +0,0 @@
"""add_organizations_and_user_organizations
Revision ID: fbf6318a8a36
Revises: 549b50ea888d
Create Date: 2025-10-31 12:08:05.141353
"""
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "fbf6318a8a36"
down_revision: str | None = "549b50ea888d"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
# Create organizations table
op.create_table(
"organizations",
sa.Column("id", sa.UUID(), nullable=False),
sa.Column("name", sa.String(length=255), nullable=False),
sa.Column("slug", sa.String(length=255), nullable=False),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("is_active", sa.Boolean(), nullable=False, server_default="true"),
sa.Column("settings", sa.JSON(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
# Create indexes for organizations
op.create_index("ix_organizations_name", "organizations", ["name"])
op.create_index("ix_organizations_slug", "organizations", ["slug"], unique=True)
op.create_index("ix_organizations_is_active", "organizations", ["is_active"])
op.create_index(
"ix_organizations_name_active", "organizations", ["name", "is_active"]
)
op.create_index(
"ix_organizations_slug_active", "organizations", ["slug", "is_active"]
)
# Create user_organizations junction table
op.create_table(
"user_organizations",
sa.Column("user_id", sa.UUID(), nullable=False),
sa.Column("organization_id", sa.UUID(), nullable=False),
sa.Column(
"role",
sa.Enum("OWNER", "ADMIN", "MEMBER", "GUEST", name="organizationrole"),
nullable=False,
server_default="MEMBER",
),
sa.Column("is_active", sa.Boolean(), nullable=False, server_default="true"),
sa.Column("custom_permissions", sa.String(length=500), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint("user_id", "organization_id"),
)
# Create foreign keys
op.create_foreign_key(
"fk_user_organizations_user_id",
"user_organizations",
"users",
["user_id"],
["id"],
ondelete="CASCADE",
)
op.create_foreign_key(
"fk_user_organizations_organization_id",
"user_organizations",
"organizations",
["organization_id"],
["id"],
ondelete="CASCADE",
)
# Create indexes for user_organizations
op.create_index("ix_user_organizations_role", "user_organizations", ["role"])
op.create_index(
"ix_user_organizations_is_active", "user_organizations", ["is_active"]
)
op.create_index(
"ix_user_org_user_active", "user_organizations", ["user_id", "is_active"]
)
op.create_index(
"ix_user_org_org_active", "user_organizations", ["organization_id", "is_active"]
)
def downgrade() -> None:
# Drop indexes for user_organizations
op.drop_index("ix_user_org_org_active", table_name="user_organizations")
op.drop_index("ix_user_org_user_active", table_name="user_organizations")
op.drop_index("ix_user_organizations_is_active", table_name="user_organizations")
op.drop_index("ix_user_organizations_role", table_name="user_organizations")
# Drop foreign keys
op.drop_constraint(
"fk_user_organizations_organization_id",
"user_organizations",
type_="foreignkey",
)
op.drop_constraint(
"fk_user_organizations_user_id", "user_organizations", type_="foreignkey"
)
# Drop user_organizations table
op.drop_table("user_organizations")
# Drop indexes for organizations
op.drop_index("ix_organizations_slug_active", table_name="organizations")
op.drop_index("ix_organizations_name_active", table_name="organizations")
op.drop_index("ix_organizations_is_active", table_name="organizations")
op.drop_index("ix_organizations_slug", table_name="organizations")
op.drop_index("ix_organizations_name", table_name="organizations")
# Drop organizations table
op.drop_table("organizations")
# Drop enum type
op.execute("DROP TYPE IF EXISTS organizationrole")

View File

@@ -1,12 +1,12 @@
from fastapi import Depends, Header, HTTPException, status from fastapi import Depends, Header, HTTPException, status
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from fastapi.security.utils import get_authorization_scheme_param from fastapi.security.utils import get_authorization_scheme_param
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.core.auth import TokenExpiredError, TokenInvalidError, get_token_data from app.core.auth import TokenExpiredError, TokenInvalidError, get_token_data
from app.core.database import get_db from app.core.database import get_db
from app.models.user import User from app.models.user import User
from app.repositories.user import user_repo
# OAuth2 configuration # OAuth2 configuration
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
@@ -32,9 +32,8 @@ async def get_current_user(
# Decode token and get user ID # Decode token and get user ID
token_data = get_token_data(token) token_data = get_token_data(token)
# Get user from database # Get user from database via repository
result = await db.execute(select(User).where(User.id == token_data.user_id)) user = await user_repo.get(db, id=str(token_data.user_id))
user = result.scalar_one_or_none()
if not user: if not user:
raise HTTPException( raise HTTPException(
@@ -144,8 +143,7 @@ async def get_optional_current_user(
try: try:
token_data = get_token_data(token) token_data = get_token_data(token)
result = await db.execute(select(User).where(User.id == token_data.user_id)) user = await user_repo.get(db, id=str(token_data.user_id))
user = result.scalar_one_or_none()
if not user or not user.is_active: if not user or not user.is_active:
return None return None
return user return user

View File

@@ -0,0 +1,132 @@
# app/api/dependencies/locale.py
"""
Locale detection dependency for internationalization (i18n).
Implements a three-tier fallback system:
1. User's saved preference (if authenticated and user.locale is set)
2. Accept-Language header (for unauthenticated users or no saved preference)
3. Default to English ("en")
"""
from fastapi import Depends, Request
from app.api.dependencies.auth import get_optional_current_user
from app.models.user import User
# Supported locales (BCP 47 format)
# Template showcases English and Italian
# Users can extend by adding more locales here
# Note: Stored in lowercase for case-insensitive matching
SUPPORTED_LOCALES = {"en", "it", "en-us", "en-gb", "it-it"}
DEFAULT_LOCALE = "en"
def parse_accept_language(accept_language: str) -> str | None:
"""
Parse the Accept-Language header and return the best matching supported locale.
The Accept-Language header format is:
"it-IT,it;q=0.9,en-US;q=0.8,en;q=0.7"
This function extracts locales in priority order (by quality value) and returns
the first one that matches our supported locales.
Args:
accept_language: The Accept-Language header value
Returns:
The best matching locale code, or None if no match found
Examples:
>>> parse_accept_language("it-IT,it;q=0.9,en;q=0.8")
"it-IT" # or "it" if it-IT is not supported
>>> parse_accept_language("fr-FR,fr;q=0.9")
None # French not supported
"""
if not accept_language:
return None
# Split by comma to get individual locale entries
# Format: "locale;q=weight" or just "locale"
locales = []
for entry in accept_language.split(","):
# Remove quality value (;q=0.9) if present
locale = entry.split(";")[0].strip()
if locale:
locales.append(locale)
# Check each locale in priority order
for locale in locales:
locale_lower = locale.lower()
# Try exact match first (e.g., "it-IT")
if locale_lower in SUPPORTED_LOCALES:
return locale_lower
# Try language code only (e.g., "it" from "it-IT")
lang_code = locale_lower.split("-")[0]
if lang_code in SUPPORTED_LOCALES:
return lang_code
return None
async def get_locale(
request: Request,
current_user: User | None = Depends(get_optional_current_user),
) -> str:
"""
Detect and return the appropriate locale for the current request.
Three-tier fallback system:
1. **User Preference** (highest priority)
- If user is authenticated and has a saved locale preference, use it
- This persists across sessions and devices
2. **Accept-Language Header** (second priority)
- Parse the Accept-Language header from the request
- Match against supported locales
- Common for browser requests
3. **Default Locale** (fallback)
- Return "en" (English) if no user preference and no header match
Args:
request: The FastAPI request object (for accessing headers)
current_user: The current authenticated user (optional)
Returns:
A valid locale code from SUPPORTED_LOCALES (guaranteed to be supported)
Examples:
>>> # Authenticated user with saved preference
>>> await get_locale(request, user_with_locale_it)
"it"
>>> # Unauthenticated user with Italian browser
>>> # (request has Accept-Language: it-IT,it;q=0.9)
>>> await get_locale(request, None)
"it"
>>> # Unauthenticated user with unsupported language
>>> # (request has Accept-Language: fr-FR,fr;q=0.9)
>>> await get_locale(request, None)
"en"
"""
# Priority 1: User's saved preference
if current_user and current_user.locale:
# Validate that saved locale is still supported
# (in case SUPPORTED_LOCALES changed after user set preference)
locale_value = str(current_user.locale)
if locale_value in SUPPORTED_LOCALES:
return locale_value
# Priority 2: Accept-Language header
accept_language = request.headers.get("accept-language", "")
if accept_language:
detected_locale = parse_accept_language(accept_language)
if detected_locale:
return detected_locale
# Priority 3: Default fallback
return DEFAULT_LOCALE

View File

@@ -15,9 +15,9 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.auth import get_current_user from app.api.dependencies.auth import get_current_user
from app.core.database import get_db from app.core.database import get_db
from app.crud.organization import organization as organization_crud
from app.models.user import User from app.models.user import User
from app.models.user_organization import OrganizationRole from app.models.user_organization import OrganizationRole
from app.services.organization_service import organization_service
def require_superuser(current_user: User = Depends(get_current_user)) -> User: def require_superuser(current_user: User = Depends(get_current_user)) -> User:
@@ -81,7 +81,7 @@ class OrganizationPermission:
return current_user return current_user
# Get user's role in organization # Get user's role in organization
user_role = await organization_crud.get_user_role_in_org( user_role = await organization_service.get_user_role_in_org(
db, user_id=current_user.id, organization_id=organization_id db, user_id=current_user.id, organization_id=organization_id
) )
@@ -123,7 +123,7 @@ async def require_org_membership(
if current_user.is_superuser: if current_user.is_superuser:
return current_user return current_user
user_role = await organization_crud.get_user_role_in_org( user_role = await organization_service.get_user_role_in_org(
db, user_id=current_user.id, organization_id=organization_id db, user_id=current_user.id, organization_id=organization_id
) )

View File

@@ -0,0 +1,41 @@
# app/api/dependencies/services.py
"""FastAPI dependency functions for service singletons."""
from app.services import oauth_provider_service
from app.services.auth_service import AuthService
from app.services.oauth_service import OAuthService
from app.services.organization_service import OrganizationService, organization_service
from app.services.session_service import SessionService, session_service
from app.services.user_service import UserService, user_service
def get_auth_service() -> AuthService:
"""Return the AuthService singleton for dependency injection."""
from app.services.auth_service import AuthService as _AuthService
return _AuthService()
def get_user_service() -> UserService:
"""Return the UserService singleton for dependency injection."""
return user_service
def get_organization_service() -> OrganizationService:
"""Return the OrganizationService singleton for dependency injection."""
return organization_service
def get_session_service() -> SessionService:
"""Return the SessionService singleton for dependency injection."""
return session_service
def get_oauth_service() -> OAuthService:
"""Return OAuthService for dependency injection."""
return OAuthService()
def get_oauth_provider_service():
"""Return the oauth_provider_service module for dependency injection."""
return oauth_provider_service

View File

@@ -1,9 +1,21 @@
from fastapi import APIRouter from fastapi import APIRouter
from app.api.routes import admin, auth, organizations, sessions, users from app.api.routes import (
admin,
auth,
oauth,
oauth_provider,
organizations,
sessions,
users,
)
api_router = APIRouter() api_router = APIRouter()
api_router.include_router(auth.router, prefix="/auth", tags=["Authentication"]) api_router.include_router(auth.router, prefix="/auth", tags=["Authentication"])
api_router.include_router(oauth.router, prefix="/oauth", tags=["OAuth"])
api_router.include_router(
oauth_provider.router, prefix="/oauth", tags=["OAuth Provider"]
)
api_router.include_router(users.router, prefix="/users", tags=["Users"]) api_router.include_router(users.router, prefix="/users", tags=["Users"])
api_router.include_router(sessions.router, prefix="/sessions", tags=["Sessions"]) api_router.include_router(sessions.router, prefix="/sessions", tags=["Sessions"])
api_router.include_router(admin.router, prefix="/admin", tags=["Admin"]) api_router.include_router(admin.router, prefix="/admin", tags=["Admin"])

View File

@@ -7,6 +7,7 @@ for managing the application.
""" """
import logging import logging
from datetime import UTC, datetime, timedelta
from enum import Enum from enum import Enum
from typing import Any from typing import Any
from uuid import UUID from uuid import UUID
@@ -23,9 +24,7 @@ from app.core.exceptions import (
ErrorCode, ErrorCode,
NotFoundError, NotFoundError,
) )
from app.crud.organization import organization as organization_crud from app.core.repository_exceptions import DuplicateEntryError
from app.crud.session import session as session_crud
from app.crud.user import user as user_crud
from app.models.user import User from app.models.user import User
from app.models.user_organization import OrganizationRole from app.models.user_organization import OrganizationRole
from app.schemas.common import ( from app.schemas.common import (
@@ -43,6 +42,9 @@ from app.schemas.organizations import (
) )
from app.schemas.sessions import AdminSessionResponse from app.schemas.sessions import AdminSessionResponse
from app.schemas.users import UserCreate, UserResponse, UserUpdate from app.schemas.users import UserCreate, UserResponse, UserUpdate
from app.services.organization_service import organization_service
from app.services.session_service import session_service
from app.services.user_service import user_service
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -63,7 +65,7 @@ class BulkUserAction(BaseModel):
action: BulkAction = Field(..., description="Action to perform on selected users") action: BulkAction = Field(..., description="Action to perform on selected users")
user_ids: list[UUID] = Field( user_ids: list[UUID] = Field(
..., min_items=1, max_items=100, description="List of user IDs (max 100)" ..., min_length=1, max_length=100, description="List of user IDs (max 100)"
) )
@@ -80,6 +82,186 @@ class BulkActionResult(BaseModel):
# ===== User Management Endpoints ===== # ===== User Management Endpoints =====
class UserGrowthData(BaseModel):
date: str
total_users: int
active_users: int
class OrgDistributionData(BaseModel):
name: str
value: int
class RegistrationActivityData(BaseModel):
date: str
registrations: int
class UserStatusData(BaseModel):
name: str
value: int
class AdminStatsResponse(BaseModel):
user_growth: list[UserGrowthData]
organization_distribution: list[OrgDistributionData]
registration_activity: list[RegistrationActivityData]
user_status: list[UserStatusData]
def _generate_demo_stats() -> AdminStatsResponse: # pragma: no cover
"""Generate demo statistics for empty databases."""
from random import randint
# Demo user growth (last 30 days)
user_growth = []
total = 10
for i in range(29, -1, -1):
date = datetime.now(UTC) - timedelta(days=i)
total += randint(0, 3) # noqa: S311
user_growth.append(
UserGrowthData(
date=date.strftime("%b %d"),
total_users=total,
active_users=int(total * 0.85),
)
)
# Demo organization distribution
org_dist = [
OrgDistributionData(name="Engineering", value=12),
OrgDistributionData(name="Product", value=8),
OrgDistributionData(name="Sales", value=15),
OrgDistributionData(name="Marketing", value=6),
OrgDistributionData(name="Support", value=5),
OrgDistributionData(name="Operations", value=4),
]
# Demo registration activity (last 14 days)
registration_activity = []
for i in range(13, -1, -1):
date = datetime.now(UTC) - timedelta(days=i)
registration_activity.append(
RegistrationActivityData(
date=date.strftime("%b %d"),
registrations=randint(0, 5), # noqa: S311
)
)
# Demo user status
user_status = [
UserStatusData(name="Active", value=45),
UserStatusData(name="Inactive", value=5),
]
return AdminStatsResponse(
user_growth=user_growth,
organization_distribution=org_dist,
registration_activity=registration_activity,
user_status=user_status,
)
@router.get(
"/stats",
response_model=AdminStatsResponse,
summary="Admin: Get Dashboard Stats",
description="Get aggregated statistics for the admin dashboard (admin only)",
operation_id="admin_get_stats",
)
async def admin_get_stats(
admin: User = Depends(require_superuser),
db: AsyncSession = Depends(get_db),
) -> Any:
"""Get admin dashboard statistics with real data from database."""
from app.core.config import settings
stats = await user_service.get_stats(db)
total_users = stats["total_users"]
active_count = stats["active_count"]
inactive_count = stats["inactive_count"]
all_users = stats["all_users"]
# If database is essentially empty (only admin user), return demo data
if total_users <= 1 and settings.DEMO_MODE: # pragma: no cover
logger.info("Returning demo stats data (empty database in demo mode)")
return _generate_demo_stats()
# 1. User Growth (Last 30 days)
user_growth = []
for i in range(29, -1, -1):
date = datetime.now(UTC) - timedelta(days=i)
date_start = date.replace(hour=0, minute=0, second=0, microsecond=0, tzinfo=UTC)
date_end = date_start + timedelta(days=1)
total_users_on_date = sum(
1
for u in all_users
if u.created_at and u.created_at.replace(tzinfo=UTC) < date_end
)
active_users_on_date = sum(
1
for u in all_users
if u.created_at
and u.created_at.replace(tzinfo=UTC) < date_end
and u.is_active
)
user_growth.append(
UserGrowthData(
date=date.strftime("%b %d"),
total_users=total_users_on_date,
active_users=active_users_on_date,
)
)
# 2. Organization Distribution - Top 6 organizations by member count
org_rows = await organization_service.get_org_distribution(db, limit=6)
org_dist = [OrgDistributionData(name=r["name"], value=r["value"]) for r in org_rows]
# 3. User Registration Activity (Last 14 days)
registration_activity = []
for i in range(13, -1, -1):
date = datetime.now(UTC) - timedelta(days=i)
date_start = date.replace(hour=0, minute=0, second=0, microsecond=0, tzinfo=UTC)
date_end = date_start + timedelta(days=1)
day_registrations = sum(
1
for u in all_users
if u.created_at
and date_start <= u.created_at.replace(tzinfo=UTC) < date_end
)
registration_activity.append(
RegistrationActivityData(
date=date.strftime("%b %d"),
registrations=day_registrations,
)
)
# 4. User Status - Active vs Inactive
logger.info(
"User status counts - Active: %s, Inactive: %s", active_count, inactive_count
)
user_status = [
UserStatusData(name="Active", value=active_count),
UserStatusData(name="Inactive", value=inactive_count),
]
return AdminStatsResponse(
user_growth=user_growth,
organization_distribution=org_dist,
registration_activity=registration_activity,
user_status=user_status,
)
# ===== User Management Endpoints =====
@router.get( @router.get(
"/users", "/users",
response_model=PaginatedResponse[UserResponse], response_model=PaginatedResponse[UserResponse],
@@ -110,7 +292,7 @@ async def admin_list_users(
filters["is_superuser"] = is_superuser filters["is_superuser"] = is_superuser
# Get users with search # Get users with search
users, total = await user_crud.get_multi_with_total( users, total = await user_service.list_users(
db, db,
skip=pagination.offset, skip=pagination.offset,
limit=pagination.limit, limit=pagination.limit,
@@ -130,7 +312,7 @@ async def admin_list_users(
return PaginatedResponse(data=users, pagination=pagination_meta) return PaginatedResponse(data=users, pagination=pagination_meta)
except Exception as e: except Exception as e:
logger.error(f"Error listing users (admin): {e!s}", exc_info=True) logger.exception("Error listing users (admin): %s", e)
raise raise
@@ -153,14 +335,14 @@ async def admin_create_user(
Allows setting is_superuser and other fields. Allows setting is_superuser and other fields.
""" """
try: try:
user = await user_crud.create(db, obj_in=user_in) user = await user_service.create_user(db, user_in)
logger.info(f"Admin {admin.email} created user {user.email}") logger.info("Admin %s created user %s", admin.email, user.email)
return user return user
except ValueError as e: except DuplicateEntryError as e:
logger.warning(f"Failed to create user: {e!s}") logger.warning("Failed to create user: %s", e)
raise NotFoundError(message=str(e), error_code=ErrorCode.USER_ALREADY_EXISTS) raise DuplicateError(message=str(e), error_code=ErrorCode.USER_ALREADY_EXISTS)
except Exception as e: except Exception as e:
logger.error(f"Error creating user (admin): {e!s}", exc_info=True) logger.exception("Error creating user (admin): %s", e)
raise raise
@@ -177,11 +359,7 @@ async def admin_get_user(
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
) -> Any: ) -> Any:
"""Get detailed information about a specific user.""" """Get detailed information about a specific user."""
user = await user_crud.get(db, id=user_id) user = await user_service.get_user(db, str(user_id))
if not user:
raise NotFoundError(
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
)
return user return user
@@ -200,20 +378,13 @@ async def admin_update_user(
) -> Any: ) -> Any:
"""Update user information with admin privileges.""" """Update user information with admin privileges."""
try: try:
user = await user_crud.get(db, id=user_id) user = await user_service.get_user(db, str(user_id))
if not user: updated_user = await user_service.update_user(db, user=user, obj_in=user_in)
raise NotFoundError( logger.info("Admin %s updated user %s", admin.email, updated_user.email)
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
)
updated_user = await user_crud.update(db, db_obj=user, obj_in=user_in)
logger.info(f"Admin {admin.email} updated user {updated_user.email}")
return updated_user return updated_user
except NotFoundError:
raise
except Exception as e: except Exception as e:
logger.error(f"Error updating user (admin): {e!s}", exc_info=True) logger.exception("Error updating user (admin): %s", e)
raise raise
@@ -231,11 +402,7 @@ async def admin_delete_user(
) -> Any: ) -> Any:
"""Soft delete a user (sets deleted_at timestamp).""" """Soft delete a user (sets deleted_at timestamp)."""
try: try:
user = await user_crud.get(db, id=user_id) user = await user_service.get_user(db, str(user_id))
if not user:
raise NotFoundError(
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
)
# Prevent deleting yourself # Prevent deleting yourself
if user.id == admin.id: if user.id == admin.id:
@@ -245,17 +412,15 @@ async def admin_delete_user(
error_code=ErrorCode.OPERATION_FORBIDDEN, error_code=ErrorCode.OPERATION_FORBIDDEN,
) )
await user_crud.soft_delete(db, id=user_id) await user_service.soft_delete_user(db, str(user_id))
logger.info(f"Admin {admin.email} deleted user {user.email}") logger.info("Admin %s deleted user %s", admin.email, user.email)
return MessageResponse( return MessageResponse(
success=True, message=f"User {user.email} has been deleted" success=True, message=f"User {user.email} has been deleted"
) )
except NotFoundError:
raise
except Exception as e: except Exception as e:
logger.error(f"Error deleting user (admin): {e!s}", exc_info=True) logger.exception("Error deleting user (admin): %s", e)
raise raise
@@ -273,23 +438,16 @@ async def admin_activate_user(
) -> Any: ) -> Any:
"""Activate a user account.""" """Activate a user account."""
try: try:
user = await user_crud.get(db, id=user_id) user = await user_service.get_user(db, str(user_id))
if not user: await user_service.update_user(db, user=user, obj_in={"is_active": True})
raise NotFoundError( logger.info("Admin %s activated user %s", admin.email, user.email)
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
)
await user_crud.update(db, db_obj=user, obj_in={"is_active": True})
logger.info(f"Admin {admin.email} activated user {user.email}")
return MessageResponse( return MessageResponse(
success=True, message=f"User {user.email} has been activated" success=True, message=f"User {user.email} has been activated"
) )
except NotFoundError:
raise
except Exception as e: except Exception as e:
logger.error(f"Error activating user (admin): {e!s}", exc_info=True) logger.exception("Error activating user (admin): %s", e)
raise raise
@@ -307,11 +465,7 @@ async def admin_deactivate_user(
) -> Any: ) -> Any:
"""Deactivate a user account.""" """Deactivate a user account."""
try: try:
user = await user_crud.get(db, id=user_id) user = await user_service.get_user(db, str(user_id))
if not user:
raise NotFoundError(
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
)
# Prevent deactivating yourself # Prevent deactivating yourself
if user.id == admin.id: if user.id == admin.id:
@@ -321,17 +475,15 @@ async def admin_deactivate_user(
error_code=ErrorCode.OPERATION_FORBIDDEN, error_code=ErrorCode.OPERATION_FORBIDDEN,
) )
await user_crud.update(db, db_obj=user, obj_in={"is_active": False}) await user_service.update_user(db, user=user, obj_in={"is_active": False})
logger.info(f"Admin {admin.email} deactivated user {user.email}") logger.info("Admin %s deactivated user %s", admin.email, user.email)
return MessageResponse( return MessageResponse(
success=True, message=f"User {user.email} has been deactivated" success=True, message=f"User {user.email} has been deactivated"
) )
except NotFoundError:
raise
except Exception as e: except Exception as e:
logger.error(f"Error deactivating user (admin): {e!s}", exc_info=True) logger.exception("Error deactivating user (admin): %s", e)
raise raise
@@ -356,19 +508,19 @@ async def admin_bulk_user_action(
try: try:
# Use efficient bulk operations instead of loop # Use efficient bulk operations instead of loop
if bulk_action.action == BulkAction.ACTIVATE: if bulk_action.action == BulkAction.ACTIVATE:
affected_count = await user_crud.bulk_update_status( affected_count = await user_service.bulk_update_status(
db, user_ids=bulk_action.user_ids, is_active=True db, user_ids=bulk_action.user_ids, is_active=True
) )
elif bulk_action.action == BulkAction.DEACTIVATE: elif bulk_action.action == BulkAction.DEACTIVATE:
affected_count = await user_crud.bulk_update_status( affected_count = await user_service.bulk_update_status(
db, user_ids=bulk_action.user_ids, is_active=False db, user_ids=bulk_action.user_ids, is_active=False
) )
elif bulk_action.action == BulkAction.DELETE: elif bulk_action.action == BulkAction.DELETE:
# bulk_soft_delete automatically excludes the admin user # bulk_soft_delete automatically excludes the admin user
affected_count = await user_crud.bulk_soft_delete( affected_count = await user_service.bulk_soft_delete(
db, user_ids=bulk_action.user_ids, exclude_user_id=admin.id db, user_ids=bulk_action.user_ids, exclude_user_id=admin.id
) )
else: else: # pragma: no cover
raise ValueError(f"Unsupported bulk action: {bulk_action.action}") raise ValueError(f"Unsupported bulk action: {bulk_action.action}")
# Calculate failed count (requested - affected) # Calculate failed count (requested - affected)
@@ -376,8 +528,11 @@ async def admin_bulk_user_action(
failed_count = requested_count - affected_count failed_count = requested_count - affected_count
logger.info( logger.info(
f"Admin {admin.email} performed bulk {bulk_action.action.value} " "Admin %s performed bulk %s on %s users (%s skipped/failed)",
f"on {affected_count} users ({failed_count} skipped/failed)" admin.email,
bulk_action.action.value,
affected_count,
failed_count,
) )
return BulkActionResult( return BulkActionResult(
@@ -388,8 +543,8 @@ async def admin_bulk_user_action(
failed_ids=None, # Bulk operations don't track individual failures failed_ids=None, # Bulk operations don't track individual failures
) )
except Exception as e: except Exception as e: # pragma: no cover
logger.error(f"Error in bulk user action: {e!s}", exc_info=True) logger.exception("Error in bulk user action: %s", e)
raise raise
@@ -413,7 +568,7 @@ async def admin_list_organizations(
"""List all organizations with filtering and search.""" """List all organizations with filtering and search."""
try: try:
# Use optimized method that gets member counts in single query (no N+1) # Use optimized method that gets member counts in single query (no N+1)
orgs_with_data, total = await organization_crud.get_multi_with_member_counts( orgs_with_data, total = await organization_service.get_multi_with_member_counts(
db, db,
skip=pagination.offset, skip=pagination.offset,
limit=pagination.limit, limit=pagination.limit,
@@ -450,7 +605,7 @@ async def admin_list_organizations(
return PaginatedResponse(data=orgs_with_count, pagination=pagination_meta) return PaginatedResponse(data=orgs_with_count, pagination=pagination_meta)
except Exception as e: except Exception as e:
logger.error(f"Error listing organizations (admin): {e!s}", exc_info=True) logger.exception("Error listing organizations (admin): %s", e)
raise raise
@@ -469,8 +624,8 @@ async def admin_create_organization(
) -> Any: ) -> Any:
"""Create a new organization.""" """Create a new organization."""
try: try:
org = await organization_crud.create(db, obj_in=org_in) org = await organization_service.create_organization(db, obj_in=org_in)
logger.info(f"Admin {admin.email} created organization {org.name}") logger.info("Admin %s created organization %s", admin.email, org.name)
# Add member count # Add member count
org_dict = { org_dict = {
@@ -486,11 +641,11 @@ async def admin_create_organization(
} }
return OrganizationResponse(**org_dict) return OrganizationResponse(**org_dict)
except ValueError as e: except DuplicateEntryError as e:
logger.warning(f"Failed to create organization: {e!s}") logger.warning("Failed to create organization: %s", e)
raise NotFoundError(message=str(e), error_code=ErrorCode.ALREADY_EXISTS) raise DuplicateError(message=str(e), error_code=ErrorCode.ALREADY_EXISTS)
except Exception as e: except Exception as e:
logger.error(f"Error creating organization (admin): {e!s}", exc_info=True) logger.exception("Error creating organization (admin): %s", e)
raise raise
@@ -507,12 +662,7 @@ async def admin_get_organization(
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
) -> Any: ) -> Any:
"""Get detailed information about a specific organization.""" """Get detailed information about a specific organization."""
org = await organization_crud.get(db, id=org_id) org = await organization_service.get_organization(db, str(org_id))
if not org:
raise NotFoundError(
message=f"Organization {org_id} not found", error_code=ErrorCode.NOT_FOUND
)
org_dict = { org_dict = {
"id": org.id, "id": org.id,
"name": org.name, "name": org.name,
@@ -522,7 +672,7 @@ async def admin_get_organization(
"settings": org.settings, "settings": org.settings,
"created_at": org.created_at, "created_at": org.created_at,
"updated_at": org.updated_at, "updated_at": org.updated_at,
"member_count": await organization_crud.get_member_count( "member_count": await organization_service.get_member_count(
db, organization_id=org.id db, organization_id=org.id
), ),
} }
@@ -544,15 +694,11 @@ async def admin_update_organization(
) -> Any: ) -> Any:
"""Update organization information.""" """Update organization information."""
try: try:
org = await organization_crud.get(db, id=org_id) org = await organization_service.get_organization(db, str(org_id))
if not org: updated_org = await organization_service.update_organization(
raise NotFoundError( db, org=org, obj_in=org_in
message=f"Organization {org_id} not found", )
error_code=ErrorCode.NOT_FOUND, logger.info("Admin %s updated organization %s", admin.email, updated_org.name)
)
updated_org = await organization_crud.update(db, db_obj=org, obj_in=org_in)
logger.info(f"Admin {admin.email} updated organization {updated_org.name}")
org_dict = { org_dict = {
"id": updated_org.id, "id": updated_org.id,
@@ -563,16 +709,14 @@ async def admin_update_organization(
"settings": updated_org.settings, "settings": updated_org.settings,
"created_at": updated_org.created_at, "created_at": updated_org.created_at,
"updated_at": updated_org.updated_at, "updated_at": updated_org.updated_at,
"member_count": await organization_crud.get_member_count( "member_count": await organization_service.get_member_count(
db, organization_id=updated_org.id db, organization_id=updated_org.id
), ),
} }
return OrganizationResponse(**org_dict) return OrganizationResponse(**org_dict)
except NotFoundError:
raise
except Exception as e: except Exception as e:
logger.error(f"Error updating organization (admin): {e!s}", exc_info=True) logger.exception("Error updating organization (admin): %s", e)
raise raise
@@ -590,24 +734,16 @@ async def admin_delete_organization(
) -> Any: ) -> Any:
"""Delete an organization and all its relationships.""" """Delete an organization and all its relationships."""
try: try:
org = await organization_crud.get(db, id=org_id) org = await organization_service.get_organization(db, str(org_id))
if not org: await organization_service.remove_organization(db, str(org_id))
raise NotFoundError( logger.info("Admin %s deleted organization %s", admin.email, org.name)
message=f"Organization {org_id} not found",
error_code=ErrorCode.NOT_FOUND,
)
await organization_crud.remove(db, id=org_id)
logger.info(f"Admin {admin.email} deleted organization {org.name}")
return MessageResponse( return MessageResponse(
success=True, message=f"Organization {org.name} has been deleted" success=True, message=f"Organization {org.name} has been deleted"
) )
except NotFoundError:
raise
except Exception as e: except Exception as e:
logger.error(f"Error deleting organization (admin): {e!s}", exc_info=True) logger.exception("Error deleting organization (admin): %s", e)
raise raise
@@ -627,14 +763,8 @@ async def admin_list_organization_members(
) -> Any: ) -> Any:
"""List all members of an organization.""" """List all members of an organization."""
try: try:
org = await organization_crud.get(db, id=org_id) await organization_service.get_organization(db, str(org_id)) # validates exists
if not org: members, total = await organization_service.get_organization_members(
raise NotFoundError(
message=f"Organization {org_id} not found",
error_code=ErrorCode.NOT_FOUND,
)
members, total = await organization_crud.get_organization_members(
db, db,
organization_id=org_id, organization_id=org_id,
skip=pagination.offset, skip=pagination.offset,
@@ -657,9 +787,7 @@ async def admin_list_organization_members(
except NotFoundError: except NotFoundError:
raise raise
except Exception as e: except Exception as e:
logger.error( logger.exception("Error listing organization members (admin): %s", e)
f"Error listing organization members (admin): {e!s}", exc_info=True
)
raise raise
@@ -687,45 +815,32 @@ async def admin_add_organization_member(
) -> Any: ) -> Any:
"""Add a user to an organization.""" """Add a user to an organization."""
try: try:
org = await organization_crud.get(db, id=org_id) org = await organization_service.get_organization(db, str(org_id))
if not org: user = await user_service.get_user(db, str(request.user_id))
raise NotFoundError(
message=f"Organization {org_id} not found",
error_code=ErrorCode.NOT_FOUND,
)
user = await user_crud.get(db, id=request.user_id) await organization_service.add_member(
if not user:
raise NotFoundError(
message=f"User {request.user_id} not found",
error_code=ErrorCode.USER_NOT_FOUND,
)
await organization_crud.add_user(
db, organization_id=org_id, user_id=request.user_id, role=request.role db, organization_id=org_id, user_id=request.user_id, role=request.role
) )
logger.info( logger.info(
f"Admin {admin.email} added user {user.email} to organization {org.name} " "Admin %s added user %s to organization %s with role %s",
f"with role {request.role.value}" admin.email,
user.email,
org.name,
request.role.value,
) )
return MessageResponse( return MessageResponse(
success=True, message=f"User {user.email} added to organization {org.name}" success=True, message=f"User {user.email} added to organization {org.name}"
) )
except ValueError as e: except DuplicateEntryError as e:
logger.warning(f"Failed to add user to organization: {e!s}") logger.warning("Failed to add user to organization: %s", e)
# Use DuplicateError for "already exists" scenarios
raise DuplicateError( raise DuplicateError(
message=str(e), error_code=ErrorCode.USER_ALREADY_EXISTS, field="user_id" message=str(e), error_code=ErrorCode.USER_ALREADY_EXISTS, field="user_id"
) )
except NotFoundError:
raise
except Exception as e: except Exception as e:
logger.error( logger.exception("Error adding member to organization (admin): %s", e)
f"Error adding member to organization (admin): {e!s}", exc_info=True
)
raise raise
@@ -744,20 +859,10 @@ async def admin_remove_organization_member(
) -> Any: ) -> Any:
"""Remove a user from an organization.""" """Remove a user from an organization."""
try: try:
org = await organization_crud.get(db, id=org_id) org = await organization_service.get_organization(db, str(org_id))
if not org: user = await user_service.get_user(db, str(user_id))
raise NotFoundError(
message=f"Organization {org_id} not found",
error_code=ErrorCode.NOT_FOUND,
)
user = await user_crud.get(db, id=user_id) success = await organization_service.remove_member(
if not user:
raise NotFoundError(
message=f"User {user_id} not found", error_code=ErrorCode.USER_NOT_FOUND
)
success = await organization_crud.remove_user(
db, organization_id=org_id, user_id=user_id db, organization_id=org_id, user_id=user_id
) )
@@ -768,7 +873,10 @@ async def admin_remove_organization_member(
) )
logger.info( logger.info(
f"Admin {admin.email} removed user {user.email} from organization {org.name}" "Admin %s removed user %s from organization %s",
admin.email,
user.email,
org.name,
) )
return MessageResponse( return MessageResponse(
@@ -778,10 +886,8 @@ async def admin_remove_organization_member(
except NotFoundError: except NotFoundError:
raise raise
except Exception as e: except Exception as e: # pragma: no cover
logger.error( logger.exception("Error removing member from organization (admin): %s", e)
f"Error removing member from organization (admin): {e!s}", exc_info=True
)
raise raise
@@ -811,7 +917,7 @@ async def admin_list_sessions(
"""List all sessions across all users with filtering and pagination.""" """List all sessions across all users with filtering and pagination."""
try: try:
# Get sessions with user info (eager loaded to prevent N+1) # Get sessions with user info (eager loaded to prevent N+1)
sessions, total = await session_crud.get_all_sessions( sessions, total = await session_service.get_all_sessions(
db, db,
skip=pagination.offset, skip=pagination.offset,
limit=pagination.limit, limit=pagination.limit,
@@ -850,7 +956,10 @@ async def admin_list_sessions(
session_responses.append(session_response) session_responses.append(session_response)
logger.info( logger.info(
f"Admin {admin.email} listed {len(session_responses)} sessions (total: {total})" "Admin %s listed %s sessions (total: %s)",
admin.email,
len(session_responses),
total,
) )
pagination_meta = create_pagination_meta( pagination_meta = create_pagination_meta(
@@ -862,6 +971,6 @@ async def admin_list_sessions(
return PaginatedResponse(data=session_responses, pagination=pagination_meta) return PaginatedResponse(data=session_responses, pagination=pagination_meta)
except Exception as e: except Exception as e: # pragma: no cover
logger.error(f"Error listing sessions (admin): {e!s}", exc_info=True) logger.exception("Error listing sessions (admin): %s", e)
raise raise

View File

@@ -15,16 +15,14 @@ from app.core.auth import (
TokenExpiredError, TokenExpiredError,
TokenInvalidError, TokenInvalidError,
decode_token, decode_token,
get_password_hash,
) )
from app.core.database import get_db from app.core.database import get_db
from app.core.exceptions import ( from app.core.exceptions import (
AuthenticationError as AuthError, AuthenticationError as AuthError,
DatabaseError, DatabaseError,
DuplicateError,
ErrorCode, ErrorCode,
) )
from app.crud.session import session as session_crud
from app.crud.user import user as user_crud
from app.models.user import User from app.models.user import User
from app.schemas.common import MessageResponse from app.schemas.common import MessageResponse
from app.schemas.sessions import LogoutRequest, SessionCreate from app.schemas.sessions import LogoutRequest, SessionCreate
@@ -39,6 +37,8 @@ from app.schemas.users import (
) )
from app.services.auth_service import AuthenticationError, AuthService from app.services.auth_service import AuthenticationError, AuthService
from app.services.email_service import email_service from app.services.email_service import email_service
from app.services.session_service import session_service
from app.services.user_service import user_service
from app.utils.device import extract_device_info from app.utils.device import extract_device_info
from app.utils.security import create_password_reset_token, verify_password_reset_token from app.utils.security import create_password_reset_token, verify_password_reset_token
@@ -91,17 +91,18 @@ async def _create_login_session(
location_country=device_info.location_country, location_country=device_info.location_country,
) )
await session_crud.create_session(db, obj_in=session_data) await session_service.create_session(db, obj_in=session_data)
logger.info( logger.info(
f"{login_type.capitalize()} successful: {user.email} from {device_info.device_name} " "%s successful: %s from %s (IP: %s)",
f"(IP: {device_info.ip_address})" login_type.capitalize(),
user.email,
device_info.device_name,
device_info.ip_address,
) )
except Exception as session_err: except Exception as session_err:
# Log but don't fail login if session creation fails # Log but don't fail login if session creation fails
logger.error( logger.exception("Failed to create session for %s: %s", user.email, session_err)
f"Failed to create session for {user.email}: {session_err!s}", exc_info=True
)
@router.post( @router.post(
@@ -123,15 +124,21 @@ async def register_user(
try: try:
user = await AuthService.create_user(db, user_data) user = await AuthService.create_user(db, user_data)
return user return user
except AuthenticationError as e: except DuplicateError:
# SECURITY: Don't reveal if email exists - generic error message # SECURITY: Don't reveal if email exists - generic error message
logger.warning(f"Registration failed: {e!s}") logger.warning("Registration failed: duplicate email %s", user_data.email)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Registration failed. Please check your information and try again.",
)
except AuthError as e:
logger.warning("Registration failed: %s", e)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail="Registration failed. Please check your information and try again.", detail="Registration failed. Please check your information and try again.",
) )
except Exception as e: except Exception as e:
logger.error(f"Unexpected error during registration: {e!s}", exc_info=True) logger.exception("Unexpected error during registration: %s", e)
raise DatabaseError( raise DatabaseError(
message="An unexpected error occurred. Please try again later.", message="An unexpected error occurred. Please try again later.",
error_code=ErrorCode.INTERNAL_ERROR, error_code=ErrorCode.INTERNAL_ERROR,
@@ -159,7 +166,7 @@ async def login(
# Explicitly check for None result and raise correct exception # Explicitly check for None result and raise correct exception
if user is None: if user is None:
logger.warning(f"Invalid login attempt for: {login_data.email}") logger.warning("Invalid login attempt for: %s", login_data.email)
raise AuthError( raise AuthError(
message="Invalid email or password", message="Invalid email or password",
error_code=ErrorCode.INVALID_CREDENTIALS, error_code=ErrorCode.INVALID_CREDENTIALS,
@@ -175,14 +182,11 @@ async def login(
except AuthenticationError as e: except AuthenticationError as e:
# Handle specific authentication errors like inactive accounts # Handle specific authentication errors like inactive accounts
logger.warning(f"Authentication failed: {e!s}") logger.warning("Authentication failed: %s", e)
raise AuthError(message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS) raise AuthError(message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS)
except AuthError:
# Re-raise custom auth exceptions without modification
raise
except Exception as e: except Exception as e:
# Handle unexpected errors # Handle unexpected errors
logger.error(f"Unexpected error during login: {e!s}", exc_info=True) logger.exception("Unexpected error during login: %s", e)
raise DatabaseError( raise DatabaseError(
message="An unexpected error occurred. Please try again later.", message="An unexpected error occurred. Please try again later.",
error_code=ErrorCode.INTERNAL_ERROR, error_code=ErrorCode.INTERNAL_ERROR,
@@ -224,13 +228,10 @@ async def login_oauth(
# Return full token response with user data # Return full token response with user data
return tokens return tokens
except AuthenticationError as e: except AuthenticationError as e:
logger.warning(f"OAuth authentication failed: {e!s}") logger.warning("OAuth authentication failed: %s", e)
raise AuthError(message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS) raise AuthError(message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS)
except AuthError:
# Re-raise custom auth exceptions without modification
raise
except Exception as e: except Exception as e:
logger.error(f"Unexpected error during OAuth login: {e!s}", exc_info=True) logger.exception("Unexpected error during OAuth login: %s", e)
raise DatabaseError( raise DatabaseError(
message="An unexpected error occurred. Please try again later.", message="An unexpected error occurred. Please try again later.",
error_code=ErrorCode.INTERNAL_ERROR, error_code=ErrorCode.INTERNAL_ERROR,
@@ -259,11 +260,12 @@ async def refresh_token(
) )
# Check if session exists and is active # Check if session exists and is active
session = await session_crud.get_active_by_jti(db, jti=refresh_payload.jti) session = await session_service.get_active_by_jti(db, jti=refresh_payload.jti)
if not session: if not session:
logger.warning( logger.warning(
f"Refresh token used for inactive or non-existent session: {refresh_payload.jti}" "Refresh token used for inactive or non-existent session: %s",
refresh_payload.jti,
) )
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@@ -279,16 +281,14 @@ async def refresh_token(
# Update session with new refresh token JTI and expiration # Update session with new refresh token JTI and expiration
try: try:
await session_crud.update_refresh_token( await session_service.update_refresh_token(
db, db,
session=session, session=session,
new_jti=new_refresh_payload.jti, new_jti=new_refresh_payload.jti,
new_expires_at=datetime.fromtimestamp(new_refresh_payload.exp, tz=UTC), new_expires_at=datetime.fromtimestamp(new_refresh_payload.exp, tz=UTC),
) )
except Exception as session_err: except Exception as session_err:
logger.error( logger.exception("Failed to update session %s: %s", session.id, session_err)
f"Failed to update session {session.id}: {session_err!s}", exc_info=True
)
# Continue anyway - tokens are already issued # Continue anyway - tokens are already issued
return tokens return tokens
@@ -311,7 +311,7 @@ async def refresh_token(
# Re-raise HTTP exceptions (like session revoked) # Re-raise HTTP exceptions (like session revoked)
raise raise
except Exception as e: except Exception as e:
logger.error(f"Unexpected error during token refresh: {e!s}") logger.error("Unexpected error during token refresh: %s", e)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An unexpected error occurred. Please try again later.", detail="An unexpected error occurred. Please try again later.",
@@ -347,7 +347,7 @@ async def request_password_reset(
""" """
try: try:
# Look up user by email # Look up user by email
user = await user_crud.get_by_email(db, email=reset_request.email) user = await user_service.get_by_email(db, email=reset_request.email)
# Only send email if user exists and is active # Only send email if user exists and is active
if user and user.is_active: if user and user.is_active:
@@ -358,11 +358,12 @@ async def request_password_reset(
await email_service.send_password_reset_email( await email_service.send_password_reset_email(
to_email=user.email, reset_token=reset_token, user_name=user.first_name to_email=user.email, reset_token=reset_token, user_name=user.first_name
) )
logger.info(f"Password reset requested for {user.email}") logger.info("Password reset requested for %s", user.email)
else: else:
# Log attempt but don't reveal if email exists # Log attempt but don't reveal if email exists
logger.warning( logger.warning(
f"Password reset requested for non-existent or inactive email: {reset_request.email}" "Password reset requested for non-existent or inactive email: %s",
reset_request.email,
) )
# Always return success to prevent email enumeration # Always return success to prevent email enumeration
@@ -371,7 +372,7 @@ async def request_password_reset(
message="If your email is registered, you will receive a password reset link shortly", message="If your email is registered, you will receive a password reset link shortly",
) )
except Exception as e: except Exception as e:
logger.error(f"Error processing password reset request: {e!s}", exc_info=True) logger.exception("Error processing password reset request: %s", e)
# Still return success to prevent information leakage # Still return success to prevent information leakage
return MessageResponse( return MessageResponse(
success=True, success=True,
@@ -412,40 +413,34 @@ async def confirm_password_reset(
detail="Invalid or expired password reset token", detail="Invalid or expired password reset token",
) )
# Look up user # Reset password via service (validates user exists and is active)
user = await user_crud.get_by_email(db, email=email) try:
user = await AuthService.reset_password(
if not user: db, email=email, new_password=reset_confirm.new_password
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
) )
except AuthenticationError as e:
if not user.is_active: err_msg = str(e)
raise HTTPException( if "inactive" in err_msg.lower():
status_code=status.HTTP_400_BAD_REQUEST, raise HTTPException(
detail="User account is inactive", status_code=status.HTTP_400_BAD_REQUEST, detail=err_msg
) )
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=err_msg)
# Update password
user.password_hash = get_password_hash(reset_confirm.new_password)
db.add(user)
await db.commit()
# SECURITY: Invalidate all existing sessions after password reset # SECURITY: Invalidate all existing sessions after password reset
# This prevents stolen sessions from being used after password change # This prevents stolen sessions from being used after password change
from app.crud.session import session as session_crud
try: try:
deactivated_count = await session_crud.deactivate_all_user_sessions( deactivated_count = await session_service.deactivate_all_user_sessions(
db, user_id=str(user.id) db, user_id=str(user.id)
) )
logger.info( logger.info(
f"Password reset successful for {user.email}, invalidated {deactivated_count} sessions" "Password reset successful for %s, invalidated %s sessions",
user.email,
deactivated_count,
) )
except Exception as session_error: except Exception as session_error:
# Log but don't fail password reset if session invalidation fails # Log but don't fail password reset if session invalidation fails
logger.error( logger.error(
f"Failed to invalidate sessions after password reset: {session_error!s}" "Failed to invalidate sessions after password reset: %s", session_error
) )
return MessageResponse( return MessageResponse(
@@ -456,7 +451,7 @@ async def confirm_password_reset(
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
logger.error(f"Error confirming password reset: {e!s}", exc_info=True) logger.exception("Error confirming password reset: %s", e)
await db.rollback() await db.rollback()
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
@@ -506,19 +501,21 @@ async def logout(
) )
except (TokenExpiredError, TokenInvalidError) as e: except (TokenExpiredError, TokenInvalidError) as e:
# Even if token is expired/invalid, try to deactivate session # Even if token is expired/invalid, try to deactivate session
logger.warning(f"Logout with invalid/expired token: {e!s}") logger.warning("Logout with invalid/expired token: %s", e)
# Don't fail - return success anyway # Don't fail - return success anyway
return MessageResponse(success=True, message="Logged out successfully") return MessageResponse(success=True, message="Logged out successfully")
# Find the session by JTI # Find the session by JTI
session = await session_crud.get_by_jti(db, jti=refresh_payload.jti) session = await session_service.get_by_jti(db, jti=refresh_payload.jti)
if session: if session:
# Verify session belongs to current user (security check) # Verify session belongs to current user (security check)
if str(session.user_id) != str(current_user.id): if str(session.user_id) != str(current_user.id):
logger.warning( logger.warning(
f"User {current_user.id} attempted to logout session {session.id} " "User %s attempted to logout session %s belonging to user %s",
f"belonging to user {session.user_id}" current_user.id,
session.id,
session.user_id,
) )
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
@@ -526,17 +523,20 @@ async def logout(
) )
# Deactivate the session # Deactivate the session
await session_crud.deactivate(db, session_id=str(session.id)) await session_service.deactivate(db, session_id=str(session.id))
logger.info( logger.info(
f"User {current_user.id} logged out from {session.device_name} " "User %s logged out from %s (session %s)",
f"(session {session.id})" current_user.id,
session.device_name,
session.id,
) )
else: else:
# Session not found - maybe already deleted or never existed # Session not found - maybe already deleted or never existed
# Return success anyway (idempotent) # Return success anyway (idempotent)
logger.info( logger.info(
f"Logout requested for non-existent session (JTI: {refresh_payload.jti})" "Logout requested for non-existent session (JTI: %s)",
refresh_payload.jti,
) )
return MessageResponse(success=True, message="Logged out successfully") return MessageResponse(success=True, message="Logged out successfully")
@@ -544,9 +544,7 @@ async def logout(
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
logger.error( logger.exception("Error during logout for user %s: %s", current_user.id, e)
f"Error during logout for user {current_user.id}: {e!s}", exc_info=True
)
# Don't expose error details # Don't expose error details
return MessageResponse(success=True, message="Logged out successfully") return MessageResponse(success=True, message="Logged out successfully")
@@ -584,12 +582,12 @@ async def logout_all(
""" """
try: try:
# Deactivate all sessions for this user # Deactivate all sessions for this user
count = await session_crud.deactivate_all_user_sessions( count = await session_service.deactivate_all_user_sessions(
db, user_id=str(current_user.id) db, user_id=str(current_user.id)
) )
logger.info( logger.info(
f"User {current_user.id} logged out from all devices ({count} sessions)" "User %s logged out from all devices (%s sessions)", current_user.id, count
) )
return MessageResponse( return MessageResponse(
@@ -598,9 +596,7 @@ async def logout_all(
) )
except Exception as e: except Exception as e:
logger.error( logger.exception("Error during logout-all for user %s: %s", current_user.id, e)
f"Error during logout-all for user {current_user.id}: {e!s}", exc_info=True
)
await db.rollback() await db.rollback()
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,

View File

@@ -0,0 +1,434 @@
# app/api/routes/oauth.py
"""
OAuth routes for social authentication.
Endpoints:
- GET /oauth/providers - List enabled OAuth providers
- GET /oauth/authorize/{provider} - Get authorization URL
- POST /oauth/callback/{provider} - Handle OAuth callback
- GET /oauth/accounts - List linked OAuth accounts
- DELETE /oauth/accounts/{provider} - Unlink an OAuth account
"""
import logging
import os
from datetime import UTC, datetime
from typing import Any
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from slowapi import Limiter
from slowapi.util import get_remote_address
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.auth import get_current_user, get_optional_current_user
from app.core.auth import decode_token
from app.core.config import settings
from app.core.database import get_db
from app.core.exceptions import AuthenticationError as AuthError
from app.models.user import User
from app.schemas.oauth import (
OAuthAccountsListResponse,
OAuthCallbackRequest,
OAuthCallbackResponse,
OAuthProvidersResponse,
OAuthUnlinkResponse,
)
from app.schemas.sessions import SessionCreate
from app.schemas.users import Token
from app.services.oauth_service import OAuthService
from app.services.session_service import session_service
from app.utils.device import extract_device_info
router = APIRouter()
logger = logging.getLogger(__name__)
# Initialize limiter for this router
limiter = Limiter(key_func=get_remote_address)
# Use higher rate limits in test environment
IS_TEST = os.getenv("IS_TEST", "False") == "True"
RATE_MULTIPLIER = 100 if IS_TEST else 1
async def _create_oauth_login_session(
db: AsyncSession,
request: Request,
user: User,
tokens: Token,
provider: str,
) -> None:
"""
Create a session record for successful OAuth login.
This is a best-effort operation - login succeeds even if session creation fails.
"""
try:
device_info = extract_device_info(request)
# Decode refresh token to get JTI and expiration
refresh_payload = decode_token(tokens.refresh_token, verify_type="refresh")
session_data = SessionCreate(
user_id=user.id,
refresh_token_jti=refresh_payload.jti,
device_name=device_info.device_name or f"OAuth ({provider})",
device_id=device_info.device_id,
ip_address=device_info.ip_address,
user_agent=device_info.user_agent,
last_used_at=datetime.now(UTC),
expires_at=datetime.fromtimestamp(refresh_payload.exp, tz=UTC),
location_city=device_info.location_city,
location_country=device_info.location_country,
)
await session_service.create_session(db, obj_in=session_data)
logger.info(
"OAuth login successful: %s via %s from %s (IP: %s)",
user.email,
provider,
device_info.device_name,
device_info.ip_address,
)
except Exception as session_err:
# Log but don't fail login if session creation fails
logger.exception(
"Failed to create session for OAuth login %s: %s", user.email, session_err
)
@router.get(
"/providers",
response_model=OAuthProvidersResponse,
summary="List OAuth Providers",
description="""
Get list of enabled OAuth providers for the login/register UI.
Returns:
List of enabled providers with display info.
""",
operation_id="list_oauth_providers",
)
async def list_providers() -> Any:
"""
Get list of enabled OAuth providers.
This endpoint is public (no authentication required) as it's needed
for the login/register UI to display available social login options.
"""
return OAuthService.get_enabled_providers()
@router.get(
"/authorize/{provider}",
response_model=dict,
summary="Get OAuth Authorization URL",
description="""
Get the authorization URL to redirect the user to the OAuth provider.
The frontend should redirect the user to the returned URL.
After authentication, the provider will redirect back to the callback URL.
**Rate Limit**: 10 requests/minute
""",
operation_id="get_oauth_authorization_url",
)
@limiter.limit(f"{10 * RATE_MULTIPLIER}/minute")
async def get_authorization_url(
request: Request,
provider: str,
redirect_uri: str = Query(
..., description="Frontend callback URL after OAuth completes"
),
current_user: User | None = Depends(get_optional_current_user),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Get OAuth authorization URL.
Args:
provider: OAuth provider (google, github)
redirect_uri: Frontend callback URL
current_user: Current user (optional, for account linking)
db: Database session
Returns:
dict with authorization_url and state
"""
if not settings.OAUTH_ENABLED:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="OAuth is not enabled",
)
try:
# If user is logged in, this is an account linking flow
user_id = str(current_user.id) if current_user else None
url, state = await OAuthService.create_authorization_url(
db,
provider=provider,
redirect_uri=redirect_uri,
user_id=user_id,
)
return {
"authorization_url": url,
"state": state,
}
except AuthError as e:
logger.warning("OAuth authorization failed: %s", e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
)
except Exception as e:
logger.exception("OAuth authorization error: %s", e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to create authorization URL",
)
@router.post(
"/callback/{provider}",
response_model=OAuthCallbackResponse,
summary="OAuth Callback",
description="""
Handle OAuth callback from provider.
The frontend should call this endpoint with the code and state
parameters received from the OAuth provider redirect.
Returns:
JWT tokens for the authenticated user.
**Rate Limit**: 10 requests/minute
""",
operation_id="handle_oauth_callback",
)
@limiter.limit(f"{10 * RATE_MULTIPLIER}/minute")
async def handle_callback(
request: Request,
provider: str,
callback_data: OAuthCallbackRequest,
redirect_uri: str = Query(
..., description="Must match the redirect_uri used in authorization"
),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Handle OAuth callback.
Args:
provider: OAuth provider (google, github)
callback_data: Code and state from provider
redirect_uri: Original redirect URI (for validation)
db: Database session
Returns:
OAuthCallbackResponse with tokens
"""
if not settings.OAUTH_ENABLED:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="OAuth is not enabled",
)
try:
result = await OAuthService.handle_callback(
db,
code=callback_data.code,
state=callback_data.state,
redirect_uri=redirect_uri,
)
# Create session for the login (need to get the user first)
# Note: This requires fetching the user from the token
# For now, we skip session creation here as the result doesn't include user info
# The session will be created on next request if needed
return result
except AuthError as e:
logger.warning("OAuth callback failed: %s", e)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=str(e),
)
except Exception as e:
logger.exception("OAuth callback error: %s", e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="OAuth authentication failed",
)
@router.get(
"/accounts",
response_model=OAuthAccountsListResponse,
summary="List Linked OAuth Accounts",
description="""
Get list of OAuth accounts linked to the current user.
Requires authentication.
""",
operation_id="list_oauth_accounts",
)
async def list_accounts(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
List OAuth accounts linked to the current user.
Args:
current_user: Current authenticated user
db: Database session
Returns:
List of linked OAuth accounts
"""
accounts = await OAuthService.get_user_accounts(db, user_id=current_user.id)
return OAuthAccountsListResponse(accounts=accounts)
@router.delete(
"/accounts/{provider}",
response_model=OAuthUnlinkResponse,
summary="Unlink OAuth Account",
description="""
Unlink an OAuth provider from the current user.
The user must have either a password set or another OAuth provider
linked to ensure they can still log in.
**Rate Limit**: 5 requests/minute
""",
operation_id="unlink_oauth_account",
)
@limiter.limit(f"{5 * RATE_MULTIPLIER}/minute")
async def unlink_account(
request: Request,
provider: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Unlink an OAuth provider from the current user.
Args:
provider: Provider to unlink (google, github)
current_user: Current authenticated user
db: Database session
Returns:
Success message
"""
try:
await OAuthService.unlink_provider(
db,
user=current_user,
provider=provider,
)
return OAuthUnlinkResponse(
success=True,
message=f"{provider.capitalize()} account unlinked successfully",
)
except AuthError as e:
logger.warning("OAuth unlink failed for %s: %s", current_user.email, e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
)
except Exception as e:
logger.exception("OAuth unlink error: %s", e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to unlink OAuth account",
)
@router.post(
"/link/{provider}",
response_model=dict,
summary="Start Account Linking",
description="""
Start the OAuth flow to link a new provider to the current user.
This is a convenience endpoint that redirects to /authorize/{provider}
with the current user context.
**Rate Limit**: 10 requests/minute
""",
operation_id="start_oauth_link",
)
@limiter.limit(f"{10 * RATE_MULTIPLIER}/minute")
async def start_link(
request: Request,
provider: str,
redirect_uri: str = Query(
..., description="Frontend callback URL after OAuth completes"
),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Start OAuth account linking flow.
This endpoint requires authentication and will initiate an OAuth flow
to link a new provider to the current user's account.
Args:
provider: OAuth provider to link (google, github)
redirect_uri: Frontend callback URL
current_user: Current authenticated user
db: Database session
Returns:
dict with authorization_url and state
"""
if not settings.OAUTH_ENABLED:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="OAuth is not enabled",
)
# Check if user already has this provider linked
existing = await OAuthService.get_user_account_by_provider(
db, user_id=current_user.id, provider=provider
)
if existing:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"You already have a {provider} account linked",
)
try:
url, state = await OAuthService.create_authorization_url(
db,
provider=provider,
redirect_uri=redirect_uri,
user_id=str(current_user.id),
)
return {
"authorization_url": url,
"state": state,
}
except AuthError as e:
logger.warning("OAuth link authorization failed: %s", e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
)
except Exception as e:
logger.exception("OAuth link error: %s", e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to create authorization URL",
)

View File

@@ -0,0 +1,824 @@
# app/api/routes/oauth_provider.py
"""
OAuth Provider routes (Authorization Server mode) for MCP integration.
Implements OAuth 2.0 Authorization Server endpoints:
- GET /.well-known/oauth-authorization-server - Server metadata (RFC 8414)
- GET /oauth/provider/authorize - Authorization endpoint
- POST /oauth/provider/token - Token endpoint
- POST /oauth/provider/revoke - Token revocation (RFC 7009)
- POST /oauth/provider/introspect - Token introspection (RFC 7662)
- Client management endpoints
Security features:
- PKCE required for public clients (S256)
- CSRF protection via state parameter
- Secure token handling
- Rate limiting on sensitive endpoints
"""
import logging
from typing import Any
from urllib.parse import urlencode
from fastapi import APIRouter, Depends, Form, HTTPException, Query, Request, status
from fastapi.responses import RedirectResponse
from slowapi import Limiter
from slowapi.util import get_remote_address
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.auth import (
get_current_active_user,
get_current_superuser,
get_optional_current_user,
)
from app.core.config import settings
from app.core.database import get_db
from app.models.user import User
from app.schemas.oauth import (
OAuthClientCreate,
OAuthClientResponse,
OAuthServerMetadata,
OAuthTokenIntrospectionResponse,
OAuthTokenResponse,
)
from app.services import oauth_provider_service as provider_service
router = APIRouter()
# Separate router for RFC 8414 well-known endpoint (registered at root level)
wellknown_router = APIRouter()
logger = logging.getLogger(__name__)
limiter = Limiter(key_func=get_remote_address)
def require_provider_enabled():
"""Dependency to check if OAuth provider mode is enabled."""
if not settings.OAUTH_PROVIDER_ENABLED:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="OAuth provider mode is not enabled. Set OAUTH_PROVIDER_ENABLED=true",
)
# ============================================================================
# Server Metadata (RFC 8414)
# ============================================================================
@wellknown_router.get(
"/.well-known/oauth-authorization-server",
response_model=OAuthServerMetadata,
summary="OAuth Server Metadata",
description="""
OAuth 2.0 Authorization Server Metadata (RFC 8414).
Returns server metadata including supported endpoints, scopes,
and capabilities. MCP clients use this to discover the server.
Note: This endpoint is at the root level per RFC 8414.
""",
operation_id="get_oauth_server_metadata",
tags=["OAuth Provider"],
)
async def get_server_metadata(
_: None = Depends(require_provider_enabled),
) -> OAuthServerMetadata:
"""Get OAuth 2.0 server metadata."""
base_url = settings.OAUTH_ISSUER.rstrip("/")
return OAuthServerMetadata(
issuer=base_url,
authorization_endpoint=f"{base_url}/api/v1/oauth/provider/authorize",
token_endpoint=f"{base_url}/api/v1/oauth/provider/token",
revocation_endpoint=f"{base_url}/api/v1/oauth/provider/revoke",
introspection_endpoint=f"{base_url}/api/v1/oauth/provider/introspect",
registration_endpoint=None, # Dynamic registration not supported
scopes_supported=[
"openid",
"profile",
"email",
"read:users",
"write:users",
"read:organizations",
"write:organizations",
"admin",
],
response_types_supported=["code"],
grant_types_supported=["authorization_code", "refresh_token"],
code_challenge_methods_supported=["S256"],
token_endpoint_auth_methods_supported=[
"client_secret_basic",
"client_secret_post",
"none", # For public clients with PKCE
],
)
# ============================================================================
# Authorization Endpoint
# ============================================================================
@router.get(
"/provider/authorize",
summary="Authorization Endpoint",
description="""
OAuth 2.0 Authorization Endpoint.
Initiates the authorization code flow:
1. Validates client and parameters
2. Checks if user is authenticated (redirects to login if not)
3. Checks existing consent
4. Redirects to consent page if needed
5. Issues authorization code and redirects back to client
Required parameters:
- response_type: Must be "code"
- client_id: Registered client ID
- redirect_uri: Must match registered URI
Recommended parameters:
- state: CSRF protection
- code_challenge + code_challenge_method: PKCE (required for public clients)
- scope: Requested permissions
""",
operation_id="oauth_provider_authorize",
tags=["OAuth Provider"],
)
@limiter.limit("30/minute")
async def authorize(
request: Request,
response_type: str = Query(..., description="Must be 'code'"),
client_id: str = Query(..., description="OAuth client ID"),
redirect_uri: str = Query(..., description="Redirect URI"),
scope: str = Query(default="", description="Requested scopes (space-separated)"),
state: str = Query(default="", description="CSRF state parameter"),
code_challenge: str | None = Query(default=None, description="PKCE code challenge"),
code_challenge_method: str | None = Query(
default=None, description="PKCE method (S256)"
),
nonce: str | None = Query(default=None, description="OpenID Connect nonce"),
db: AsyncSession = Depends(get_db),
_: None = Depends(require_provider_enabled),
current_user: User | None = Depends(get_optional_current_user),
) -> Any:
"""
Authorization endpoint - initiates OAuth flow.
If user is not authenticated, redirects to login with return URL.
If user has not consented, redirects to consent page.
If all checks pass, generates code and redirects to client.
"""
# Validate response_type
if response_type != "code":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="invalid_request: response_type must be 'code'",
)
# Validate PKCE method if provided - ONLY S256 is allowed (RFC 7636 Section 4.3)
# "plain" method provides no security benefit and MUST NOT be used
if code_challenge_method and code_challenge_method != "S256":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="invalid_request: code_challenge_method must be 'S256' (plain is not supported)",
)
# Validate client
try:
client = await provider_service.get_client(db, client_id)
if not client:
raise provider_service.InvalidClientError("Unknown client_id")
provider_service.validate_redirect_uri(client, redirect_uri)
except provider_service.OAuthProviderError as e:
# For client/redirect errors, we can't safely redirect - show error
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"{e.error}: {e.error_description}",
)
# Validate and filter scopes
try:
requested_scopes = provider_service.parse_scope(scope)
valid_scopes = provider_service.validate_scopes(client, requested_scopes)
except provider_service.InvalidScopeError as e:
# Redirect with error
scope_error_params: dict[str, str] = {"error": e.error}
if e.error_description:
scope_error_params["error_description"] = e.error_description
if state:
scope_error_params["state"] = state
return RedirectResponse(
url=f"{redirect_uri}?{urlencode(scope_error_params)}",
status_code=status.HTTP_302_FOUND,
)
# Public clients MUST use PKCE
if client.client_type == "public":
if not code_challenge or code_challenge_method != "S256":
pkce_error_params: dict[str, str] = {
"error": "invalid_request",
"error_description": "PKCE with S256 is required for public clients",
}
if state:
pkce_error_params["state"] = state
return RedirectResponse(
url=f"{redirect_uri}?{urlencode(pkce_error_params)}",
status_code=status.HTTP_302_FOUND,
)
# If user is not authenticated, redirect to login
if not current_user:
# Store authorization request in session and redirect to login
# The frontend will handle the return URL
login_url = f"{settings.FRONTEND_URL}/login"
return_params = urlencode(
{
"oauth_authorize": "true",
"client_id": client_id,
"redirect_uri": redirect_uri,
"scope": " ".join(valid_scopes),
"state": state,
"code_challenge": code_challenge or "",
"code_challenge_method": code_challenge_method or "",
"nonce": nonce or "",
}
)
return RedirectResponse(
url=f"{login_url}?return_to=/auth/consent?{return_params}",
status_code=status.HTTP_302_FOUND,
)
# Check if user has already consented
has_consent = await provider_service.check_consent(
db, current_user.id, client_id, valid_scopes
)
if not has_consent:
# Redirect to consent page
consent_params = urlencode(
{
"client_id": client_id,
"client_name": client.client_name,
"redirect_uri": redirect_uri,
"scope": " ".join(valid_scopes),
"state": state,
"code_challenge": code_challenge or "",
"code_challenge_method": code_challenge_method or "",
"nonce": nonce or "",
}
)
return RedirectResponse(
url=f"{settings.FRONTEND_URL}/auth/consent?{consent_params}",
status_code=status.HTTP_302_FOUND,
)
# User is authenticated and has consented - issue authorization code
try:
code = await provider_service.create_authorization_code(
db=db,
client=client,
user=current_user,
redirect_uri=redirect_uri,
scope=" ".join(valid_scopes),
code_challenge=code_challenge,
code_challenge_method=code_challenge_method,
state=state,
nonce=nonce,
)
except provider_service.OAuthProviderError as e:
error_params: dict[str, str] = {"error": e.error}
if e.error_description:
error_params["error_description"] = e.error_description
if state:
error_params["state"] = state
return RedirectResponse(
url=f"{redirect_uri}?{urlencode(error_params)}",
status_code=status.HTTP_302_FOUND,
)
# Success - redirect with code
success_params = {"code": code}
if state:
success_params["state"] = state
return RedirectResponse(
url=f"{redirect_uri}?{urlencode(success_params)}",
status_code=status.HTTP_302_FOUND,
)
@router.post(
"/provider/authorize/consent",
summary="Submit Authorization Consent",
description="""
Submit user consent for OAuth authorization.
Called by the consent page after user approves or denies.
""",
operation_id="oauth_provider_consent",
tags=["OAuth Provider"],
)
@limiter.limit("30/minute")
async def submit_consent(
request: Request,
approved: bool = Form(..., description="Whether user approved"),
client_id: str = Form(..., description="OAuth client ID"),
redirect_uri: str = Form(..., description="Redirect URI"),
scope: str = Form(default="", description="Granted scopes"),
state: str = Form(default="", description="CSRF state parameter"),
code_challenge: str | None = Form(default=None),
code_challenge_method: str | None = Form(default=None),
nonce: str | None = Form(default=None),
db: AsyncSession = Depends(get_db),
_: None = Depends(require_provider_enabled),
current_user: User = Depends(get_current_active_user),
) -> Any:
"""Process consent form submission."""
# Validate client
try:
client = await provider_service.get_client(db, client_id)
if not client:
raise provider_service.InvalidClientError("Unknown client_id")
provider_service.validate_redirect_uri(client, redirect_uri)
except provider_service.OAuthProviderError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"{e.error}: {e.error_description}",
)
# If user denied, redirect with error
if not approved:
denied_params: dict[str, str] = {
"error": "access_denied",
"error_description": "User denied authorization",
}
if state:
denied_params["state"] = state
return RedirectResponse(
url=f"{redirect_uri}?{urlencode(denied_params)}",
status_code=status.HTTP_302_FOUND,
)
# Parse and validate scopes
granted_scopes = provider_service.parse_scope(scope)
valid_scopes = provider_service.validate_scopes(client, granted_scopes)
# Record consent
await provider_service.grant_consent(db, current_user.id, client_id, valid_scopes)
# Generate authorization code
try:
code = await provider_service.create_authorization_code(
db=db,
client=client,
user=current_user,
redirect_uri=redirect_uri,
scope=" ".join(valid_scopes),
code_challenge=code_challenge,
code_challenge_method=code_challenge_method,
state=state,
nonce=nonce,
)
except provider_service.OAuthProviderError as e:
error_params: dict[str, str] = {"error": e.error}
if e.error_description:
error_params["error_description"] = e.error_description
if state:
error_params["state"] = state
return RedirectResponse(
url=f"{redirect_uri}?{urlencode(error_params)}",
status_code=status.HTTP_302_FOUND,
)
# Success
success_params = {"code": code}
if state:
success_params["state"] = state
return RedirectResponse(
url=f"{redirect_uri}?{urlencode(success_params)}",
status_code=status.HTTP_302_FOUND,
)
# ============================================================================
# Token Endpoint
# ============================================================================
@router.post(
"/provider/token",
response_model=OAuthTokenResponse,
summary="Token Endpoint",
description="""
OAuth 2.0 Token Endpoint.
Supports:
- authorization_code: Exchange code for tokens
- refresh_token: Refresh access token
Client authentication:
- Confidential clients: client_secret (Basic auth or POST body)
- Public clients: No secret, but PKCE code_verifier required
""",
operation_id="oauth_provider_token",
tags=["OAuth Provider"],
)
@limiter.limit("60/minute")
async def token(
request: Request,
grant_type: str = Form(..., description="Grant type"),
code: str | None = Form(default=None, description="Authorization code"),
redirect_uri: str | None = Form(default=None, description="Redirect URI"),
client_id: str | None = Form(default=None, description="Client ID"),
client_secret: str | None = Form(default=None, description="Client secret"),
code_verifier: str | None = Form(default=None, description="PKCE code verifier"),
refresh_token: str | None = Form(default=None, description="Refresh token"),
scope: str | None = Form(default=None, description="Scope (for refresh)"),
db: AsyncSession = Depends(get_db),
_: None = Depends(require_provider_enabled),
) -> OAuthTokenResponse:
"""Token endpoint - exchange code for tokens or refresh."""
# Extract client credentials from Basic auth if not in body
if not client_id:
auth_header = request.headers.get("Authorization", "")
if auth_header.startswith("Basic "):
import base64
try:
decoded = base64.b64decode(auth_header[6:]).decode()
client_id, client_secret = decoded.split(":", 1)
except Exception as e:
# Log malformed Basic auth for security monitoring
logger.warning(
"Malformed Basic auth header in token request: %s", type(e).__name__
)
# Fall back to form body
if not client_id:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="invalid_client: client_id required",
headers={"WWW-Authenticate": "Basic"},
)
# Get device info
device_info = request.headers.get("User-Agent", "")[:500]
ip_address = get_remote_address(request)
try:
if grant_type == "authorization_code":
if not code:
raise provider_service.InvalidRequestError("code required")
if not redirect_uri:
raise provider_service.InvalidRequestError("redirect_uri required")
result = await provider_service.exchange_authorization_code(
db=db,
code=code,
client_id=client_id,
redirect_uri=redirect_uri,
code_verifier=code_verifier,
client_secret=client_secret,
device_info=device_info,
ip_address=ip_address,
)
elif grant_type == "refresh_token":
if not refresh_token:
raise provider_service.InvalidRequestError("refresh_token required")
result = await provider_service.refresh_tokens(
db=db,
refresh_token=refresh_token,
client_id=client_id,
client_secret=client_secret,
scope=scope,
device_info=device_info,
ip_address=ip_address,
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="unsupported_grant_type: Must be authorization_code or refresh_token",
)
return OAuthTokenResponse(**result)
except provider_service.InvalidClientError as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"{e.error}: {e.error_description}",
headers={"WWW-Authenticate": "Basic"},
)
except provider_service.OAuthProviderError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"{e.error}: {e.error_description}",
)
# ============================================================================
# Token Revocation (RFC 7009)
# ============================================================================
@router.post(
"/provider/revoke",
status_code=status.HTTP_200_OK,
summary="Token Revocation Endpoint",
description="""
OAuth 2.0 Token Revocation Endpoint (RFC 7009).
Revokes an access token or refresh token.
Always returns 200 OK (even if token is invalid) per spec.
""",
operation_id="oauth_provider_revoke",
tags=["OAuth Provider"],
)
@limiter.limit("30/minute")
async def revoke(
request: Request,
token: str = Form(..., description="Token to revoke"),
token_type_hint: str | None = Form(
default=None, description="Token type hint (access_token, refresh_token)"
),
client_id: str | None = Form(default=None, description="Client ID"),
client_secret: str | None = Form(default=None, description="Client secret"),
db: AsyncSession = Depends(get_db),
_: None = Depends(require_provider_enabled),
) -> dict[str, str]:
"""Revoke a token."""
# Extract client credentials from Basic auth if not in body
if not client_id:
auth_header = request.headers.get("Authorization", "")
if auth_header.startswith("Basic "):
import base64
try:
decoded = base64.b64decode(auth_header[6:]).decode()
client_id, client_secret = decoded.split(":", 1)
except Exception as e:
# Log malformed Basic auth for security monitoring
logger.warning(
"Malformed Basic auth header in revoke request: %s",
type(e).__name__,
)
# Fall back to form body
try:
await provider_service.revoke_token(
db=db,
token=token,
token_type_hint=token_type_hint,
client_id=client_id,
client_secret=client_secret,
)
except provider_service.InvalidClientError:
# Per RFC 7009, we should return 200 OK even for errors
# But client authentication errors can return 401
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="invalid_client",
headers={"WWW-Authenticate": "Basic"},
)
except Exception as e:
# Log but don't expose errors per RFC 7009
logger.warning("Token revocation error: %s", e)
# Always return 200 OK per RFC 7009
return {"status": "ok"}
# ============================================================================
# Token Introspection (RFC 7662)
# ============================================================================
@router.post(
"/provider/introspect",
response_model=OAuthTokenIntrospectionResponse,
summary="Token Introspection Endpoint",
description="""
OAuth 2.0 Token Introspection Endpoint (RFC 7662).
Allows resource servers to query the authorization server
to determine the active state and metadata of a token.
""",
operation_id="oauth_provider_introspect",
tags=["OAuth Provider"],
)
@limiter.limit("120/minute")
async def introspect(
request: Request,
token: str = Form(..., description="Token to introspect"),
token_type_hint: str | None = Form(
default=None, description="Token type hint (access_token, refresh_token)"
),
client_id: str | None = Form(default=None, description="Client ID"),
client_secret: str | None = Form(default=None, description="Client secret"),
db: AsyncSession = Depends(get_db),
_: None = Depends(require_provider_enabled),
) -> OAuthTokenIntrospectionResponse:
"""Introspect a token."""
# Extract client credentials from Basic auth if not in body
if not client_id:
auth_header = request.headers.get("Authorization", "")
if auth_header.startswith("Basic "):
import base64
try:
decoded = base64.b64decode(auth_header[6:]).decode()
client_id, client_secret = decoded.split(":", 1)
except Exception as e:
# Log malformed Basic auth for security monitoring
logger.warning(
"Malformed Basic auth header in introspect request: %s",
type(e).__name__,
)
# Fall back to form body
try:
result = await provider_service.introspect_token(
db=db,
token=token,
token_type_hint=token_type_hint,
client_id=client_id,
client_secret=client_secret,
)
return OAuthTokenIntrospectionResponse(**result)
except provider_service.InvalidClientError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="invalid_client",
headers={"WWW-Authenticate": "Basic"},
)
except Exception as e:
logger.warning("Token introspection error: %s", e)
return OAuthTokenIntrospectionResponse(active=False) # pyright: ignore[reportCallIssue]
# ============================================================================
# Client Management (Admin)
# ============================================================================
@router.post(
"/provider/clients",
response_model=dict,
summary="Register OAuth Client",
description="""
Register a new OAuth client (admin only).
Creates an MCP client that can authenticate against this API.
Returns client_id and client_secret (for confidential clients).
**Important:** Store the client_secret securely - it won't be shown again!
""",
operation_id="register_oauth_client",
tags=["OAuth Provider Admin"],
)
async def register_client(
client_name: str = Form(..., description="Client application name"),
redirect_uris: str = Form(..., description="Comma-separated redirect URIs"),
client_type: str = Form(default="public", description="public or confidential"),
scopes: str = Form(
default="openid profile email",
description="Allowed scopes (space-separated)",
),
mcp_server_url: str | None = Form(default=None, description="MCP server URL"),
db: AsyncSession = Depends(get_db),
_: None = Depends(require_provider_enabled),
current_user: User = Depends(get_current_superuser),
) -> dict:
"""Register a new OAuth client."""
# Parse redirect URIs
uris = [uri.strip() for uri in redirect_uris.split(",") if uri.strip()]
if not uris:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="At least one redirect_uri is required",
)
# Parse scopes
allowed_scopes = [s.strip() for s in scopes.split() if s.strip()]
client_data = OAuthClientCreate(
client_name=client_name,
client_description=None,
redirect_uris=uris,
allowed_scopes=allowed_scopes,
client_type=client_type,
)
client, secret = await provider_service.register_client(db, client_data)
# Update MCP server URL if provided
if mcp_server_url:
client.mcp_server_url = mcp_server_url
await db.commit()
result = {
"client_id": client.client_id,
"client_name": client.client_name,
"client_type": client.client_type,
"redirect_uris": client.redirect_uris,
"allowed_scopes": client.allowed_scopes,
}
if secret:
result["client_secret"] = secret
result["warning"] = (
"Store the client_secret securely! It will not be shown again."
)
return result
@router.get(
"/provider/clients",
response_model=list[OAuthClientResponse],
summary="List OAuth Clients",
description="List all registered OAuth clients (admin only).",
operation_id="list_oauth_clients",
tags=["OAuth Provider Admin"],
)
async def list_clients(
db: AsyncSession = Depends(get_db),
_: None = Depends(require_provider_enabled),
current_user: User = Depends(get_current_superuser),
) -> list[OAuthClientResponse]:
"""List all OAuth clients."""
clients = await provider_service.list_clients(db)
return [OAuthClientResponse.model_validate(c) for c in clients]
@router.delete(
"/provider/clients/{client_id}",
status_code=status.HTTP_204_NO_CONTENT,
summary="Delete OAuth Client",
description="Delete an OAuth client (admin only). Revokes all tokens.",
operation_id="delete_oauth_client",
tags=["OAuth Provider Admin"],
)
async def delete_client(
client_id: str,
db: AsyncSession = Depends(get_db),
_: None = Depends(require_provider_enabled),
current_user: User = Depends(get_current_superuser),
) -> None:
"""Delete an OAuth client."""
client = await provider_service.get_client(db, client_id)
if not client:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Client not found",
)
await provider_service.delete_client_by_id(db, client_id=client_id)
# ============================================================================
# User Consent Management
# ============================================================================
@router.get(
"/provider/consents",
summary="List My Consents",
description="List OAuth applications the current user has authorized.",
operation_id="list_my_oauth_consents",
tags=["OAuth Provider"],
)
async def list_my_consents(
db: AsyncSession = Depends(get_db),
_: None = Depends(require_provider_enabled),
current_user: User = Depends(get_current_active_user),
) -> list[dict]:
"""List applications the user has authorized."""
return await provider_service.list_user_consents(db, user_id=current_user.id)
@router.delete(
"/provider/consents/{client_id}",
status_code=status.HTTP_204_NO_CONTENT,
summary="Revoke My Consent",
description="Revoke authorization for an OAuth application. Also revokes all tokens.",
operation_id="revoke_my_oauth_consent",
tags=["OAuth Provider"],
)
async def revoke_my_consent(
client_id: str,
db: AsyncSession = Depends(get_db),
_: None = Depends(require_provider_enabled),
current_user: User = Depends(get_current_active_user),
) -> None:
"""Revoke consent for an application."""
revoked = await provider_service.revoke_consent(db, current_user.id, client_id)
if not revoked:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="No consent found for this client",
)

View File

@@ -15,8 +15,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.auth import get_current_user from app.api.dependencies.auth import get_current_user
from app.api.dependencies.permissions import require_org_admin, require_org_membership from app.api.dependencies.permissions import require_org_admin, require_org_membership
from app.core.database import get_db from app.core.database import get_db
from app.core.exceptions import ErrorCode, NotFoundError
from app.crud.organization import organization as organization_crud
from app.models.user import User from app.models.user import User
from app.schemas.common import ( from app.schemas.common import (
PaginatedResponse, PaginatedResponse,
@@ -28,6 +26,7 @@ from app.schemas.organizations import (
OrganizationResponse, OrganizationResponse,
OrganizationUpdate, OrganizationUpdate,
) )
from app.services.organization_service import organization_service
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -54,7 +53,7 @@ async def get_my_organizations(
""" """
try: try:
# Get all org data in single query with JOIN and subquery # Get all org data in single query with JOIN and subquery
orgs_data = await organization_crud.get_user_organizations_with_details( orgs_data = await organization_service.get_user_organizations_with_details(
db, user_id=current_user.id, is_active=is_active db, user_id=current_user.id, is_active=is_active
) )
@@ -78,7 +77,7 @@ async def get_my_organizations(
return orgs_with_data return orgs_with_data
except Exception as e: except Exception as e:
logger.error(f"Error getting user organizations: {e!s}", exc_info=True) logger.exception("Error getting user organizations: %s", e)
raise raise
@@ -100,13 +99,7 @@ async def get_organization(
User must be a member of the organization. User must be a member of the organization.
""" """
try: try:
org = await organization_crud.get(db, id=organization_id) org = await organization_service.get_organization(db, str(organization_id))
if not org: # pragma: no cover - Permission check prevents this (see docs/UNREACHABLE_DEFENSIVE_CODE_ANALYSIS.md)
raise NotFoundError(
detail=f"Organization {organization_id} not found",
error_code=ErrorCode.NOT_FOUND,
)
org_dict = { org_dict = {
"id": org.id, "id": org.id,
"name": org.name, "name": org.name,
@@ -116,16 +109,14 @@ async def get_organization(
"settings": org.settings, "settings": org.settings,
"created_at": org.created_at, "created_at": org.created_at,
"updated_at": org.updated_at, "updated_at": org.updated_at,
"member_count": await organization_crud.get_member_count( "member_count": await organization_service.get_member_count(
db, organization_id=org.id db, organization_id=org.id
), ),
} }
return OrganizationResponse(**org_dict) return OrganizationResponse(**org_dict)
except NotFoundError: # pragma: no cover - See above
raise
except Exception as e: except Exception as e:
logger.error(f"Error getting organization: {e!s}", exc_info=True) logger.exception("Error getting organization: %s", e)
raise raise
@@ -149,7 +140,7 @@ async def get_organization_members(
User must be a member of the organization to view members. User must be a member of the organization to view members.
""" """
try: try:
members, total = await organization_crud.get_organization_members( members, total = await organization_service.get_organization_members(
db, db,
organization_id=organization_id, organization_id=organization_id,
skip=pagination.offset, skip=pagination.offset,
@@ -169,7 +160,7 @@ async def get_organization_members(
return PaginatedResponse(data=member_responses, pagination=pagination_meta) return PaginatedResponse(data=member_responses, pagination=pagination_meta)
except Exception as e: except Exception as e:
logger.error(f"Error getting organization members: {e!s}", exc_info=True) logger.exception("Error getting organization members: %s", e)
raise raise
@@ -192,16 +183,12 @@ async def update_organization(
Requires owner or admin role in the organization. Requires owner or admin role in the organization.
""" """
try: try:
org = await organization_crud.get(db, id=organization_id) org = await organization_service.get_organization(db, str(organization_id))
if not org: # pragma: no cover - Permission check prevents this (see docs/UNREACHABLE_DEFENSIVE_CODE_ANALYSIS.md) updated_org = await organization_service.update_organization(
raise NotFoundError( db, org=org, obj_in=org_in
detail=f"Organization {organization_id} not found", )
error_code=ErrorCode.NOT_FOUND,
)
updated_org = await organization_crud.update(db, db_obj=org, obj_in=org_in)
logger.info( logger.info(
f"User {current_user.email} updated organization {updated_org.name}" "User %s updated organization %s", current_user.email, updated_org.name
) )
org_dict = { org_dict = {
@@ -213,14 +200,12 @@ async def update_organization(
"settings": updated_org.settings, "settings": updated_org.settings,
"created_at": updated_org.created_at, "created_at": updated_org.created_at,
"updated_at": updated_org.updated_at, "updated_at": updated_org.updated_at,
"member_count": await organization_crud.get_member_count( "member_count": await organization_service.get_member_count(
db, organization_id=updated_org.id db, organization_id=updated_org.id
), ),
} }
return OrganizationResponse(**org_dict) return OrganizationResponse(**org_dict)
except NotFoundError: # pragma: no cover - See above
raise
except Exception as e: except Exception as e:
logger.error(f"Error updating organization: {e!s}", exc_info=True) logger.exception("Error updating organization: %s", e)
raise raise

View File

@@ -17,10 +17,10 @@ from app.api.dependencies.auth import get_current_user
from app.core.auth import decode_token from app.core.auth import decode_token
from app.core.database import get_db from app.core.database import get_db
from app.core.exceptions import AuthorizationError, ErrorCode, NotFoundError from app.core.exceptions import AuthorizationError, ErrorCode, NotFoundError
from app.crud.session import session as session_crud
from app.models.user import User from app.models.user import User
from app.schemas.common import MessageResponse from app.schemas.common import MessageResponse
from app.schemas.sessions import SessionListResponse, SessionResponse from app.schemas.sessions import SessionListResponse, SessionResponse
from app.services.session_service import session_service
router = APIRouter() router = APIRouter()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -60,7 +60,7 @@ async def list_my_sessions(
""" """
try: try:
# Get all active sessions for user # Get all active sessions for user
sessions = await session_crud.get_user_sessions( sessions = await session_service.get_user_sessions(
db, user_id=str(current_user.id), active_only=True db, user_id=str(current_user.id), active_only=True
) )
@@ -74,9 +74,7 @@ async def list_my_sessions(
# For now, we'll mark current based on most recent activity # For now, we'll mark current based on most recent activity
except Exception as e: except Exception as e:
# Optional token parsing - silently ignore failures # Optional token parsing - silently ignore failures
logger.debug( logger.debug("Failed to decode access token for session marking: %s", e)
f"Failed to decode access token for session marking: {e!s}"
)
# Convert to response format # Convert to response format
session_responses = [] session_responses = []
@@ -98,7 +96,7 @@ async def list_my_sessions(
session_responses.append(session_response) session_responses.append(session_response)
logger.info( logger.info(
f"User {current_user.id} listed {len(session_responses)} active sessions" "User %s listed %s active sessions", current_user.id, len(session_responses)
) )
return SessionListResponse( return SessionListResponse(
@@ -106,9 +104,7 @@ async def list_my_sessions(
) )
except Exception as e: except Exception as e:
logger.error( logger.exception("Error listing sessions for user %s: %s", current_user.id, e)
f"Error listing sessions for user {current_user.id}: {e!s}", exc_info=True
)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve sessions", detail="Failed to retrieve sessions",
@@ -150,7 +146,7 @@ async def revoke_session(
""" """
try: try:
# Get the session # Get the session
session = await session_crud.get(db, id=str(session_id)) session = await session_service.get_session(db, str(session_id))
if not session: if not session:
raise NotFoundError( raise NotFoundError(
@@ -161,8 +157,10 @@ async def revoke_session(
# Verify session belongs to current user # Verify session belongs to current user
if str(session.user_id) != str(current_user.id): if str(session.user_id) != str(current_user.id):
logger.warning( logger.warning(
f"User {current_user.id} attempted to revoke session {session_id} " "User %s attempted to revoke session %s belonging to user %s",
f"belonging to user {session.user_id}" current_user.id,
session_id,
session.user_id,
) )
raise AuthorizationError( raise AuthorizationError(
message="You can only revoke your own sessions", message="You can only revoke your own sessions",
@@ -170,11 +168,13 @@ async def revoke_session(
) )
# Deactivate the session # Deactivate the session
await session_crud.deactivate(db, session_id=str(session_id)) await session_service.deactivate(db, session_id=str(session_id))
logger.info( logger.info(
f"User {current_user.id} revoked session {session_id} " "User %s revoked session %s (%s)",
f"({session.device_name})" current_user.id,
session_id,
session.device_name,
) )
return MessageResponse( return MessageResponse(
@@ -185,7 +185,7 @@ async def revoke_session(
except (NotFoundError, AuthorizationError): except (NotFoundError, AuthorizationError):
raise raise
except Exception as e: except Exception as e:
logger.error(f"Error revoking session {session_id}: {e!s}", exc_info=True) logger.exception("Error revoking session %s: %s", session_id, e)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to revoke session", detail="Failed to revoke session",
@@ -224,12 +224,12 @@ async def cleanup_expired_sessions(
""" """
try: try:
# Use optimized bulk DELETE instead of N individual deletes # Use optimized bulk DELETE instead of N individual deletes
deleted_count = await session_crud.cleanup_expired_for_user( deleted_count = await session_service.cleanup_expired_for_user(
db, user_id=str(current_user.id) db, user_id=str(current_user.id)
) )
logger.info( logger.info(
f"User {current_user.id} cleaned up {deleted_count} expired sessions" "User %s cleaned up %s expired sessions", current_user.id, deleted_count
) )
return MessageResponse( return MessageResponse(
@@ -237,9 +237,8 @@ async def cleanup_expired_sessions(
) )
except Exception as e: except Exception as e:
logger.error( logger.exception(
f"Error cleaning up sessions for user {current_user.id}: {e!s}", "Error cleaning up sessions for user %s: %s", current_user.id, e
exc_info=True,
) )
await db.rollback() await db.rollback()
raise HTTPException( raise HTTPException(

View File

@@ -1,5 +1,5 @@
""" """
User management endpoints for CRUD operations. User management endpoints for database operations.
""" """
import logging import logging
@@ -13,8 +13,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.auth import get_current_superuser, get_current_user from app.api.dependencies.auth import get_current_superuser, get_current_user
from app.core.database import get_db from app.core.database import get_db
from app.core.exceptions import AuthorizationError, ErrorCode, NotFoundError from app.core.exceptions import AuthorizationError, ErrorCode
from app.crud.user import user as user_crud
from app.models.user import User from app.models.user import User
from app.schemas.common import ( from app.schemas.common import (
MessageResponse, MessageResponse,
@@ -25,6 +24,7 @@ from app.schemas.common import (
) )
from app.schemas.users import PasswordChange, UserResponse, UserUpdate from app.schemas.users import PasswordChange, UserResponse, UserUpdate
from app.services.auth_service import AuthenticationError, AuthService from app.services.auth_service import AuthenticationError, AuthService
from app.services.user_service import user_service
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -71,7 +71,7 @@ async def list_users(
filters["is_superuser"] = is_superuser filters["is_superuser"] = is_superuser
# Get paginated users with total count # Get paginated users with total count
users, total = await user_crud.get_multi_with_total( users, total = await user_service.list_users(
db, db,
skip=pagination.offset, skip=pagination.offset,
limit=pagination.limit, limit=pagination.limit,
@@ -90,7 +90,7 @@ async def list_users(
return PaginatedResponse(data=users, pagination=pagination_meta) return PaginatedResponse(data=users, pagination=pagination_meta)
except Exception as e: except Exception as e:
logger.error(f"Error listing users: {e!s}", exc_info=True) logger.exception("Error listing users: %s", e)
raise raise
@@ -107,7 +107,9 @@ async def list_users(
""", """,
operation_id="get_current_user_profile", operation_id="get_current_user_profile",
) )
def get_current_user_profile(current_user: User = Depends(get_current_user)) -> Any: async def get_current_user_profile(
current_user: User = Depends(get_current_user),
) -> Any:
"""Get current user's profile.""" """Get current user's profile."""
return current_user return current_user
@@ -138,18 +140,16 @@ async def update_current_user(
Users cannot elevate their own permissions (protected by UserUpdate schema validator). Users cannot elevate their own permissions (protected by UserUpdate schema validator).
""" """
try: try:
updated_user = await user_crud.update( updated_user = await user_service.update_user(
db, db_obj=current_user, obj_in=user_update db, user=current_user, obj_in=user_update
) )
logger.info(f"User {current_user.id} updated their profile") logger.info("User %s updated their profile", current_user.id)
return updated_user return updated_user
except ValueError as e: except ValueError as e:
logger.error(f"Error updating user {current_user.id}: {e!s}") logger.error("Error updating user %s: %s", current_user.id, e)
raise raise
except Exception as e: except Exception as e:
logger.error( logger.exception("Unexpected error updating user %s: %s", current_user.id, e)
f"Unexpected error updating user {current_user.id}: {e!s}", exc_info=True
)
raise raise
@@ -182,7 +182,9 @@ async def get_user_by_id(
# Check permissions # Check permissions
if str(user_id) != str(current_user.id) and not current_user.is_superuser: if str(user_id) != str(current_user.id) and not current_user.is_superuser:
logger.warning( logger.warning(
f"User {current_user.id} attempted to access user {user_id} without permission" "User %s attempted to access user %s without permission",
current_user.id,
user_id,
) )
raise AuthorizationError( raise AuthorizationError(
message="Not enough permissions to view this user", message="Not enough permissions to view this user",
@@ -190,13 +192,7 @@ async def get_user_by_id(
) )
# Get user # Get user
user = await user_crud.get(db, id=str(user_id)) user = await user_service.get_user(db, str(user_id))
if not user:
raise NotFoundError(
message=f"User with id {user_id} not found",
error_code=ErrorCode.USER_NOT_FOUND,
)
return user return user
@@ -233,7 +229,9 @@ async def update_user(
if not is_own_profile and not current_user.is_superuser: if not is_own_profile and not current_user.is_superuser:
logger.warning( logger.warning(
f"User {current_user.id} attempted to update user {user_id} without permission" "User %s attempted to update user %s without permission",
current_user.id,
user_id,
) )
raise AuthorizationError( raise AuthorizationError(
message="Not enough permissions to update this user", message="Not enough permissions to update this user",
@@ -241,22 +239,17 @@ async def update_user(
) )
# Get user # Get user
user = await user_crud.get(db, id=str(user_id)) user = await user_service.get_user(db, str(user_id))
if not user:
raise NotFoundError(
message=f"User with id {user_id} not found",
error_code=ErrorCode.USER_NOT_FOUND,
)
try: try:
updated_user = await user_crud.update(db, db_obj=user, obj_in=user_update) updated_user = await user_service.update_user(db, user=user, obj_in=user_update)
logger.info(f"User {user_id} updated by {current_user.id}") logger.info("User %s updated by %s", user_id, current_user.id)
return updated_user return updated_user
except ValueError as e: except ValueError as e:
logger.error(f"Error updating user {user_id}: {e!s}") logger.error("Error updating user %s: %s", user_id, e)
raise raise
except Exception as e: except Exception as e:
logger.error(f"Unexpected error updating user {user_id}: {e!s}", exc_info=True) logger.exception("Unexpected error updating user %s: %s", user_id, e)
raise raise
@@ -296,19 +289,19 @@ async def change_current_user_password(
) )
if success: if success:
logger.info(f"User {current_user.id} changed their password") logger.info("User %s changed their password", current_user.id)
return MessageResponse( return MessageResponse(
success=True, message="Password changed successfully" success=True, message="Password changed successfully"
) )
except AuthenticationError as e: except AuthenticationError as e:
logger.warning( logger.warning(
f"Failed password change attempt for user {current_user.id}: {e!s}" "Failed password change attempt for user %s: %s", current_user.id, e
) )
raise AuthorizationError( raise AuthorizationError(
message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS
) )
except Exception as e: except Exception as e:
logger.error(f"Error changing password for user {current_user.id}: {e!s}") logger.error("Error changing password for user %s: %s", current_user.id, e)
raise raise
@@ -346,24 +339,19 @@ async def delete_user(
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS, error_code=ErrorCode.INSUFFICIENT_PERMISSIONS,
) )
# Get user # Get user (raises NotFoundError if not found)
user = await user_crud.get(db, id=str(user_id)) await user_service.get_user(db, str(user_id))
if not user:
raise NotFoundError(
message=f"User with id {user_id} not found",
error_code=ErrorCode.USER_NOT_FOUND,
)
try: try:
# Use soft delete instead of hard delete # Use soft delete instead of hard delete
await user_crud.soft_delete(db, id=str(user_id)) await user_service.soft_delete_user(db, str(user_id))
logger.info(f"User {user_id} soft-deleted by {current_user.id}") logger.info("User %s soft-deleted by %s", user_id, current_user.id)
return MessageResponse( return MessageResponse(
success=True, message=f"User {user_id} deleted successfully" success=True, message=f"User {user_id} deleted successfully"
) )
except ValueError as e: except ValueError as e:
logger.error(f"Error deleting user {user_id}: {e!s}") logger.error("Error deleting user %s: %s", user_id, e)
raise raise
except Exception as e: except Exception as e:
logger.error(f"Unexpected error deleting user {user_id}: {e!s}", exc_info=True) logger.exception("Unexpected error deleting user %s: %s", user_id, e)
raise raise

View File

@@ -1,23 +1,21 @@
import asyncio import asyncio
import logging
import uuid import uuid
from datetime import UTC, datetime, timedelta from datetime import UTC, datetime, timedelta
from functools import partial from functools import partial
from typing import Any from typing import Any
from jose import JWTError, jwt import bcrypt
from passlib.context import CryptContext import jwt
from jwt.exceptions import (
ExpiredSignatureError,
InvalidTokenError,
MissingRequiredClaimError,
)
from pydantic import ValidationError from pydantic import ValidationError
from app.core.config import settings from app.core.config import settings
from app.schemas.users import TokenData, TokenPayload from app.schemas.users import TokenData, TokenPayload
# Suppress passlib bcrypt warnings about ident
logging.getLogger("passlib").setLevel(logging.ERROR)
# Password hashing context
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
# Custom exceptions for auth # Custom exceptions for auth
class AuthError(Exception): class AuthError(Exception):
@@ -37,13 +35,16 @@ class TokenMissingClaimError(AuthError):
def verify_password(plain_password: str, hashed_password: str) -> bool: def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a password against a hash.""" """Verify a password against a bcrypt hash."""
return pwd_context.verify(plain_password, hashed_password) return bcrypt.checkpw(
plain_password.encode("utf-8"), hashed_password.encode("utf-8")
)
def get_password_hash(password: str) -> str: def get_password_hash(password: str) -> str:
"""Generate a password hash.""" """Generate a bcrypt password hash."""
return pwd_context.hash(password) salt = bcrypt.gensalt()
return bcrypt.hashpw(password.encode("utf-8"), salt).decode("utf-8")
async def verify_password_async(plain_password: str, hashed_password: str) -> bool: async def verify_password_async(plain_password: str, hashed_password: str) -> bool:
@@ -60,9 +61,9 @@ async def verify_password_async(plain_password: str, hashed_password: str) -> bo
Returns: Returns:
True if password matches, False otherwise True if password matches, False otherwise
""" """
loop = asyncio.get_event_loop() loop = asyncio.get_running_loop()
return await loop.run_in_executor( return await loop.run_in_executor(
None, partial(pwd_context.verify, plain_password, hashed_password) None, partial(verify_password, plain_password, hashed_password)
) )
@@ -80,8 +81,8 @@ async def get_password_hash_async(password: str) -> str:
Returns: Returns:
Hashed password string Hashed password string
""" """
loop = asyncio.get_event_loop() loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, pwd_context.hash, password) return await loop.run_in_executor(None, get_password_hash, password)
def create_access_token( def create_access_token(
@@ -121,11 +122,7 @@ def create_access_token(
to_encode.update(claims) to_encode.update(claims)
# Create the JWT # Create the JWT
encoded_jwt = jwt.encode( return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
)
return encoded_jwt
def create_refresh_token( def create_refresh_token(
@@ -154,11 +151,7 @@ def create_refresh_token(
"type": "refresh", "type": "refresh",
} }
encoded_jwt = jwt.encode( return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
)
return encoded_jwt
def decode_token(token: str, verify_type: str | None = None) -> TokenPayload: def decode_token(token: str, verify_type: str | None = None) -> TokenPayload:
@@ -198,7 +191,7 @@ def decode_token(token: str, verify_type: str | None = None) -> TokenPayload:
# Reject weak or unexpected algorithms # Reject weak or unexpected algorithms
# NOTE: These are defensive checks that provide defense-in-depth. # NOTE: These are defensive checks that provide defense-in-depth.
# The python-jose library rejects these tokens BEFORE we reach here, # PyJWT rejects these tokens BEFORE we reach here,
# but we keep these checks in case the library changes or is misconfigured. # but we keep these checks in case the library changes or is misconfigured.
# Coverage: Marked as pragma since library catches first (see tests/core/test_auth_security.py) # Coverage: Marked as pragma since library catches first (see tests/core/test_auth_security.py)
if token_algorithm == "NONE": # pragma: no cover if token_algorithm == "NONE": # pragma: no cover
@@ -219,10 +212,11 @@ def decode_token(token: str, verify_type: str | None = None) -> TokenPayload:
token_data = TokenPayload(**payload) token_data = TokenPayload(**payload)
return token_data return token_data
except JWTError as e: except ExpiredSignatureError:
# Check if the error is due to an expired token raise TokenExpiredError("Token has expired")
if "expired" in str(e).lower(): except MissingRequiredClaimError as e:
raise TokenExpiredError("Token has expired") raise TokenMissingClaimError(f"Token missing required claim: {e}")
except InvalidTokenError:
raise TokenInvalidError("Invalid authentication token") raise TokenInvalidError("Invalid authentication token")
except ValidationError: except ValidationError:
raise TokenInvalidError("Invalid token payload") raise TokenInvalidError("Invalid token payload")

View File

@@ -5,7 +5,7 @@ from pydantic_settings import BaseSettings
class Settings(BaseSettings): class Settings(BaseSettings):
PROJECT_NAME: str = "App" PROJECT_NAME: str = "PragmaStack"
VERSION: str = "1.0.0" VERSION: str = "1.0.0"
API_V1_STR: str = "/api/v1" API_V1_STR: str = "/api/v1"
@@ -14,6 +14,10 @@ class Settings(BaseSettings):
default="development", default="development",
description="Environment: development, staging, or production", description="Environment: development, staging, or production",
) )
DEMO_MODE: bool = Field(
default=False,
description="Enable demo mode (relaxed security, demo users)",
)
# Security: Content Security Policy # Security: Content Security Policy
# Set to False to disable CSP entirely (not recommended) # Set to False to disable CSP entirely (not recommended)
@@ -72,6 +76,60 @@ class Settings(BaseSettings):
description="Frontend application URL for email links", description="Frontend application URL for email links",
) )
# OAuth Configuration
OAUTH_ENABLED: bool = Field(
default=False,
description="Enable OAuth authentication (social login)",
)
OAUTH_AUTO_LINK_BY_EMAIL: bool = Field(
default=True,
description="Automatically link OAuth accounts to existing users with matching email",
)
OAUTH_STATE_EXPIRE_MINUTES: int = Field(
default=10,
description="OAuth state parameter expiration time in minutes",
)
# Google OAuth
OAUTH_GOOGLE_CLIENT_ID: str | None = Field(
default=None,
description="Google OAuth client ID from Google Cloud Console",
)
OAUTH_GOOGLE_CLIENT_SECRET: str | None = Field(
default=None,
description="Google OAuth client secret from Google Cloud Console",
)
# GitHub OAuth
OAUTH_GITHUB_CLIENT_ID: str | None = Field(
default=None,
description="GitHub OAuth client ID from GitHub Developer Settings",
)
OAUTH_GITHUB_CLIENT_SECRET: str | None = Field(
default=None,
description="GitHub OAuth client secret from GitHub Developer Settings",
)
# OAuth Provider Mode (for MCP clients - skeleton)
OAUTH_PROVIDER_ENABLED: bool = Field(
default=False,
description="Enable OAuth provider mode (act as authorization server for MCP clients)",
)
OAUTH_ISSUER: str = Field(
default="http://localhost:8000",
description="OAuth issuer URL (your API base URL)",
)
@property
def enabled_oauth_providers(self) -> list[str]:
"""Get list of enabled OAuth providers based on configured credentials."""
providers = []
if self.OAUTH_GOOGLE_CLIENT_ID and self.OAUTH_GOOGLE_CLIENT_SECRET:
providers.append("google")
if self.OAUTH_GITHUB_CLIENT_ID and self.OAUTH_GITHUB_CLIENT_SECRET:
providers.append("github")
return providers
# Admin user # Admin user
FIRST_SUPERUSER_EMAIL: str | None = Field( FIRST_SUPERUSER_EMAIL: str | None = Field(
default=None, description="Email for first superuser account" default=None, description="Email for first superuser account"
@@ -110,11 +168,21 @@ class Settings(BaseSettings):
@field_validator("FIRST_SUPERUSER_PASSWORD") @field_validator("FIRST_SUPERUSER_PASSWORD")
@classmethod @classmethod
def validate_superuser_password(cls, v: str | None) -> str | None: def validate_superuser_password(cls, v: str | None, info) -> str | None:
"""Validate superuser password strength.""" """Validate superuser password strength."""
if v is None: if v is None:
return v return v
# Get environment from values if available
values_data = info.data if info.data else {}
demo_mode = values_data.get("DEMO_MODE", False)
if demo_mode:
# In demo mode, allow specific weak passwords for demo accounts
demo_passwords = {"Demo123!", "Admin123!"}
if v in demo_passwords:
return v
if len(v) < 12: if len(v) < 12:
raise ValueError("FIRST_SUPERUSER_PASSWORD must be at least 12 characters") raise ValueError("FIRST_SUPERUSER_PASSWORD must be at least 12 characters")

View File

@@ -128,8 +128,8 @@ async def async_transaction_scope() -> AsyncGenerator[AsyncSession, None]:
Usage: Usage:
async with async_transaction_scope() as db: async with async_transaction_scope() as db:
user = await user_crud.create(db, obj_in=user_create) user = await user_repo.create(db, obj_in=user_create)
profile = await profile_crud.create(db, obj_in=profile_create) profile = await profile_repo.create(db, obj_in=profile_create)
# Both operations committed together # Both operations committed together
""" """
async with SessionLocal() as session: async with SessionLocal() as session:
@@ -139,7 +139,7 @@ async def async_transaction_scope() -> AsyncGenerator[AsyncSession, None]:
logger.debug("Async transaction committed successfully") logger.debug("Async transaction committed successfully")
except Exception as e: except Exception as e:
await session.rollback() await session.rollback()
logger.error(f"Async transaction failed, rolling back: {e!s}") logger.error("Async transaction failed, rolling back: %s", e)
raise raise
finally: finally:
await session.close() await session.close()
@@ -155,7 +155,7 @@ async def check_async_database_health() -> bool:
await db.execute(text("SELECT 1")) await db.execute(text("SELECT 1"))
return True return True
except Exception as e: except Exception as e:
logger.error(f"Async database health check failed: {e!s}") logger.error("Async database health check failed: %s", e)
return False return False

View File

@@ -0,0 +1,366 @@
{
"organizations": [
{
"name": "Acme Corp",
"slug": "acme-corp",
"description": "A leading provider of coyote-catching equipment."
},
{
"name": "Globex Corporation",
"slug": "globex",
"description": "We own the East Coast."
},
{
"name": "Soylent Corp",
"slug": "soylent",
"description": "Making food for the future."
},
{
"name": "Initech",
"slug": "initech",
"description": "Software for the soul."
},
{
"name": "Umbrella Corporation",
"slug": "umbrella",
"description": "Our business is life itself."
},
{
"name": "Massive Dynamic",
"slug": "massive-dynamic",
"description": "What don't we do?"
}
],
"users": [
{
"email": "demo@example.com",
"password": "DemoPass1234!",
"first_name": "Demo",
"last_name": "User",
"is_superuser": false,
"organization_slug": "acme-corp",
"role": "member",
"is_active": true
},
{
"email": "alice@acme.com",
"password": "Demo123!",
"first_name": "Alice",
"last_name": "Smith",
"is_superuser": false,
"organization_slug": "acme-corp",
"role": "admin",
"is_active": true
},
{
"email": "bob@acme.com",
"password": "Demo123!",
"first_name": "Bob",
"last_name": "Jones",
"is_superuser": false,
"organization_slug": "acme-corp",
"role": "member",
"is_active": true
},
{
"email": "charlie@acme.com",
"password": "Demo123!",
"first_name": "Charlie",
"last_name": "Brown",
"is_superuser": false,
"organization_slug": "acme-corp",
"role": "member",
"is_active": false
},
{
"email": "diana@acme.com",
"password": "Demo123!",
"first_name": "Diana",
"last_name": "Prince",
"is_superuser": false,
"organization_slug": "acme-corp",
"role": "member",
"is_active": true
},
{
"email": "carol@globex.com",
"password": "Demo123!",
"first_name": "Carol",
"last_name": "Williams",
"is_superuser": false,
"organization_slug": "globex",
"role": "owner",
"is_active": true
},
{
"email": "dan@globex.com",
"password": "Demo123!",
"first_name": "Dan",
"last_name": "Miller",
"is_superuser": false,
"organization_slug": "globex",
"role": "member",
"is_active": true
},
{
"email": "ellen@globex.com",
"password": "Demo123!",
"first_name": "Ellen",
"last_name": "Ripley",
"is_superuser": false,
"organization_slug": "globex",
"role": "member",
"is_active": true
},
{
"email": "fred@globex.com",
"password": "Demo123!",
"first_name": "Fred",
"last_name": "Flintstone",
"is_superuser": false,
"organization_slug": "globex",
"role": "member",
"is_active": true
},
{
"email": "dave@soylent.com",
"password": "Demo123!",
"first_name": "Dave",
"last_name": "Brown",
"is_superuser": false,
"organization_slug": "soylent",
"role": "member",
"is_active": true
},
{
"email": "gina@soylent.com",
"password": "Demo123!",
"first_name": "Gina",
"last_name": "Torres",
"is_superuser": false,
"organization_slug": "soylent",
"role": "member",
"is_active": true
},
{
"email": "harry@soylent.com",
"password": "Demo123!",
"first_name": "Harry",
"last_name": "Potter",
"is_superuser": false,
"organization_slug": "soylent",
"role": "admin",
"is_active": true
},
{
"email": "eve@initech.com",
"password": "Demo123!",
"first_name": "Eve",
"last_name": "Davis",
"is_superuser": false,
"organization_slug": "initech",
"role": "admin",
"is_active": true
},
{
"email": "iris@initech.com",
"password": "Demo123!",
"first_name": "Iris",
"last_name": "West",
"is_superuser": false,
"organization_slug": "initech",
"role": "member",
"is_active": true
},
{
"email": "jack@initech.com",
"password": "Demo123!",
"first_name": "Jack",
"last_name": "Sparrow",
"is_superuser": false,
"organization_slug": "initech",
"role": "member",
"is_active": false
},
{
"email": "frank@umbrella.com",
"password": "Demo123!",
"first_name": "Frank",
"last_name": "Miller",
"is_superuser": false,
"organization_slug": "umbrella",
"role": "member",
"is_active": true
},
{
"email": "george@umbrella.com",
"password": "Demo123!",
"first_name": "George",
"last_name": "Costanza",
"is_superuser": false,
"organization_slug": "umbrella",
"role": "member",
"is_active": false
},
{
"email": "kate@umbrella.com",
"password": "Demo123!",
"first_name": "Kate",
"last_name": "Bishop",
"is_superuser": false,
"organization_slug": "umbrella",
"role": "member",
"is_active": true
},
{
"email": "leo@massive.com",
"password": "Demo123!",
"first_name": "Leo",
"last_name": "Messi",
"is_superuser": false,
"organization_slug": "massive-dynamic",
"role": "owner",
"is_active": true
},
{
"email": "mary@massive.com",
"password": "Demo123!",
"first_name": "Mary",
"last_name": "Jane",
"is_superuser": false,
"organization_slug": "massive-dynamic",
"role": "member",
"is_active": true
},
{
"email": "nathan@massive.com",
"password": "Demo123!",
"first_name": "Nathan",
"last_name": "Drake",
"is_superuser": false,
"organization_slug": "massive-dynamic",
"role": "member",
"is_active": true
},
{
"email": "olivia@massive.com",
"password": "Demo123!",
"first_name": "Olivia",
"last_name": "Dunham",
"is_superuser": false,
"organization_slug": "massive-dynamic",
"role": "admin",
"is_active": true
},
{
"email": "peter@massive.com",
"password": "Demo123!",
"first_name": "Peter",
"last_name": "Parker",
"is_superuser": false,
"organization_slug": "massive-dynamic",
"role": "member",
"is_active": true
},
{
"email": "quinn@massive.com",
"password": "Demo123!",
"first_name": "Quinn",
"last_name": "Mallory",
"is_superuser": false,
"organization_slug": "massive-dynamic",
"role": "member",
"is_active": true
},
{
"email": "grace@example.com",
"password": "Demo123!",
"first_name": "Grace",
"last_name": "Hopper",
"is_superuser": false,
"organization_slug": null,
"role": null,
"is_active": true
},
{
"email": "heidi@example.com",
"password": "Demo123!",
"first_name": "Heidi",
"last_name": "Klum",
"is_superuser": false,
"organization_slug": null,
"role": null,
"is_active": true
},
{
"email": "ivan@example.com",
"password": "Demo123!",
"first_name": "Ivan",
"last_name": "Drago",
"is_superuser": false,
"organization_slug": null,
"role": null,
"is_active": false
},
{
"email": "rachel@example.com",
"password": "Demo123!",
"first_name": "Rachel",
"last_name": "Green",
"is_superuser": false,
"organization_slug": null,
"role": null,
"is_active": true
},
{
"email": "sam@example.com",
"password": "Demo123!",
"first_name": "Sam",
"last_name": "Wilson",
"is_superuser": false,
"organization_slug": null,
"role": null,
"is_active": true
},
{
"email": "tony@example.com",
"password": "Demo123!",
"first_name": "Tony",
"last_name": "Stark",
"is_superuser": false,
"organization_slug": null,
"role": null,
"is_active": true
},
{
"email": "una@example.com",
"password": "Demo123!",
"first_name": "Una",
"last_name": "Chin-Riley",
"is_superuser": false,
"organization_slug": null,
"role": null,
"is_active": false
},
{
"email": "victor@example.com",
"password": "Demo123!",
"first_name": "Victor",
"last_name": "Von Doom",
"is_superuser": false,
"organization_slug": null,
"role": null,
"is_active": true
},
{
"email": "wanda@example.com",
"password": "Demo123!",
"first_name": "Wanda",
"last_name": "Maximoff",
"is_superuser": false,
"organization_slug": null,
"role": null,
"is_active": true
}
]
}

View File

@@ -143,8 +143,11 @@ async def api_exception_handler(request: Request, exc: APIException) -> JSONResp
Returns a standardized error response with error code and message. Returns a standardized error response with error code and message.
""" """
logger.warning( logger.warning(
f"API exception: {exc.error_code} - {exc.message} " "API exception: %s - %s (status: %s, path: %s)",
f"(status: {exc.status_code}, path: {request.url.path})" exc.error_code,
exc.message,
exc.status_code,
request.url.path,
) )
error_response = ErrorResponse( error_response = ErrorResponse(
@@ -186,7 +189,9 @@ async def validation_exception_handler(
) )
) )
logger.warning(f"Validation error: {len(errors)} errors (path: {request.url.path})") logger.warning(
"Validation error: %s errors (path: %s)", len(errors), request.url.path
)
error_response = ErrorResponse(errors=errors) error_response = ErrorResponse(errors=errors)
@@ -218,11 +223,14 @@ async def http_exception_handler(request: Request, exc: HTTPException) -> JSONRe
) )
logger.warning( logger.warning(
f"HTTP exception: {exc.status_code} - {exc.detail} (path: {request.url.path})" "HTTP exception: %s - %s (path: %s)",
exc.status_code,
exc.detail,
request.url.path,
) )
error_response = ErrorResponse( error_response = ErrorResponse(
errors=[ErrorDetail(code=error_code, message=str(exc.detail))] errors=[ErrorDetail(code=error_code, message=str(exc.detail), field=None)]
) )
return JSONResponse( return JSONResponse(
@@ -239,10 +247,11 @@ async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONR
Logs the full exception and returns a generic error response to avoid Logs the full exception and returns a generic error response to avoid
leaking sensitive information in production. leaking sensitive information in production.
""" """
logger.error( logger.exception(
f"Unhandled exception: {type(exc).__name__} - {exc!s} " "Unhandled exception: %s - %s (path: %s)",
f"(path: {request.url.path})", type(exc).__name__,
exc_info=True, exc,
request.url.path,
) )
# In production, don't expose internal error details # In production, don't expose internal error details
@@ -254,7 +263,7 @@ async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONR
message = f"{type(exc).__name__}: {exc!s}" message = f"{type(exc).__name__}: {exc!s}"
error_response = ErrorResponse( error_response = ErrorResponse(
errors=[ErrorDetail(code=ErrorCode.INTERNAL_ERROR, message=message)] errors=[ErrorDetail(code=ErrorCode.INTERNAL_ERROR, message=message, field=None)]
) )
return JSONResponse( return JSONResponse(

View File

@@ -0,0 +1,26 @@
"""
Custom exceptions for the repository layer.
These exceptions allow services and routes to handle database-level errors
with proper semantics, without leaking SQLAlchemy internals.
"""
class RepositoryError(Exception):
"""Base for all repository-layer errors."""
class DuplicateEntryError(RepositoryError):
"""Raised on unique constraint violations. Maps to HTTP 409 Conflict."""
class IntegrityConstraintError(RepositoryError):
"""Raised on FK or check constraint violations."""
class RecordNotFoundError(RepositoryError):
"""Raised when an expected record doesn't exist."""
class InvalidInputError(RepositoryError):
"""Raised on bad pagination params, invalid UUIDs, or other invalid inputs."""

View File

@@ -1,6 +0,0 @@
# app/crud/__init__.py
from .organization import organization
from .session import session as session_crud
from .user import user
__all__ = ["organization", "session_crud", "user"]

View File

@@ -6,12 +6,20 @@ Creates the first superuser if configured and doesn't already exist.
""" """
import asyncio import asyncio
import json
import logging import logging
import random
from datetime import UTC, datetime, timedelta
from pathlib import Path
from sqlalchemy import select, text
from app.core.config import settings from app.core.config import settings
from app.core.database import SessionLocal, engine from app.core.database import SessionLocal, engine
from app.crud.user import user as user_crud from app.models.organization import Organization
from app.models.user import User from app.models.user import User
from app.models.user_organization import UserOrganization
from app.repositories.user import user_repo as user_repo
from app.schemas.users import UserCreate from app.schemas.users import UserCreate
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -26,21 +34,27 @@ async def init_db() -> User | None:
""" """
# Use default values if not set in environment variables # Use default values if not set in environment variables
superuser_email = settings.FIRST_SUPERUSER_EMAIL or "admin@example.com" superuser_email = settings.FIRST_SUPERUSER_EMAIL or "admin@example.com"
superuser_password = settings.FIRST_SUPERUSER_PASSWORD or "AdminPassword123!"
default_password = "AdminPassword123!"
if settings.DEMO_MODE:
default_password = "AdminPass1234!"
superuser_password = settings.FIRST_SUPERUSER_PASSWORD or default_password
if not settings.FIRST_SUPERUSER_EMAIL or not settings.FIRST_SUPERUSER_PASSWORD: if not settings.FIRST_SUPERUSER_EMAIL or not settings.FIRST_SUPERUSER_PASSWORD:
logger.warning( logger.warning(
"First superuser credentials not configured in settings. " "First superuser credentials not configured in settings. "
f"Using defaults: {superuser_email}" "Using defaults: %s",
superuser_email,
) )
async with SessionLocal() as session: async with SessionLocal() as session:
try: try:
# Check if superuser already exists # Check if superuser already exists
existing_user = await user_crud.get_by_email(session, email=superuser_email) existing_user = await user_repo.get_by_email(session, email=superuser_email)
if existing_user: if existing_user:
logger.info(f"Superuser already exists: {existing_user.email}") logger.info("Superuser already exists: %s", existing_user.email)
return existing_user return existing_user
# Create superuser if doesn't exist # Create superuser if doesn't exist
@@ -52,19 +66,143 @@ async def init_db() -> User | None:
is_superuser=True, is_superuser=True,
) )
user = await user_crud.create(session, obj_in=user_in) user = await user_repo.create(session, obj_in=user_in)
await session.commit() await session.commit()
await session.refresh(user) await session.refresh(user)
logger.info(f"Created first superuser: {user.email}") logger.info("Created first superuser: %s", user.email)
# Create demo data if in demo mode
if settings.DEMO_MODE:
await load_demo_data(session)
return user return user
except Exception as e: except Exception as e:
await session.rollback() await session.rollback()
logger.error(f"Error initializing database: {e}") logger.error("Error initializing database: %s", e)
raise raise
def _load_json_file(path: Path):
with open(path) as f:
return json.load(f)
async def load_demo_data(session):
"""Load demo data from JSON file."""
demo_data_path = Path(__file__).parent / "core" / "demo_data.json"
if not demo_data_path.exists():
logger.warning("Demo data file not found: %s", demo_data_path)
return
try:
# Use asyncio.to_thread to avoid blocking the event loop
data = await asyncio.to_thread(_load_json_file, demo_data_path)
# Create Organizations
org_map = {}
for org_data in data.get("organizations", []):
# Check if org exists
result = await session.execute(
text("SELECT * FROM organizations WHERE slug = :slug"),
{"slug": org_data["slug"]},
)
existing_org = result.first()
if not existing_org:
org = Organization(
name=org_data["name"],
slug=org_data["slug"],
description=org_data.get("description"),
is_active=True,
)
session.add(org)
await session.flush() # Flush to get ID
org_map[org.slug] = org
logger.info("Created demo organization: %s", org.name)
else:
# We can't easily get the ORM object from raw SQL result for map without querying again or mapping
# So let's just query it properly if we need it for relationships
# But for simplicity in this script, let's just assume we created it or it exists.
# To properly map for users, we need the ID.
# Let's use a simpler approach: just try to create, if slug conflict, skip.
pass
# Re-query all orgs to build map for users
result = await session.execute(select(Organization))
orgs = result.scalars().all()
org_map = {org.slug: org for org in orgs}
# Create Users
for user_data in data.get("users", []):
existing_user = await user_repo.get_by_email(
session, email=user_data["email"]
)
if not existing_user:
# Create user
user_in = UserCreate(
email=user_data["email"],
password=user_data["password"],
first_name=user_data["first_name"],
last_name=user_data["last_name"],
is_superuser=user_data["is_superuser"],
is_active=user_data.get("is_active", True),
)
user = await user_repo.create(session, obj_in=user_in)
# Randomize created_at for demo data (last 30 days)
# This makes the charts look more realistic
days_ago = random.randint(0, 30) # noqa: S311
random_time = datetime.now(UTC) - timedelta(days=days_ago)
# Add some random hours/minutes variation
random_time = random_time.replace(
hour=random.randint(0, 23), # noqa: S311
minute=random.randint(0, 59), # noqa: S311
)
# Update the timestamp and is_active directly in the database
# We do this to ensure the values are persisted correctly
await session.execute(
text(
"UPDATE users SET created_at = :created_at, is_active = :is_active WHERE id = :user_id"
),
{
"created_at": random_time,
"is_active": user_data.get("is_active", True),
"user_id": user.id,
},
)
logger.info(
"Created demo user: %s (created %s days ago, active=%s)",
user.email,
days_ago,
user_data.get("is_active", True),
)
# Add to organization if specified
org_slug = user_data.get("organization_slug")
role = user_data.get("role")
if org_slug and org_slug in org_map and role:
org = org_map[org_slug]
# Check if membership exists (it shouldn't for new user)
member = UserOrganization(
user_id=user.id, organization_id=org.id, role=role
)
session.add(member)
logger.info("Added %s to %s as %s", user.email, org.name, role)
else:
logger.info("Demo user already exists: %s", existing_user.email)
await session.commit()
logger.info("Demo data loaded successfully")
except Exception as e:
logger.error("Error loading demo data: %s", e)
raise
async def main(): async def main():
"""Main entry point for database initialization.""" """Main entry point for database initialization."""
# Configure logging to show info logs # Configure logging to show info logs

View File

@@ -1,7 +1,7 @@
import logging import logging
import os import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from datetime import datetime from datetime import UTC, datetime
from typing import Any from typing import Any
from apscheduler.schedulers.asyncio import AsyncIOScheduler from apscheduler.schedulers.asyncio import AsyncIOScheduler
@@ -14,8 +14,9 @@ from slowapi.errors import RateLimitExceeded
from slowapi.util import get_remote_address from slowapi.util import get_remote_address
from app.api.main import api_router from app.api.main import api_router
from app.api.routes.oauth_provider import wellknown_router as oauth_wellknown_router
from app.core.config import settings from app.core.config import settings
from app.core.database import check_database_health from app.core.database import check_database_health, close_async_db
from app.core.exceptions import ( from app.core.exceptions import (
APIException, APIException,
api_exception_handler, api_exception_handler,
@@ -71,6 +72,7 @@ async def lifespan(app: FastAPI):
if os.getenv("IS_TEST", "False") != "True": if os.getenv("IS_TEST", "False") != "True":
scheduler.shutdown() scheduler.shutdown()
logger.info("Scheduled jobs stopped") logger.info("Scheduled jobs stopped")
await close_async_db()
logger.info("Starting app!!!") logger.info("Starting app!!!")
@@ -293,7 +295,7 @@ async def health_check() -> JSONResponse:
""" """
health_status: dict[str, Any] = { health_status: dict[str, Any] = {
"status": "healthy", "status": "healthy",
"timestamp": datetime.utcnow().isoformat() + "Z", "timestamp": datetime.now(UTC).isoformat().replace("+00:00", "Z"),
"version": settings.VERSION, "version": settings.VERSION,
"environment": settings.ENVIRONMENT, "environment": settings.ENVIRONMENT,
"checks": {}, "checks": {},
@@ -318,9 +320,13 @@ async def health_check() -> JSONResponse:
"message": f"Database connection failed: {e!s}", "message": f"Database connection failed: {e!s}",
} }
response_status = status.HTTP_503_SERVICE_UNAVAILABLE response_status = status.HTTP_503_SERVICE_UNAVAILABLE
logger.error(f"Health check failed - database error: {e}") logger.error("Health check failed - database error: %s", e)
return JSONResponse(status_code=response_status, content=health_status) return JSONResponse(status_code=response_status, content=health_status)
app.include_router(api_router, prefix=settings.API_V1_STR) app.include_router(api_router, prefix=settings.API_V1_STR)
# OAuth 2.0 well-known endpoint at root level per RFC 8414
# This allows MCP clients to discover the OAuth server metadata at /.well-known/oauth-authorization-server
app.include_router(oauth_wellknown_router)

View File

@@ -7,6 +7,15 @@ Imports all models to ensure they're registered with SQLAlchemy.
from app.core.database import Base from app.core.database import Base
from .base import TimestampMixin, UUIDMixin from .base import TimestampMixin, UUIDMixin
# OAuth models (client mode - authenticate via Google/GitHub)
from .oauth_account import OAuthAccount
# OAuth provider models (server mode - act as authorization server for MCP)
from .oauth_authorization_code import OAuthAuthorizationCode
from .oauth_client import OAuthClient
from .oauth_provider_token import OAuthConsent, OAuthProviderRefreshToken
from .oauth_state import OAuthState
from .organization import Organization from .organization import Organization
# Import models # Import models
@@ -16,6 +25,12 @@ from .user_session import UserSession
__all__ = [ __all__ = [
"Base", "Base",
"OAuthAccount",
"OAuthAuthorizationCode",
"OAuthClient",
"OAuthConsent",
"OAuthProviderRefreshToken",
"OAuthState",
"Organization", "Organization",
"OrganizationRole", "OrganizationRole",
"TimestampMixin", "TimestampMixin",

View File

@@ -0,0 +1,55 @@
"""OAuth account model for linking external OAuth providers to users."""
from sqlalchemy import Column, DateTime, ForeignKey, Index, String, UniqueConstraint
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
from .base import Base, TimestampMixin, UUIDMixin
class OAuthAccount(Base, UUIDMixin, TimestampMixin):
"""
Links OAuth provider accounts to users.
Supports multiple OAuth providers per user (e.g., user can have both
Google and GitHub connected). Each provider account is uniquely identified
by (provider, provider_user_id).
"""
__tablename__ = "oauth_accounts"
# Link to user
user_id = Column(
UUID(as_uuid=True),
ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
# OAuth provider identification
provider = Column(
String(50), nullable=False, index=True
) # google, github, microsoft
provider_user_id = Column(String(255), nullable=False) # Provider's unique user ID
provider_email = Column(
String(255), nullable=True, index=True
) # Email from provider (for reference)
# Optional: store provider tokens for API access
# TODO: Encrypt these at rest in production (requires key management infrastructure)
access_token = Column(String(2048), nullable=True)
refresh_token = Column(String(2048), nullable=True)
token_expires_at = Column(DateTime(timezone=True), nullable=True)
# Relationship
user = relationship("User", back_populates="oauth_accounts")
__table_args__ = (
# Each provider account can only be linked to one user
UniqueConstraint("provider", "provider_user_id", name="uq_oauth_provider_user"),
# Index for finding all OAuth accounts for a user + provider
Index("ix_oauth_accounts_user_provider", "user_id", "provider"),
)
def __repr__(self):
return f"<OAuthAccount {self.provider}:{self.provider_user_id}>"

View File

@@ -0,0 +1,100 @@
"""OAuth authorization code model for OAuth provider mode."""
from datetime import UTC, datetime
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Index, String
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
from .base import Base, TimestampMixin, UUIDMixin
class OAuthAuthorizationCode(Base, UUIDMixin, TimestampMixin):
"""
OAuth 2.0 Authorization Code for the authorization code flow.
Authorization codes are:
- Single-use (marked as used after exchange)
- Short-lived (10 minutes default)
- Bound to specific client, user, redirect_uri
- Support PKCE (code_challenge/code_challenge_method)
Security considerations:
- Code must be cryptographically random (64 chars, URL-safe)
- Must validate redirect_uri matches exactly
- Must verify PKCE code_verifier for public clients
- Must be consumed within expiration time
Performance indexes (defined in migration 0002_add_performance_indexes.py):
- ix_perf_oauth_auth_codes_expires: expires_at WHERE used = false
"""
__tablename__ = "oauth_authorization_codes"
# The authorization code (cryptographically random, URL-safe)
code = Column(String(128), unique=True, nullable=False, index=True)
# Client that requested the code
client_id = Column(
String(64),
ForeignKey("oauth_clients.client_id", ondelete="CASCADE"),
nullable=False,
)
# User who authorized the request
user_id = Column(
UUID(as_uuid=True),
ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
)
# Redirect URI (must match exactly on token exchange)
redirect_uri = Column(String(2048), nullable=False)
# Granted scopes (space-separated)
scope = Column(String(1000), nullable=False, default="")
# PKCE support (required for public clients)
code_challenge = Column(String(128), nullable=True)
code_challenge_method = Column(String(10), nullable=True) # "S256" or "plain"
# State parameter (for CSRF protection, returned to client)
state = Column(String(256), nullable=True)
# Nonce (for OpenID Connect, included in ID token)
nonce = Column(String(256), nullable=True)
# Expiration (codes are short-lived, typically 10 minutes)
expires_at = Column(DateTime(timezone=True), nullable=False)
# Single-use flag (set to True after successful exchange)
used = Column(Boolean, default=False, nullable=False)
# Relationships
client = relationship("OAuthClient", backref="authorization_codes")
user = relationship("User", backref="oauth_authorization_codes")
# Indexes for efficient cleanup queries
__table_args__ = (
Index("ix_oauth_authorization_codes_expires_at", "expires_at"),
Index("ix_oauth_authorization_codes_client_user", "client_id", "user_id"),
)
def __repr__(self):
return f"<OAuthAuthorizationCode {self.code[:8]}... for {self.client_id}>"
@property
def is_expired(self) -> bool:
"""Check if the authorization code has expired."""
# Use timezone-aware comparison (datetime.utcnow() is deprecated)
now = datetime.now(UTC)
expires_at = self.expires_at
# Handle both timezone-aware and naive datetimes from DB
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=UTC)
return bool(now > expires_at)
@property
def is_valid(self) -> bool:
"""Check if the authorization code is valid (not used, not expired)."""
return not self.used and not self.is_expired

View File

@@ -0,0 +1,67 @@
"""OAuth client model for OAuth provider mode (MCP clients)."""
from sqlalchemy import Boolean, Column, ForeignKey, String
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.orm import relationship
from .base import Base, TimestampMixin, UUIDMixin
class OAuthClient(Base, UUIDMixin, TimestampMixin):
"""
Registered OAuth clients (for OAuth provider mode).
This model stores third-party applications that can authenticate
against this API using OAuth 2.0. Used for MCP (Model Context Protocol)
client authentication and API access.
NOTE: This is a skeleton implementation. The full OAuth provider
functionality (authorization endpoint, token endpoint, etc.) can be
expanded when needed.
"""
__tablename__ = "oauth_clients"
# Client credentials
client_id = Column(String(64), unique=True, nullable=False, index=True)
client_secret_hash = Column(
String(255), nullable=True
) # NULL for public clients (PKCE)
# Client metadata
client_name = Column(String(255), nullable=False)
client_description = Column(String(1000), nullable=True)
# Client type: "public" (SPA, mobile) or "confidential" (server-side)
client_type = Column(String(20), nullable=False, default="public")
# Allowed redirect URIs (JSON array)
redirect_uris = Column(JSONB, nullable=False, default=list)
# Allowed scopes (JSON array of scope names)
allowed_scopes = Column(JSONB, nullable=False, default=list)
# Token lifetimes (in seconds)
access_token_lifetime = Column(String(10), nullable=False, default="3600") # 1 hour
refresh_token_lifetime = Column(
String(10), nullable=False, default="604800"
) # 7 days
# Status
is_active = Column(Boolean, default=True, nullable=False, index=True)
# Optional: owner user (for user-registered applications)
owner_user_id = Column(
UUID(as_uuid=True),
ForeignKey("users.id", ondelete="SET NULL"),
nullable=True,
)
# MCP-specific: URL of the MCP server this client represents
mcp_server_url = Column(String(2048), nullable=True)
# Relationship
owner = relationship("User", backref="owned_oauth_clients")
def __repr__(self):
return f"<OAuthClient {self.client_name} ({self.client_id[:8]}...)>"

View File

@@ -0,0 +1,162 @@
"""OAuth provider token models for OAuth provider mode."""
from datetime import UTC, datetime
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Index, String
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
from .base import Base, TimestampMixin, UUIDMixin
class OAuthProviderRefreshToken(Base, UUIDMixin, TimestampMixin):
"""
OAuth 2.0 Refresh Token for the OAuth provider.
Refresh tokens are:
- Opaque (stored as hash in DB, actual token given to client)
- Long-lived (configurable, default 30 days)
- Revocable (via revoked flag or deletion)
- Bound to specific client, user, and scope
Access tokens are JWTs and not stored in DB (self-contained).
This model only tracks refresh tokens for revocation support.
Security considerations:
- Store token hash, not plaintext
- Support token rotation (new refresh token on use)
- Track last used time for security auditing
- Support revocation by user, client, or admin
Performance indexes (defined in migration 0002_add_performance_indexes.py):
- ix_perf_oauth_refresh_tokens_expires: expires_at WHERE revoked = false
"""
__tablename__ = "oauth_provider_refresh_tokens"
# Hash of the refresh token (SHA-256)
# We store hash, not plaintext, for security
token_hash = Column(String(64), unique=True, nullable=False, index=True)
# Unique token ID (JTI) - used in JWT access tokens to reference this refresh token
jti = Column(String(64), unique=True, nullable=False, index=True)
# Client that owns this token
client_id = Column(
String(64),
ForeignKey("oauth_clients.client_id", ondelete="CASCADE"),
nullable=False,
)
# User who authorized this token
user_id = Column(
UUID(as_uuid=True),
ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
)
# Granted scopes (space-separated)
scope = Column(String(1000), nullable=False, default="")
# Token expiration
expires_at = Column(DateTime(timezone=True), nullable=False)
# Revocation flag
revoked = Column(Boolean, default=False, nullable=False, index=True)
# Last used timestamp (for security auditing)
last_used_at = Column(DateTime(timezone=True), nullable=True)
# Device/session info (optional, for user visibility)
device_info = Column(String(500), nullable=True)
ip_address = Column(String(45), nullable=True)
# Relationships
client = relationship("OAuthClient", backref="refresh_tokens")
user = relationship("User", backref="oauth_provider_refresh_tokens")
# Indexes
__table_args__ = (
Index("ix_oauth_provider_refresh_tokens_expires_at", "expires_at"),
Index("ix_oauth_provider_refresh_tokens_client_user", "client_id", "user_id"),
Index(
"ix_oauth_provider_refresh_tokens_user_revoked",
"user_id",
"revoked",
),
)
def __repr__(self):
status = "revoked" if self.revoked else "active"
return f"<OAuthProviderRefreshToken {self.jti[:8]}... ({status})>"
@property
def is_expired(self) -> bool:
"""Check if the refresh token has expired."""
# Use timezone-aware comparison (datetime.utcnow() is deprecated)
now = datetime.now(UTC)
expires_at = self.expires_at
# Handle both timezone-aware and naive datetimes from DB
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=UTC)
return bool(now > expires_at)
@property
def is_valid(self) -> bool:
"""Check if the refresh token is valid (not revoked, not expired)."""
return not self.revoked and not self.is_expired
class OAuthConsent(Base, UUIDMixin, TimestampMixin):
"""
OAuth consent record - remembers user consent for a client.
When a user grants consent to an OAuth client, we store the record
so they don't have to re-consent on subsequent authorizations
(unless scopes change).
This enables a better UX - users only see consent screen once per client,
unless the client requests additional scopes.
"""
__tablename__ = "oauth_consents"
# User who granted consent
user_id = Column(
UUID(as_uuid=True),
ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
)
# Client that received consent
client_id = Column(
String(64),
ForeignKey("oauth_clients.client_id", ondelete="CASCADE"),
nullable=False,
)
# Granted scopes (space-separated)
granted_scopes = Column(String(1000), nullable=False, default="")
# Relationships
client = relationship("OAuthClient", backref="consents")
user = relationship("User", backref="oauth_consents")
# Unique constraint: one consent record per user+client
__table_args__ = (
Index(
"ix_oauth_consents_user_client",
"user_id",
"client_id",
unique=True,
),
)
def __repr__(self):
return f"<OAuthConsent user={self.user_id} client={self.client_id}>"
def has_scopes(self, requested_scopes: list[str]) -> bool:
"""Check if all requested scopes are already granted."""
granted = set(self.granted_scopes.split()) if self.granted_scopes else set()
requested = set(requested_scopes)
return requested.issubset(granted)

View File

@@ -0,0 +1,45 @@
"""OAuth state model for CSRF protection during OAuth flows."""
from sqlalchemy import Column, DateTime, String
from sqlalchemy.dialects.postgresql import UUID
from .base import Base, TimestampMixin, UUIDMixin
class OAuthState(Base, UUIDMixin, TimestampMixin):
"""
Temporary storage for OAuth state parameters.
Prevents CSRF attacks during OAuth flows by storing a random state
value that must match on callback. Also stores PKCE code_verifier
for the Authorization Code flow with PKCE.
These records are short-lived (10 minutes by default) and should
be deleted after use or expiration.
"""
__tablename__ = "oauth_states"
# Random state parameter (CSRF protection)
state = Column(String(255), unique=True, nullable=False, index=True)
# PKCE code_verifier (used to generate code_challenge)
code_verifier = Column(String(128), nullable=True)
# OIDC nonce for ID token replay protection
nonce = Column(String(255), nullable=True)
# OAuth provider (google, github, etc.)
provider = Column(String(50), nullable=False)
# Original redirect URI (for callback validation)
redirect_uri = Column(String(500), nullable=True)
# User ID if this is an account linking flow (user is already logged in)
user_id = Column(UUID(as_uuid=True), nullable=True)
# Expiration time
expires_at = Column(DateTime(timezone=True), nullable=False)
def __repr__(self):
return f"<OAuthState {self.state[:8]}... ({self.provider})>"

View File

@@ -10,6 +10,9 @@ class Organization(Base, UUIDMixin, TimestampMixin):
""" """
Organization model for multi-tenant support. Organization model for multi-tenant support.
Users can belong to multiple organizations with different roles. Users can belong to multiple organizations with different roles.
Performance indexes (defined in migration 0002_add_performance_indexes.py):
- ix_perf_organizations_slug_lower: LOWER(slug) WHERE is_active = true
""" """
__tablename__ = "organizations" __tablename__ = "organizations"

View File

@@ -6,22 +6,45 @@ from .base import Base, TimestampMixin, UUIDMixin
class User(Base, UUIDMixin, TimestampMixin): class User(Base, UUIDMixin, TimestampMixin):
"""
User model for authentication and profile data.
Performance indexes (defined in migration 0002_add_performance_indexes.py):
- ix_perf_users_email_lower: LOWER(email) WHERE deleted_at IS NULL
- ix_perf_users_active: is_active WHERE deleted_at IS NULL
"""
__tablename__ = "users" __tablename__ = "users"
email = Column(String(255), unique=True, nullable=False, index=True) email = Column(String(255), unique=True, nullable=False, index=True)
password_hash = Column(String(255), nullable=False) # Nullable to support OAuth-only users who never set a password
password_hash = Column(String(255), nullable=True)
first_name = Column(String(100), nullable=False, default="user") first_name = Column(String(100), nullable=False, default="user")
last_name = Column(String(100), nullable=True) last_name = Column(String(100), nullable=True)
phone_number = Column(String(20)) phone_number = Column(String(20))
is_active = Column(Boolean, default=True, nullable=False, index=True) is_active = Column(Boolean, default=True, nullable=False, index=True)
is_superuser = Column(Boolean, default=False, nullable=False, index=True) is_superuser = Column(Boolean, default=False, nullable=False, index=True)
preferences = Column(JSONB) preferences = Column(JSONB)
locale = Column(String(10), nullable=True, index=True)
deleted_at = Column(DateTime(timezone=True), nullable=True, index=True) deleted_at = Column(DateTime(timezone=True), nullable=True, index=True)
# Relationships # Relationships
user_organizations = relationship( user_organizations = relationship(
"UserOrganization", back_populates="user", cascade="all, delete-orphan" "UserOrganization", back_populates="user", cascade="all, delete-orphan"
) )
oauth_accounts = relationship(
"OAuthAccount", back_populates="user", cascade="all, delete-orphan"
)
@property
def has_password(self) -> bool:
"""Check if user can login with password (not OAuth-only)."""
return self.password_hash is not None
@property
def can_remove_oauth(self) -> bool:
"""Check if user can safely remove an OAuth account link."""
return self.has_password or len(self.oauth_accounts) > 1
def __repr__(self): def __repr__(self):
return f"<User {self.email}>" return f"<User {self.email}>"

View File

@@ -44,7 +44,7 @@ class UserOrganization(Base, TimestampMixin):
Enum(OrganizationRole), Enum(OrganizationRole),
default=OrganizationRole.MEMBER, default=OrganizationRole.MEMBER,
nullable=False, nullable=False,
index=True, # Note: index defined in __table_args__ as ix_user_org_role
) )
is_active = Column(Boolean, default=True, nullable=False, index=True) is_active = Column(Boolean, default=True, nullable=False, index=True)

View File

@@ -22,6 +22,9 @@ class UserSession(Base, UUIDMixin, TimestampMixin):
Each time a user logs in from a device, a new session is created. Each time a user logs in from a device, a new session is created.
Sessions are identified by the refresh token JTI (JWT ID). Sessions are identified by the refresh token JTI (JWT ID).
Performance indexes (defined in migration 0002_add_performance_indexes.py):
- ix_perf_user_sessions_expires: expires_at WHERE is_active = true
""" """
__tablename__ = "user_sessions" __tablename__ = "user_sessions"
@@ -73,7 +76,11 @@ class UserSession(Base, UUIDMixin, TimestampMixin):
"""Check if session has expired.""" """Check if session has expired."""
from datetime import datetime from datetime import datetime
return self.expires_at < datetime.now(UTC) now = datetime.now(UTC)
expires_at = self.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=UTC)
return bool(expires_at < now)
def to_dict(self): def to_dict(self):
"""Convert session to dictionary for serialization.""" """Convert session to dictionary for serialization."""

View File

@@ -0,0 +1,39 @@
# app/repositories/__init__.py
"""Repository layer — all database access goes through these classes."""
from app.repositories.oauth_account import OAuthAccountRepository, oauth_account_repo
from app.repositories.oauth_authorization_code import (
OAuthAuthorizationCodeRepository,
oauth_authorization_code_repo,
)
from app.repositories.oauth_client import OAuthClientRepository, oauth_client_repo
from app.repositories.oauth_consent import OAuthConsentRepository, oauth_consent_repo
from app.repositories.oauth_provider_token import (
OAuthProviderTokenRepository,
oauth_provider_token_repo,
)
from app.repositories.oauth_state import OAuthStateRepository, oauth_state_repo
from app.repositories.organization import OrganizationRepository, organization_repo
from app.repositories.session import SessionRepository, session_repo
from app.repositories.user import UserRepository, user_repo
__all__ = [
"OAuthAccountRepository",
"OAuthAuthorizationCodeRepository",
"OAuthClientRepository",
"OAuthConsentRepository",
"OAuthProviderTokenRepository",
"OAuthStateRepository",
"OrganizationRepository",
"SessionRepository",
"UserRepository",
"oauth_account_repo",
"oauth_authorization_code_repo",
"oauth_client_repo",
"oauth_consent_repo",
"oauth_provider_token_repo",
"oauth_state_repo",
"organization_repo",
"session_repo",
"user_repo",
]

View File

@@ -1,6 +1,6 @@
# app/crud/base_async.py # app/repositories/base.py
""" """
Async CRUD operations base class using SQLAlchemy 2.0 async patterns. Base repository class for async database operations using SQLAlchemy 2.0 async patterns.
Provides reusable create, read, update, and delete operations for all models. Provides reusable create, read, update, and delete operations for all models.
""" """
@@ -18,6 +18,11 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Load from sqlalchemy.orm import Load
from app.core.database import Base from app.core.database import Base
from app.core.repository_exceptions import (
DuplicateEntryError,
IntegrityConstraintError,
InvalidInputError,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -26,16 +31,16 @@ CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel) UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
class CRUDBase[ class BaseRepository[
ModelType: Base, ModelType: Base,
CreateSchemaType: BaseModel, CreateSchemaType: BaseModel,
UpdateSchemaType: BaseModel, UpdateSchemaType: BaseModel,
]: ]:
"""Async CRUD operations for a model.""" """Async repository operations for a model."""
def __init__(self, model: type[ModelType]): def __init__(self, model: type[ModelType]):
""" """
CRUD object with default async methods to Create, Read, Update, Delete. Repository object with default async methods to Create, Read, Update, Delete.
Parameters: Parameters:
model: A SQLAlchemy model class model: A SQLAlchemy model class
@@ -56,26 +61,19 @@ class CRUDBase[
Returns: Returns:
Model instance or None if not found Model instance or None if not found
Example:
# Eager load user relationship
from sqlalchemy.orm import joinedload
session = await session_crud.get(db, id=session_id, options=[joinedload(UserSession.user)])
""" """
# Validate UUID format and convert to UUID object if string
try: try:
if isinstance(id, uuid.UUID): if isinstance(id, uuid.UUID):
uuid_obj = id uuid_obj = id
else: else:
uuid_obj = uuid.UUID(str(id)) uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e: except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format: {id} - {e!s}") logger.warning("Invalid UUID format: %s - %s", id, e)
return None return None
try: try:
query = select(self.model).where(self.model.id == uuid_obj) query = select(self.model).where(self.model.id == uuid_obj)
# Apply eager loading options if provided
if options: if options:
for option in options: for option in options:
query = query.options(option) query = query.options(option)
@@ -83,7 +81,9 @@ class CRUDBase[
result = await db.execute(query) result = await db.execute(query)
return result.scalar_one_or_none() return result.scalar_one_or_none()
except Exception as e: except Exception as e:
logger.error(f"Error retrieving {self.model.__name__} with id {id}: {e!s}") logger.error(
"Error retrieving %s with id %s: %s", self.model.__name__, id, e
)
raise raise
async def get_multi( async def get_multi(
@@ -96,28 +96,17 @@ class CRUDBase[
) -> list[ModelType]: ) -> list[ModelType]:
""" """
Get multiple records with pagination validation and optional eager loading. Get multiple records with pagination validation and optional eager loading.
Args:
db: Database session
skip: Number of records to skip
limit: Maximum number of records to return
options: Optional list of SQLAlchemy load options for eager loading
Returns:
List of model instances
""" """
# Validate pagination parameters
if skip < 0: if skip < 0:
raise ValueError("skip must be non-negative") raise InvalidInputError("skip must be non-negative")
if limit < 0: if limit < 0:
raise ValueError("limit must be non-negative") raise InvalidInputError("limit must be non-negative")
if limit > 1000: if limit > 1000:
raise ValueError("Maximum limit is 1000") raise InvalidInputError("Maximum limit is 1000")
try: try:
query = select(self.model).offset(skip).limit(limit) query = select(self.model).order_by(self.model.id).offset(skip).limit(limit)
# Apply eager loading options if provided
if options: if options:
for option in options: for option in options:
query = query.options(option) query = query.options(option)
@@ -126,7 +115,7 @@ class CRUDBase[
return list(result.scalars().all()) return list(result.scalars().all())
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Error retrieving multiple {self.model.__name__} records: {e!s}" "Error retrieving multiple %s records: %s", self.model.__name__, e
) )
raise raise
@@ -136,9 +125,8 @@ class CRUDBase[
"""Create a new record with error handling. """Create a new record with error handling.
NOTE: This method is defensive code that's never called in practice. NOTE: This method is defensive code that's never called in practice.
All CRUD subclasses (CRUDUser, CRUDOrganization, CRUDSession) override this method All repository subclasses override this method with their own implementations.
with their own implementations, so the base implementation and its exception handlers Marked as pragma: no cover to avoid false coverage gaps.
are never executed. Marked as pragma: no cover to avoid false coverage gaps.
""" """
try: # pragma: no cover try: # pragma: no cover
obj_in_data = jsonable_encoder(obj_in) obj_in_data = jsonable_encoder(obj_in)
@@ -152,22 +140,24 @@ class CRUDBase[
error_msg = str(e.orig) if hasattr(e, "orig") else str(e) error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower(): if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
logger.warning( logger.warning(
f"Duplicate entry attempted for {self.model.__name__}: {error_msg}" "Duplicate entry attempted for %s: %s",
self.model.__name__,
error_msg,
) )
raise ValueError( raise DuplicateEntryError(
f"A {self.model.__name__} with this data already exists" f"A {self.model.__name__} with this data already exists"
) )
logger.error(f"Integrity error creating {self.model.__name__}: {error_msg}") logger.error(
raise ValueError(f"Database integrity error: {error_msg}") "Integrity error creating %s: %s", self.model.__name__, error_msg
)
raise IntegrityConstraintError(f"Database integrity error: {error_msg}")
except (OperationalError, DataError) as e: # pragma: no cover except (OperationalError, DataError) as e: # pragma: no cover
await db.rollback() await db.rollback()
logger.error(f"Database error creating {self.model.__name__}: {e!s}") logger.error("Database error creating %s: %s", self.model.__name__, e)
raise ValueError(f"Database operation failed: {e!s}") raise IntegrityConstraintError(f"Database operation failed: {e!s}")
except Exception as e: # pragma: no cover except Exception as e: # pragma: no cover
await db.rollback() await db.rollback()
logger.error( logger.exception("Unexpected error creating %s: %s", self.model.__name__, e)
f"Unexpected error creating {self.model.__name__}: {e!s}", exc_info=True
)
raise raise
async def update( async def update(
@@ -198,34 +188,35 @@ class CRUDBase[
error_msg = str(e.orig) if hasattr(e, "orig") else str(e) error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower(): if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
logger.warning( logger.warning(
f"Duplicate entry attempted for {self.model.__name__}: {error_msg}" "Duplicate entry attempted for %s: %s",
self.model.__name__,
error_msg,
) )
raise ValueError( raise DuplicateEntryError(
f"A {self.model.__name__} with this data already exists" f"A {self.model.__name__} with this data already exists"
) )
logger.error(f"Integrity error updating {self.model.__name__}: {error_msg}") logger.error(
raise ValueError(f"Database integrity error: {error_msg}") "Integrity error updating %s: %s", self.model.__name__, error_msg
)
raise IntegrityConstraintError(f"Database integrity error: {error_msg}")
except (OperationalError, DataError) as e: except (OperationalError, DataError) as e:
await db.rollback() await db.rollback()
logger.error(f"Database error updating {self.model.__name__}: {e!s}") logger.error("Database error updating %s: %s", self.model.__name__, e)
raise ValueError(f"Database operation failed: {e!s}") raise IntegrityConstraintError(f"Database operation failed: {e!s}")
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error( logger.exception("Unexpected error updating %s: %s", self.model.__name__, e)
f"Unexpected error updating {self.model.__name__}: {e!s}", exc_info=True
)
raise raise
async def remove(self, db: AsyncSession, *, id: str) -> ModelType | None: async def remove(self, db: AsyncSession, *, id: str) -> ModelType | None:
"""Delete a record with error handling and null check.""" """Delete a record with error handling and null check."""
# Validate UUID format and convert to UUID object if string
try: try:
if isinstance(id, uuid.UUID): if isinstance(id, uuid.UUID):
uuid_obj = id uuid_obj = id
else: else:
uuid_obj = uuid.UUID(str(id)) uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e: except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format for deletion: {id} - {e!s}") logger.warning("Invalid UUID format for deletion: %s - %s", id, e)
return None return None
try: try:
@@ -236,7 +227,7 @@ class CRUDBase[
if obj is None: if obj is None:
logger.warning( logger.warning(
f"{self.model.__name__} with id {id} not found for deletion" "%s with id %s not found for deletion", self.model.__name__, id
) )
return None return None
@@ -246,15 +237,16 @@ class CRUDBase[
except IntegrityError as e: except IntegrityError as e:
await db.rollback() await db.rollback()
error_msg = str(e.orig) if hasattr(e, "orig") else str(e) error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
logger.error(f"Integrity error deleting {self.model.__name__}: {error_msg}") logger.error(
raise ValueError( "Integrity error deleting %s: %s", self.model.__name__, error_msg
)
raise IntegrityConstraintError(
f"Cannot delete {self.model.__name__}: referenced by other records" f"Cannot delete {self.model.__name__}: referenced by other records"
) )
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error( logger.exception(
f"Error deleting {self.model.__name__} with id {id}: {e!s}", "Error deleting %s with id %s: %s", self.model.__name__, id, e
exc_info=True,
) )
raise raise
@@ -267,65 +259,53 @@ class CRUDBase[
sort_by: str | None = None, sort_by: str | None = None,
sort_order: str = "asc", sort_order: str = "asc",
filters: dict[str, Any] | None = None, filters: dict[str, Any] | None = None,
) -> tuple[list[ModelType], int]: ) -> tuple[list[ModelType], int]: # pragma: no cover
""" """
Get multiple records with total count, filtering, and sorting. Get multiple records with total count, filtering, and sorting.
Args: NOTE: This method is defensive code that's never called in practice.
db: Database session All repository subclasses override this method with their own implementations.
skip: Number of records to skip Marked as pragma: no cover to avoid false coverage gaps.
limit: Maximum number of records to return
sort_by: Field name to sort by (must be a valid model attribute)
sort_order: Sort order ("asc" or "desc")
filters: Dictionary of filters (field_name: value)
Returns:
Tuple of (items, total_count)
""" """
# Validate pagination parameters
if skip < 0: if skip < 0:
raise ValueError("skip must be non-negative") raise InvalidInputError("skip must be non-negative")
if limit < 0: if limit < 0:
raise ValueError("limit must be non-negative") raise InvalidInputError("limit must be non-negative")
if limit > 1000: if limit > 1000:
raise ValueError("Maximum limit is 1000") raise InvalidInputError("Maximum limit is 1000")
try: try:
# Build base query
query = select(self.model) query = select(self.model)
# Exclude soft-deleted records by default
if hasattr(self.model, "deleted_at"): if hasattr(self.model, "deleted_at"):
query = query.where(self.model.deleted_at.is_(None)) query = query.where(self.model.deleted_at.is_(None))
# Apply filters
if filters: if filters:
for field, value in filters.items(): for field, value in filters.items():
if hasattr(self.model, field) and value is not None: if hasattr(self.model, field) and value is not None:
query = query.where(getattr(self.model, field) == value) query = query.where(getattr(self.model, field) == value)
# Get total count (before pagination)
count_query = select(func.count()).select_from(query.alias()) count_query = select(func.count()).select_from(query.alias())
count_result = await db.execute(count_query) count_result = await db.execute(count_query)
total = count_result.scalar_one() total = count_result.scalar_one()
# Apply sorting
if sort_by and hasattr(self.model, sort_by): if sort_by and hasattr(self.model, sort_by):
sort_column = getattr(self.model, sort_by) sort_column = getattr(self.model, sort_by)
if sort_order.lower() == "desc": if sort_order.lower() == "desc":
query = query.order_by(sort_column.desc()) query = query.order_by(sort_column.desc())
else: else:
query = query.order_by(sort_column.asc()) query = query.order_by(sort_column.asc())
else:
query = query.order_by(self.model.id)
# Apply pagination
query = query.offset(skip).limit(limit) query = query.offset(skip).limit(limit)
items_result = await db.execute(query) items_result = await db.execute(query)
items = list(items_result.scalars().all()) items = list(items_result.scalars().all())
return items, total return items, total
except Exception as e: except Exception as e: # pragma: no cover
logger.error( logger.error(
f"Error retrieving paginated {self.model.__name__} records: {e!s}" "Error retrieving paginated %s records: %s", self.model.__name__, e
) )
raise raise
@@ -335,7 +315,7 @@ class CRUDBase[
result = await db.execute(select(func.count(self.model.id))) result = await db.execute(select(func.count(self.model.id)))
return result.scalar_one() return result.scalar_one()
except Exception as e: except Exception as e:
logger.error(f"Error counting {self.model.__name__} records: {e!s}") logger.error("Error counting %s records: %s", self.model.__name__, e)
raise raise
async def exists(self, db: AsyncSession, id: str) -> bool: async def exists(self, db: AsyncSession, id: str) -> bool:
@@ -351,14 +331,13 @@ class CRUDBase[
""" """
from datetime import datetime from datetime import datetime
# Validate UUID format and convert to UUID object if string
try: try:
if isinstance(id, uuid.UUID): if isinstance(id, uuid.UUID):
uuid_obj = id uuid_obj = id
else: else:
uuid_obj = uuid.UUID(str(id)) uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e: except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format for soft deletion: {id} - {e!s}") logger.warning("Invalid UUID format for soft deletion: %s - %s", id, e)
return None return None
try: try:
@@ -369,18 +348,16 @@ class CRUDBase[
if obj is None: if obj is None:
logger.warning( logger.warning(
f"{self.model.__name__} with id {id} not found for soft deletion" "%s with id %s not found for soft deletion", self.model.__name__, id
) )
return None return None
# Check if model supports soft deletes
if not hasattr(self.model, "deleted_at"): if not hasattr(self.model, "deleted_at"):
logger.error(f"{self.model.__name__} does not support soft deletes") logger.error("%s does not support soft deletes", self.model.__name__)
raise ValueError( raise InvalidInputError(
f"{self.model.__name__} does not have a deleted_at column" f"{self.model.__name__} does not have a deleted_at column"
) )
# Set deleted_at timestamp
obj.deleted_at = datetime.now(UTC) obj.deleted_at = datetime.now(UTC)
db.add(obj) db.add(obj)
await db.commit() await db.commit()
@@ -388,9 +365,8 @@ class CRUDBase[
return obj return obj
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error( logger.exception(
f"Error soft deleting {self.model.__name__} with id {id}: {e!s}", "Error soft deleting %s with id %s: %s", self.model.__name__, id, e
exc_info=True,
) )
raise raise
@@ -400,18 +376,16 @@ class CRUDBase[
Only works if the model has a 'deleted_at' column. Only works if the model has a 'deleted_at' column.
""" """
# Validate UUID format
try: try:
if isinstance(id, uuid.UUID): if isinstance(id, uuid.UUID):
uuid_obj = id uuid_obj = id
else: else:
uuid_obj = uuid.UUID(str(id)) uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e: except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format for restoration: {id} - {e!s}") logger.warning("Invalid UUID format for restoration: %s - %s", id, e)
return None return None
try: try:
# Find the soft-deleted record
if hasattr(self.model, "deleted_at"): if hasattr(self.model, "deleted_at"):
result = await db.execute( result = await db.execute(
select(self.model).where( select(self.model).where(
@@ -420,18 +394,19 @@ class CRUDBase[
) )
obj = result.scalar_one_or_none() obj = result.scalar_one_or_none()
else: else:
logger.error(f"{self.model.__name__} does not support soft deletes") logger.error("%s does not support soft deletes", self.model.__name__)
raise ValueError( raise InvalidInputError(
f"{self.model.__name__} does not have a deleted_at column" f"{self.model.__name__} does not have a deleted_at column"
) )
if obj is None: if obj is None:
logger.warning( logger.warning(
f"Soft-deleted {self.model.__name__} with id {id} not found for restoration" "Soft-deleted %s with id %s not found for restoration",
self.model.__name__,
id,
) )
return None return None
# Clear deleted_at timestamp
obj.deleted_at = None obj.deleted_at = None
db.add(obj) db.add(obj)
await db.commit() await db.commit()
@@ -439,8 +414,7 @@ class CRUDBase[
return obj return obj
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error( logger.exception(
f"Error restoring {self.model.__name__} with id {id}: {e!s}", "Error restoring %s with id %s: %s", self.model.__name__, id, e
exc_info=True,
) )
raise raise

View File

@@ -0,0 +1,249 @@
# app/repositories/oauth_account.py
"""Repository for OAuthAccount model async database operations."""
import logging
from datetime import datetime
from uuid import UUID
from pydantic import BaseModel
from sqlalchemy import and_, delete, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from app.core.repository_exceptions import DuplicateEntryError
from app.models.oauth_account import OAuthAccount
from app.repositories.base import BaseRepository
from app.schemas.oauth import OAuthAccountCreate
logger = logging.getLogger(__name__)
class EmptySchema(BaseModel):
"""Placeholder schema for repository operations that don't need update schemas."""
class OAuthAccountRepository(
BaseRepository[OAuthAccount, OAuthAccountCreate, EmptySchema]
):
"""Repository for OAuth account links."""
async def get_by_provider_id(
self,
db: AsyncSession,
*,
provider: str,
provider_user_id: str,
) -> OAuthAccount | None:
"""Get OAuth account by provider and provider user ID."""
try:
result = await db.execute(
select(OAuthAccount)
.where(
and_(
OAuthAccount.provider == provider,
OAuthAccount.provider_user_id == provider_user_id,
)
)
.options(joinedload(OAuthAccount.user))
)
return result.scalar_one_or_none()
except Exception as e: # pragma: no cover
logger.error(
"Error getting OAuth account for %s:%s: %s",
provider,
provider_user_id,
e,
)
raise
async def get_by_provider_email(
self,
db: AsyncSession,
*,
provider: str,
email: str,
) -> OAuthAccount | None:
"""Get OAuth account by provider and email."""
try:
result = await db.execute(
select(OAuthAccount)
.where(
and_(
OAuthAccount.provider == provider,
OAuthAccount.provider_email == email,
)
)
.options(joinedload(OAuthAccount.user))
)
return result.scalar_one_or_none()
except Exception as e: # pragma: no cover
logger.error(
"Error getting OAuth account for %s email %s: %s", provider, email, e
)
raise
async def get_user_accounts(
self,
db: AsyncSession,
*,
user_id: str | UUID,
) -> list[OAuthAccount]:
"""Get all OAuth accounts linked to a user."""
try:
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
result = await db.execute(
select(OAuthAccount)
.where(OAuthAccount.user_id == user_uuid)
.order_by(OAuthAccount.created_at.desc())
)
return list(result.scalars().all())
except Exception as e: # pragma: no cover
logger.error("Error getting OAuth accounts for user %s: %s", user_id, e)
raise
async def get_user_account_by_provider(
self,
db: AsyncSession,
*,
user_id: str | UUID,
provider: str,
) -> OAuthAccount | None:
"""Get a specific OAuth account for a user and provider."""
try:
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
result = await db.execute(
select(OAuthAccount).where(
and_(
OAuthAccount.user_id == user_uuid,
OAuthAccount.provider == provider,
)
)
)
return result.scalar_one_or_none()
except Exception as e: # pragma: no cover
logger.error(
"Error getting OAuth account for user %s, provider %s: %s",
user_id,
provider,
e,
)
raise
async def create_account(
self, db: AsyncSession, *, obj_in: OAuthAccountCreate
) -> OAuthAccount:
"""Create a new OAuth account link."""
try:
db_obj = OAuthAccount(
user_id=obj_in.user_id,
provider=obj_in.provider,
provider_user_id=obj_in.provider_user_id,
provider_email=obj_in.provider_email,
access_token=obj_in.access_token,
refresh_token=obj_in.refresh_token,
token_expires_at=obj_in.token_expires_at,
)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
logger.info(
"OAuth account created: %s linked to user %s",
obj_in.provider,
obj_in.user_id,
)
return db_obj
except IntegrityError as e: # pragma: no cover
await db.rollback()
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
if "uq_oauth_provider_user" in error_msg.lower():
logger.warning(
"OAuth account already exists: %s:%s",
obj_in.provider,
obj_in.provider_user_id,
)
raise DuplicateEntryError(
f"This {obj_in.provider} account is already linked to another user"
)
logger.error("Integrity error creating OAuth account: %s", error_msg)
raise DuplicateEntryError(f"Failed to create OAuth account: {error_msg}")
except Exception as e: # pragma: no cover
await db.rollback()
logger.exception("Error creating OAuth account: %s", e)
raise
async def delete_account(
self,
db: AsyncSession,
*,
user_id: str | UUID,
provider: str,
) -> bool:
"""Delete an OAuth account link."""
try:
user_uuid = UUID(str(user_id)) if isinstance(user_id, str) else user_id
result = await db.execute(
delete(OAuthAccount).where(
and_(
OAuthAccount.user_id == user_uuid,
OAuthAccount.provider == provider,
)
)
)
await db.commit()
deleted = result.rowcount > 0
if deleted:
logger.info(
"OAuth account deleted: %s unlinked from user %s", provider, user_id
)
else:
logger.warning(
"OAuth account not found for deletion: %s for user %s",
provider,
user_id,
)
return deleted
except Exception as e: # pragma: no cover
await db.rollback()
logger.error(
"Error deleting OAuth account %s for user %s: %s", provider, user_id, e
)
raise
async def update_tokens(
self,
db: AsyncSession,
*,
account: OAuthAccount,
access_token: str | None = None,
refresh_token: str | None = None,
token_expires_at: datetime | None = None,
) -> OAuthAccount:
"""Update OAuth tokens for an account."""
try:
if access_token is not None:
account.access_token = access_token
if refresh_token is not None:
account.refresh_token = refresh_token
if token_expires_at is not None:
account.token_expires_at = token_expires_at
db.add(account)
await db.commit()
await db.refresh(account)
return account
except Exception as e: # pragma: no cover
await db.rollback()
logger.error("Error updating OAuth tokens: %s", e)
raise
# Singleton instance
oauth_account_repo = OAuthAccountRepository(OAuthAccount)

View File

@@ -0,0 +1,108 @@
# app/repositories/oauth_authorization_code.py
"""Repository for OAuthAuthorizationCode model."""
import logging
from datetime import UTC, datetime
from uuid import UUID
from sqlalchemy import and_, delete, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.oauth_authorization_code import OAuthAuthorizationCode
logger = logging.getLogger(__name__)
class OAuthAuthorizationCodeRepository:
"""Repository for OAuth 2.0 authorization codes."""
async def create_code(
self,
db: AsyncSession,
*,
code: str,
client_id: str,
user_id: UUID,
redirect_uri: str,
scope: str,
expires_at: datetime,
code_challenge: str | None = None,
code_challenge_method: str | None = None,
state: str | None = None,
nonce: str | None = None,
) -> OAuthAuthorizationCode:
"""Create and persist a new authorization code."""
auth_code = OAuthAuthorizationCode(
code=code,
client_id=client_id,
user_id=user_id,
redirect_uri=redirect_uri,
scope=scope,
code_challenge=code_challenge,
code_challenge_method=code_challenge_method,
state=state,
nonce=nonce,
expires_at=expires_at,
used=False,
)
db.add(auth_code)
await db.commit()
return auth_code
async def consume_code_atomically(
self, db: AsyncSession, *, code: str
) -> UUID | None:
"""
Atomically mark a code as used and return its UUID.
Returns the UUID if the code was found and not yet used, None otherwise.
This prevents race conditions per RFC 6749 Section 4.1.2.
"""
stmt = (
update(OAuthAuthorizationCode)
.where(
and_(
OAuthAuthorizationCode.code == code,
OAuthAuthorizationCode.used == False, # noqa: E712
)
)
.values(used=True)
.returning(OAuthAuthorizationCode.id)
)
result = await db.execute(stmt)
row_id = result.scalar_one_or_none()
if row_id is not None:
await db.commit()
return row_id
async def get_by_id(
self, db: AsyncSession, *, code_id: UUID
) -> OAuthAuthorizationCode | None:
"""Get authorization code by its UUID primary key."""
result = await db.execute(
select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.id == code_id)
)
return result.scalar_one_or_none()
async def get_by_code(
self, db: AsyncSession, *, code: str
) -> OAuthAuthorizationCode | None:
"""Get authorization code by the code string value."""
result = await db.execute(
select(OAuthAuthorizationCode).where(OAuthAuthorizationCode.code == code)
)
return result.scalar_one_or_none()
async def cleanup_expired(self, db: AsyncSession) -> int:
"""Delete all expired authorization codes. Returns count deleted."""
result = await db.execute(
delete(OAuthAuthorizationCode).where(
OAuthAuthorizationCode.expires_at < datetime.now(UTC)
)
)
await db.commit()
return result.rowcount # type: ignore[attr-defined]
# Singleton instance
oauth_authorization_code_repo = OAuthAuthorizationCodeRepository()

View File

@@ -0,0 +1,201 @@
# app/repositories/oauth_client.py
"""Repository for OAuthClient model async database operations."""
import logging
import secrets
from uuid import UUID
from pydantic import BaseModel
from sqlalchemy import and_, delete, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.repository_exceptions import DuplicateEntryError
from app.models.oauth_client import OAuthClient
from app.repositories.base import BaseRepository
from app.schemas.oauth import OAuthClientCreate
logger = logging.getLogger(__name__)
class EmptySchema(BaseModel):
"""Placeholder schema for repository operations that don't need update schemas."""
class OAuthClientRepository(
BaseRepository[OAuthClient, OAuthClientCreate, EmptySchema]
):
"""Repository for OAuth clients (provider mode)."""
async def get_by_client_id(
self, db: AsyncSession, *, client_id: str
) -> OAuthClient | None:
"""Get OAuth client by client_id."""
try:
result = await db.execute(
select(OAuthClient).where(
and_(
OAuthClient.client_id == client_id,
OAuthClient.is_active == True, # noqa: E712
)
)
)
return result.scalar_one_or_none()
except Exception as e: # pragma: no cover
logger.error("Error getting OAuth client %s: %s", client_id, e)
raise
async def create_client(
self,
db: AsyncSession,
*,
obj_in: OAuthClientCreate,
owner_user_id: UUID | None = None,
) -> tuple[OAuthClient, str | None]:
"""Create a new OAuth client."""
try:
client_id = secrets.token_urlsafe(32)
client_secret = None
client_secret_hash = None
if obj_in.client_type == "confidential":
client_secret = secrets.token_urlsafe(48)
from app.core.auth import get_password_hash
client_secret_hash = get_password_hash(client_secret)
db_obj = OAuthClient(
client_id=client_id,
client_secret_hash=client_secret_hash,
client_name=obj_in.client_name,
client_description=obj_in.client_description,
client_type=obj_in.client_type,
redirect_uris=obj_in.redirect_uris,
allowed_scopes=obj_in.allowed_scopes,
owner_user_id=owner_user_id,
is_active=True,
)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
logger.info(
"OAuth client created: %s (%s...)", obj_in.client_name, client_id[:8]
)
return db_obj, client_secret
except IntegrityError as e: # pragma: no cover
await db.rollback()
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
logger.error("Error creating OAuth client: %s", error_msg)
raise DuplicateEntryError(f"Failed to create OAuth client: {error_msg}")
except Exception as e: # pragma: no cover
await db.rollback()
logger.exception("Error creating OAuth client: %s", e)
raise
async def deactivate_client(
self, db: AsyncSession, *, client_id: str
) -> OAuthClient | None:
"""Deactivate an OAuth client."""
try:
client = await self.get_by_client_id(db, client_id=client_id)
if client is None:
return None
client.is_active = False
db.add(client)
await db.commit()
await db.refresh(client)
logger.info("OAuth client deactivated: %s", client.client_name)
return client
except Exception as e: # pragma: no cover
await db.rollback()
logger.error("Error deactivating OAuth client %s: %s", client_id, e)
raise
async def validate_redirect_uri(
self, db: AsyncSession, *, client_id: str, redirect_uri: str
) -> bool:
"""Validate that a redirect URI is allowed for a client."""
try:
client = await self.get_by_client_id(db, client_id=client_id)
if client is None:
return False
return redirect_uri in (client.redirect_uris or [])
except Exception as e: # pragma: no cover
logger.error("Error validating redirect URI: %s", e)
return False
async def verify_client_secret(
self, db: AsyncSession, *, client_id: str, client_secret: str
) -> bool:
"""Verify client credentials."""
try:
result = await db.execute(
select(OAuthClient).where(
and_(
OAuthClient.client_id == client_id,
OAuthClient.is_active == True, # noqa: E712
)
)
)
client = result.scalar_one_or_none()
if client is None or client.client_secret_hash is None:
return False
from app.core.auth import verify_password
stored_hash: str = str(client.client_secret_hash)
if stored_hash.startswith("$2"):
return verify_password(client_secret, stored_hash)
else:
import hashlib
secret_hash = hashlib.sha256(client_secret.encode()).hexdigest()
return secrets.compare_digest(stored_hash, secret_hash)
except Exception as e: # pragma: no cover
logger.error("Error verifying client secret: %s", e)
return False
async def get_all_clients(
self, db: AsyncSession, *, include_inactive: bool = False
) -> list[OAuthClient]:
"""Get all OAuth clients."""
try:
query = select(OAuthClient).order_by(OAuthClient.created_at.desc())
if not include_inactive:
query = query.where(OAuthClient.is_active == True) # noqa: E712
result = await db.execute(query)
return list(result.scalars().all())
except Exception as e: # pragma: no cover
logger.error("Error getting all OAuth clients: %s", e)
raise
async def delete_client(self, db: AsyncSession, *, client_id: str) -> bool:
"""Delete an OAuth client permanently."""
try:
result = await db.execute(
delete(OAuthClient).where(OAuthClient.client_id == client_id)
)
await db.commit()
deleted = result.rowcount > 0
if deleted:
logger.info("OAuth client deleted: %s", client_id)
else:
logger.warning("OAuth client not found for deletion: %s", client_id)
return deleted
except Exception as e: # pragma: no cover
await db.rollback()
logger.error("Error deleting OAuth client %s: %s", client_id, e)
raise
# Singleton instance
oauth_client_repo = OAuthClientRepository(OAuthClient)

View File

@@ -0,0 +1,113 @@
# app/repositories/oauth_consent.py
"""Repository for OAuthConsent model."""
import logging
from typing import Any
from uuid import UUID
from sqlalchemy import and_, delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.oauth_client import OAuthClient
from app.models.oauth_provider_token import OAuthConsent
logger = logging.getLogger(__name__)
class OAuthConsentRepository:
"""Repository for OAuth consent records (user grants to clients)."""
async def get_consent(
self, db: AsyncSession, *, user_id: UUID, client_id: str
) -> OAuthConsent | None:
"""Get the consent record for a user-client pair, or None if not found."""
result = await db.execute(
select(OAuthConsent).where(
and_(
OAuthConsent.user_id == user_id,
OAuthConsent.client_id == client_id,
)
)
)
return result.scalar_one_or_none()
async def grant_consent(
self,
db: AsyncSession,
*,
user_id: UUID,
client_id: str,
scopes: list[str],
) -> OAuthConsent:
"""
Create or update consent for a user-client pair.
If consent already exists, the new scopes are merged with existing ones.
Returns the created or updated consent record.
"""
consent = await self.get_consent(db, user_id=user_id, client_id=client_id)
if consent:
existing = (
set(consent.granted_scopes.split()) if consent.granted_scopes else set()
)
merged = existing | set(scopes)
consent.granted_scopes = " ".join(sorted(merged)) # type: ignore[assignment]
else:
consent = OAuthConsent(
user_id=user_id,
client_id=client_id,
granted_scopes=" ".join(sorted(set(scopes))),
)
db.add(consent)
await db.commit()
await db.refresh(consent)
return consent
async def get_user_consents_with_clients(
self, db: AsyncSession, *, user_id: UUID
) -> list[dict[str, Any]]:
"""Get all consent records for a user joined with client details."""
result = await db.execute(
select(OAuthConsent, OAuthClient)
.join(OAuthClient, OAuthConsent.client_id == OAuthClient.client_id)
.where(OAuthConsent.user_id == user_id)
)
rows = result.all()
return [
{
"client_id": consent.client_id,
"client_name": client.client_name,
"client_description": client.client_description,
"granted_scopes": consent.granted_scopes.split()
if consent.granted_scopes
else [],
"granted_at": consent.created_at.isoformat(),
}
for consent, client in rows
]
async def revoke_consent(
self, db: AsyncSession, *, user_id: UUID, client_id: str
) -> bool:
"""
Delete the consent record for a user-client pair.
Returns True if a record was found and deleted.
Note: Callers are responsible for also revoking associated tokens.
"""
result = await db.execute(
delete(OAuthConsent).where(
and_(
OAuthConsent.user_id == user_id,
OAuthConsent.client_id == client_id,
)
)
)
await db.commit()
return result.rowcount > 0 # type: ignore[attr-defined]
# Singleton instance
oauth_consent_repo = OAuthConsentRepository()

View File

@@ -0,0 +1,142 @@
# app/repositories/oauth_provider_token.py
"""Repository for OAuthProviderRefreshToken model."""
import logging
from datetime import UTC, datetime, timedelta
from uuid import UUID
from sqlalchemy import and_, delete, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.oauth_provider_token import OAuthProviderRefreshToken
logger = logging.getLogger(__name__)
class OAuthProviderTokenRepository:
"""Repository for OAuth provider refresh tokens."""
async def create_token(
self,
db: AsyncSession,
*,
token_hash: str,
jti: str,
client_id: str,
user_id: UUID,
scope: str,
expires_at: datetime,
device_info: str | None = None,
ip_address: str | None = None,
) -> OAuthProviderRefreshToken:
"""Create and persist a new refresh token record."""
token = OAuthProviderRefreshToken(
token_hash=token_hash,
jti=jti,
client_id=client_id,
user_id=user_id,
scope=scope,
expires_at=expires_at,
device_info=device_info,
ip_address=ip_address,
)
db.add(token)
await db.commit()
return token
async def get_by_token_hash(
self, db: AsyncSession, *, token_hash: str
) -> OAuthProviderRefreshToken | None:
"""Get refresh token record by SHA-256 token hash."""
result = await db.execute(
select(OAuthProviderRefreshToken).where(
OAuthProviderRefreshToken.token_hash == token_hash
)
)
return result.scalar_one_or_none()
async def get_by_jti(
self, db: AsyncSession, *, jti: str
) -> OAuthProviderRefreshToken | None:
"""Get refresh token record by JWT ID (JTI)."""
result = await db.execute(
select(OAuthProviderRefreshToken).where(
OAuthProviderRefreshToken.jti == jti
)
)
return result.scalar_one_or_none()
async def revoke(
self, db: AsyncSession, *, token: OAuthProviderRefreshToken
) -> None:
"""Mark a specific token record as revoked."""
token.revoked = True # type: ignore[assignment]
token.last_used_at = datetime.now(UTC) # type: ignore[assignment]
await db.commit()
async def revoke_all_for_user_client(
self, db: AsyncSession, *, user_id: UUID, client_id: str
) -> int:
"""
Revoke all active tokens for a specific user-client pair.
Used when security incidents are detected (e.g., authorization code reuse).
Returns the number of tokens revoked.
"""
result = await db.execute(
update(OAuthProviderRefreshToken)
.where(
and_(
OAuthProviderRefreshToken.user_id == user_id,
OAuthProviderRefreshToken.client_id == client_id,
OAuthProviderRefreshToken.revoked == False, # noqa: E712
)
)
.values(revoked=True)
)
count = result.rowcount # type: ignore[attr-defined]
if count > 0:
await db.commit()
return count
async def revoke_all_for_user(self, db: AsyncSession, *, user_id: UUID) -> int:
"""
Revoke all active tokens for a user across all clients.
Used when user changes password or logs out everywhere.
Returns the number of tokens revoked.
"""
result = await db.execute(
update(OAuthProviderRefreshToken)
.where(
and_(
OAuthProviderRefreshToken.user_id == user_id,
OAuthProviderRefreshToken.revoked == False, # noqa: E712
)
)
.values(revoked=True)
)
count = result.rowcount # type: ignore[attr-defined]
if count > 0:
await db.commit()
return count
async def cleanup_expired(self, db: AsyncSession, *, cutoff_days: int = 7) -> int:
"""
Delete expired refresh tokens older than cutoff_days.
Should be called periodically (e.g., daily).
Returns the number of tokens deleted.
"""
cutoff = datetime.now(UTC) - timedelta(days=cutoff_days)
result = await db.execute(
delete(OAuthProviderRefreshToken).where(
OAuthProviderRefreshToken.expires_at < cutoff
)
)
await db.commit()
return result.rowcount # type: ignore[attr-defined]
# Singleton instance
oauth_provider_token_repo = OAuthProviderTokenRepository()

View File

@@ -0,0 +1,113 @@
# app/repositories/oauth_state.py
"""Repository for OAuthState model async database operations."""
import logging
from datetime import UTC, datetime
from pydantic import BaseModel
from sqlalchemy import delete, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.repository_exceptions import DuplicateEntryError
from app.models.oauth_state import OAuthState
from app.repositories.base import BaseRepository
from app.schemas.oauth import OAuthStateCreate
logger = logging.getLogger(__name__)
class EmptySchema(BaseModel):
"""Placeholder schema for repository operations that don't need update schemas."""
class OAuthStateRepository(BaseRepository[OAuthState, OAuthStateCreate, EmptySchema]):
"""Repository for OAuth state (CSRF protection)."""
async def create_state(
self, db: AsyncSession, *, obj_in: OAuthStateCreate
) -> OAuthState:
"""Create a new OAuth state for CSRF protection."""
try:
db_obj = OAuthState(
state=obj_in.state,
code_verifier=obj_in.code_verifier,
nonce=obj_in.nonce,
provider=obj_in.provider,
redirect_uri=obj_in.redirect_uri,
user_id=obj_in.user_id,
expires_at=obj_in.expires_at,
)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
logger.debug("OAuth state created for %s", obj_in.provider)
return db_obj
except IntegrityError as e: # pragma: no cover
await db.rollback()
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
logger.error("OAuth state collision: %s", error_msg)
raise DuplicateEntryError("Failed to create OAuth state, please retry")
except Exception as e: # pragma: no cover
await db.rollback()
logger.exception("Error creating OAuth state: %s", e)
raise
async def get_and_consume_state(
self, db: AsyncSession, *, state: str
) -> OAuthState | None:
"""Get and delete OAuth state (consume it)."""
try:
result = await db.execute(
select(OAuthState).where(OAuthState.state == state)
)
db_obj = result.scalar_one_or_none()
if db_obj is None:
logger.warning("OAuth state not found: %s...", state[:8])
return None
now = datetime.now(UTC)
expires_at = db_obj.expires_at
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=UTC)
if expires_at < now:
logger.warning("OAuth state expired: %s...", state[:8])
await db.delete(db_obj)
await db.commit()
return None
await db.delete(db_obj)
await db.commit()
logger.debug("OAuth state consumed: %s...", state[:8])
return db_obj
except Exception as e: # pragma: no cover
await db.rollback()
logger.error("Error consuming OAuth state: %s", e)
raise
async def cleanup_expired(self, db: AsyncSession) -> int:
"""Clean up expired OAuth states."""
try:
now = datetime.now(UTC)
stmt = delete(OAuthState).where(OAuthState.expires_at < now)
result = await db.execute(stmt)
await db.commit()
count = result.rowcount
if count > 0:
logger.info("Cleaned up %s expired OAuth states", count)
return count
except Exception as e: # pragma: no cover
await db.rollback()
logger.error("Error cleaning up expired OAuth states: %s", e)
raise
# Singleton instance
oauth_state_repo = OAuthStateRepository(OAuthState)

View File

@@ -1,5 +1,5 @@
# app/crud/organization_async.py # app/repositories/organization.py
"""Async CRUD operations for Organization model using SQLAlchemy 2.0 patterns.""" """Repository for Organization model async database operations using SQLAlchemy 2.0 patterns."""
import logging import logging
from typing import Any from typing import Any
@@ -9,10 +9,11 @@ from sqlalchemy import and_, case, func, or_, select
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.crud.base import CRUDBase from app.core.repository_exceptions import DuplicateEntryError, IntegrityConstraintError
from app.models.organization import Organization from app.models.organization import Organization
from app.models.user import User from app.models.user import User
from app.models.user_organization import OrganizationRole, UserOrganization from app.models.user_organization import OrganizationRole, UserOrganization
from app.repositories.base import BaseRepository
from app.schemas.organizations import ( from app.schemas.organizations import (
OrganizationCreate, OrganizationCreate,
OrganizationUpdate, OrganizationUpdate,
@@ -21,8 +22,10 @@ from app.schemas.organizations import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUpdate]): class OrganizationRepository(
"""Async CRUD operations for Organization model.""" BaseRepository[Organization, OrganizationCreate, OrganizationUpdate]
):
"""Repository for Organization model."""
async def get_by_slug(self, db: AsyncSession, *, slug: str) -> Organization | None: async def get_by_slug(self, db: AsyncSession, *, slug: str) -> Organization | None:
"""Get organization by slug.""" """Get organization by slug."""
@@ -32,7 +35,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
) )
return result.scalar_one_or_none() return result.scalar_one_or_none()
except Exception as e: except Exception as e:
logger.error(f"Error getting organization by slug {slug}: {e!s}") logger.error("Error getting organization by slug %s: %s", slug, e)
raise raise
async def create( async def create(
@@ -54,18 +57,20 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
except IntegrityError as e: except IntegrityError as e:
await db.rollback() await db.rollback()
error_msg = str(e.orig) if hasattr(e, "orig") else str(e) error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
if "slug" in error_msg.lower(): if (
logger.warning(f"Duplicate slug attempted: {obj_in.slug}") "slug" in error_msg.lower()
raise ValueError( or "unique" in error_msg.lower()
or "duplicate" in error_msg.lower()
):
logger.warning("Duplicate slug attempted: %s", obj_in.slug)
raise DuplicateEntryError(
f"Organization with slug '{obj_in.slug}' already exists" f"Organization with slug '{obj_in.slug}' already exists"
) )
logger.error(f"Integrity error creating organization: {error_msg}") logger.error("Integrity error creating organization: %s", error_msg)
raise ValueError(f"Database integrity error: {error_msg}") raise IntegrityConstraintError(f"Database integrity error: {error_msg}")
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error( logger.exception("Unexpected error creating organization: %s", e)
f"Unexpected error creating organization: {e!s}", exc_info=True
)
raise raise
async def get_multi_with_filters( async def get_multi_with_filters(
@@ -79,16 +84,10 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
sort_by: str = "created_at", sort_by: str = "created_at",
sort_order: str = "desc", sort_order: str = "desc",
) -> tuple[list[Organization], int]: ) -> tuple[list[Organization], int]:
""" """Get multiple organizations with filtering, searching, and sorting."""
Get multiple organizations with filtering, searching, and sorting.
Returns:
Tuple of (organizations list, total count)
"""
try: try:
query = select(Organization) query = select(Organization)
# Apply filters
if is_active is not None: if is_active is not None:
query = query.where(Organization.is_active == is_active) query = query.where(Organization.is_active == is_active)
@@ -100,26 +99,23 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
) )
query = query.where(search_filter) query = query.where(search_filter)
# Get total count before pagination
count_query = select(func.count()).select_from(query.alias()) count_query = select(func.count()).select_from(query.alias())
count_result = await db.execute(count_query) count_result = await db.execute(count_query)
total = count_result.scalar_one() total = count_result.scalar_one()
# Apply sorting
sort_column = getattr(Organization, sort_by, Organization.created_at) sort_column = getattr(Organization, sort_by, Organization.created_at)
if sort_order == "desc": if sort_order == "desc":
query = query.order_by(sort_column.desc()) query = query.order_by(sort_column.desc())
else: else:
query = query.order_by(sort_column.asc()) query = query.order_by(sort_column.asc())
# Apply pagination
query = query.offset(skip).limit(limit) query = query.offset(skip).limit(limit)
result = await db.execute(query) result = await db.execute(query)
organizations = list(result.scalars().all()) organizations = list(result.scalars().all())
return organizations, total return organizations, total
except Exception as e: except Exception as e:
logger.error(f"Error getting organizations with filters: {e!s}") logger.error("Error getting organizations with filters: %s", e)
raise raise
async def get_member_count(self, db: AsyncSession, *, organization_id: UUID) -> int: async def get_member_count(self, db: AsyncSession, *, organization_id: UUID) -> int:
@@ -136,7 +132,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
return result.scalar_one() or 0 return result.scalar_one() or 0
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Error getting member count for organization {organization_id}: {e!s}" "Error getting member count for organization %s: %s", organization_id, e
) )
raise raise
@@ -149,16 +145,8 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
is_active: bool | None = None, is_active: bool | None = None,
search: str | None = None, search: str | None = None,
) -> tuple[list[dict[str, Any]], int]: ) -> tuple[list[dict[str, Any]], int]:
""" """Get organizations with member counts in a SINGLE QUERY using JOIN and GROUP BY."""
Get organizations with member counts in a SINGLE QUERY using JOIN and GROUP BY.
This eliminates the N+1 query problem.
Returns:
Tuple of (list of dicts with org and member_count, total count)
"""
try: try:
# Build base query with LEFT JOIN and GROUP BY
# Use CASE statement to count only active members
query = ( query = (
select( select(
Organization, Organization,
@@ -181,10 +169,10 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
.group_by(Organization.id) .group_by(Organization.id)
) )
# Apply filters
if is_active is not None: if is_active is not None:
query = query.where(Organization.is_active == is_active) query = query.where(Organization.is_active == is_active)
search_filter = None
if search: if search:
search_filter = or_( search_filter = or_(
Organization.name.ilike(f"%{search}%"), Organization.name.ilike(f"%{search}%"),
@@ -193,17 +181,15 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
) )
query = query.where(search_filter) query = query.where(search_filter)
# Get total count
count_query = select(func.count(Organization.id)) count_query = select(func.count(Organization.id))
if is_active is not None: if is_active is not None:
count_query = count_query.where(Organization.is_active == is_active) count_query = count_query.where(Organization.is_active == is_active)
if search: if search_filter is not None:
count_query = count_query.where(search_filter) count_query = count_query.where(search_filter)
count_result = await db.execute(count_query) count_result = await db.execute(count_query)
total = count_result.scalar_one() total = count_result.scalar_one()
# Apply pagination and ordering
query = ( query = (
query.order_by(Organization.created_at.desc()).offset(skip).limit(limit) query.order_by(Organization.created_at.desc()).offset(skip).limit(limit)
) )
@@ -211,7 +197,6 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
result = await db.execute(query) result = await db.execute(query)
rows = result.all() rows = result.all()
# Convert to list of dicts
orgs_with_counts = [ orgs_with_counts = [
{"organization": org, "member_count": member_count} {"organization": org, "member_count": member_count}
for org, member_count in rows for org, member_count in rows
@@ -220,9 +205,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
return orgs_with_counts, total return orgs_with_counts, total
except Exception as e: except Exception as e:
logger.error( logger.exception("Error getting organizations with member counts: %s", e)
f"Error getting organizations with member counts: {e!s}", exc_info=True
)
raise raise
async def add_user( async def add_user(
@@ -236,7 +219,6 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
) -> UserOrganization: ) -> UserOrganization:
"""Add a user to an organization with a specific role.""" """Add a user to an organization with a specific role."""
try: try:
# Check if relationship already exists
result = await db.execute( result = await db.execute(
select(UserOrganization).where( select(UserOrganization).where(
and_( and_(
@@ -248,7 +230,6 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
existing = result.scalar_one_or_none() existing = result.scalar_one_or_none()
if existing: if existing:
# Reactivate if inactive, or raise error if already active
if not existing.is_active: if not existing.is_active:
existing.is_active = True existing.is_active = True
existing.role = role existing.role = role
@@ -257,9 +238,10 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
await db.refresh(existing) await db.refresh(existing)
return existing return existing
else: else:
raise ValueError("User is already a member of this organization") raise DuplicateEntryError(
"User is already a member of this organization"
)
# Create new relationship
user_org = UserOrganization( user_org = UserOrganization(
user_id=user_id, user_id=user_id,
organization_id=organization_id, organization_id=organization_id,
@@ -273,11 +255,11 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
return user_org return user_org
except IntegrityError as e: except IntegrityError as e:
await db.rollback() await db.rollback()
logger.error(f"Integrity error adding user to organization: {e!s}") logger.error("Integrity error adding user to organization: %s", e)
raise ValueError("Failed to add user to organization") raise IntegrityConstraintError("Failed to add user to organization")
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error(f"Error adding user to organization: {e!s}", exc_info=True) logger.exception("Error adding user to organization: %s", e)
raise raise
async def remove_user( async def remove_user(
@@ -303,7 +285,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
return True return True
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error(f"Error removing user from organization: {e!s}", exc_info=True) logger.exception("Error removing user from organization: %s", e)
raise raise
async def update_user_role( async def update_user_role(
@@ -338,7 +320,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
return user_org return user_org
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error(f"Error updating user role: {e!s}", exc_info=True) logger.exception("Error updating user role: %s", e)
raise raise
async def get_organization_members( async def get_organization_members(
@@ -348,16 +330,10 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
organization_id: UUID, organization_id: UUID,
skip: int = 0, skip: int = 0,
limit: int = 100, limit: int = 100,
is_active: bool = True, is_active: bool | None = True,
) -> tuple[list[dict[str, Any]], int]: ) -> tuple[list[dict[str, Any]], int]:
""" """Get members of an organization with user details."""
Get members of an organization with user details.
Returns:
Tuple of (members list with user details, total count)
"""
try: try:
# Build query with join
query = ( query = (
select(UserOrganization, User) select(UserOrganization, User)
.join(User, UserOrganization.user_id == User.id) .join(User, UserOrganization.user_id == User.id)
@@ -367,7 +343,6 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
if is_active is not None: if is_active is not None:
query = query.where(UserOrganization.is_active == is_active) query = query.where(UserOrganization.is_active == is_active)
# Get total count
count_query = select(func.count()).select_from( count_query = select(func.count()).select_from(
select(UserOrganization) select(UserOrganization)
.where(UserOrganization.organization_id == organization_id) .where(UserOrganization.organization_id == organization_id)
@@ -381,7 +356,6 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
count_result = await db.execute(count_query) count_result = await db.execute(count_query)
total = count_result.scalar_one() total = count_result.scalar_one()
# Apply ordering and pagination
query = ( query = (
query.order_by(UserOrganization.created_at.desc()) query.order_by(UserOrganization.created_at.desc())
.offset(skip) .offset(skip)
@@ -406,11 +380,11 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
return members, total return members, total
except Exception as e: except Exception as e:
logger.error(f"Error getting organization members: {e!s}") logger.error("Error getting organization members: %s", e)
raise raise
async def get_user_organizations( async def get_user_organizations(
self, db: AsyncSession, *, user_id: UUID, is_active: bool = True self, db: AsyncSession, *, user_id: UUID, is_active: bool | None = True
) -> list[Organization]: ) -> list[Organization]:
"""Get all organizations a user belongs to.""" """Get all organizations a user belongs to."""
try: try:
@@ -429,21 +403,14 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
result = await db.execute(query) result = await db.execute(query)
return list(result.scalars().all()) return list(result.scalars().all())
except Exception as e: except Exception as e:
logger.error(f"Error getting user organizations: {e!s}") logger.error("Error getting user organizations: %s", e)
raise raise
async def get_user_organizations_with_details( async def get_user_organizations_with_details(
self, db: AsyncSession, *, user_id: UUID, is_active: bool = True self, db: AsyncSession, *, user_id: UUID, is_active: bool | None = True
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
""" """Get user's organizations with role and member count in SINGLE QUERY."""
Get user's organizations with role and member count in SINGLE QUERY.
Eliminates N+1 problem by using subquery for member counts.
Returns:
List of dicts with organization, role, and member_count
"""
try: try:
# Subquery to get member counts for each organization
member_count_subq = ( member_count_subq = (
select( select(
UserOrganization.organization_id, UserOrganization.organization_id,
@@ -454,7 +421,6 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
.subquery() .subquery()
) )
# Main query with JOIN to get org, role, and member count
query = ( query = (
select( select(
Organization, Organization,
@@ -486,9 +452,7 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
] ]
except Exception as e: except Exception as e:
logger.error( logger.exception("Error getting user organizations with details: %s", e)
f"Error getting user organizations with details: {e!s}", exc_info=True
)
raise raise
async def get_user_role_in_org( async def get_user_role_in_org(
@@ -507,9 +471,9 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
) )
user_org = result.scalar_one_or_none() user_org = result.scalar_one_or_none()
return user_org.role if user_org else None return user_org.role if user_org else None # pyright: ignore[reportReturnType]
except Exception as e: except Exception as e:
logger.error(f"Error getting user role in org: {e!s}") logger.error("Error getting user role in org: %s", e)
raise raise
async def is_user_org_owner( async def is_user_org_owner(
@@ -531,5 +495,5 @@ class CRUDOrganization(CRUDBase[Organization, OrganizationCreate, OrganizationUp
return role in [OrganizationRole.OWNER, OrganizationRole.ADMIN] return role in [OrganizationRole.OWNER, OrganizationRole.ADMIN]
# Create a singleton instance for use across the application # Singleton instance
organization = CRUDOrganization(Organization) organization_repo = OrganizationRepository(Organization)

View File

@@ -1,6 +1,5 @@
""" # app/repositories/session.py
Async CRUD operations for user sessions using SQLAlchemy 2.0 patterns. """Repository for UserSession model async database operations using SQLAlchemy 2.0 patterns."""
"""
import logging import logging
import uuid import uuid
@@ -11,49 +10,32 @@ from sqlalchemy import and_, delete, func, select, update
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload
from app.crud.base import CRUDBase from app.core.repository_exceptions import IntegrityConstraintError, InvalidInputError
from app.models.user_session import UserSession from app.models.user_session import UserSession
from app.repositories.base import BaseRepository
from app.schemas.sessions import SessionCreate, SessionUpdate from app.schemas.sessions import SessionCreate, SessionUpdate
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): class SessionRepository(BaseRepository[UserSession, SessionCreate, SessionUpdate]):
"""Async CRUD operations for user sessions.""" """Repository for UserSession model."""
async def get_by_jti(self, db: AsyncSession, *, jti: str) -> UserSession | None: async def get_by_jti(self, db: AsyncSession, *, jti: str) -> UserSession | None:
""" """Get session by refresh token JTI."""
Get session by refresh token JTI.
Args:
db: Database session
jti: Refresh token JWT ID
Returns:
UserSession if found, None otherwise
"""
try: try:
result = await db.execute( result = await db.execute(
select(UserSession).where(UserSession.refresh_token_jti == jti) select(UserSession).where(UserSession.refresh_token_jti == jti)
) )
return result.scalar_one_or_none() return result.scalar_one_or_none()
except Exception as e: except Exception as e:
logger.error(f"Error getting session by JTI {jti}: {e!s}") logger.error("Error getting session by JTI %s: %s", jti, e)
raise raise
async def get_active_by_jti( async def get_active_by_jti(
self, db: AsyncSession, *, jti: str self, db: AsyncSession, *, jti: str
) -> UserSession | None: ) -> UserSession | None:
""" """Get active session by refresh token JTI."""
Get active session by refresh token JTI.
Args:
db: Database session
jti: Refresh token JWT ID
Returns:
Active UserSession if found, None otherwise
"""
try: try:
result = await db.execute( result = await db.execute(
select(UserSession).where( select(UserSession).where(
@@ -65,7 +47,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
) )
return result.scalar_one_or_none() return result.scalar_one_or_none()
except Exception as e: except Exception as e:
logger.error(f"Error getting active session by JTI {jti}: {e!s}") logger.error("Error getting active session by JTI %s: %s", jti, e)
raise raise
async def get_user_sessions( async def get_user_sessions(
@@ -76,25 +58,12 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
active_only: bool = True, active_only: bool = True,
with_user: bool = False, with_user: bool = False,
) -> list[UserSession]: ) -> list[UserSession]:
""" """Get all sessions for a user with optional eager loading."""
Get all sessions for a user with optional eager loading.
Args:
db: Database session
user_id: User ID
active_only: If True, return only active sessions
with_user: If True, eager load user relationship to prevent N+1
Returns:
List of UserSession objects
"""
try: try:
# Convert user_id string to UUID if needed
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
query = select(UserSession).where(UserSession.user_id == user_uuid) query = select(UserSession).where(UserSession.user_id == user_uuid)
# Add eager loading if requested to prevent N+1 queries
if with_user: if with_user:
query = query.options(joinedload(UserSession.user)) query = query.options(joinedload(UserSession.user))
@@ -105,25 +74,13 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
result = await db.execute(query) result = await db.execute(query)
return list(result.scalars().all()) return list(result.scalars().all())
except Exception as e: except Exception as e:
logger.error(f"Error getting sessions for user {user_id}: {e!s}") logger.error("Error getting sessions for user %s: %s", user_id, e)
raise raise
async def create_session( async def create_session(
self, db: AsyncSession, *, obj_in: SessionCreate self, db: AsyncSession, *, obj_in: SessionCreate
) -> UserSession: ) -> UserSession:
""" """Create a new user session."""
Create a new user session.
Args:
db: Database session
obj_in: SessionCreate schema with session data
Returns:
Created UserSession
Raises:
ValueError: If session creation fails
"""
try: try:
db_obj = UserSession( db_obj = UserSession(
user_id=obj_in.user_id, user_id=obj_in.user_id,
@@ -143,33 +100,26 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
await db.refresh(db_obj) await db.refresh(db_obj)
logger.info( logger.info(
f"Session created for user {obj_in.user_id} from {obj_in.device_name} " "Session created for user %s from %s (IP: %s)",
f"(IP: {obj_in.ip_address})" obj_in.user_id,
obj_in.device_name,
obj_in.ip_address,
) )
return db_obj return db_obj
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error(f"Error creating session: {e!s}", exc_info=True) logger.exception("Error creating session: %s", e)
raise ValueError(f"Failed to create session: {e!s}") raise IntegrityConstraintError(f"Failed to create session: {e!s}")
async def deactivate( async def deactivate(
self, db: AsyncSession, *, session_id: str self, db: AsyncSession, *, session_id: str
) -> UserSession | None: ) -> UserSession | None:
""" """Deactivate a session (logout from device)."""
Deactivate a session (logout from device).
Args:
db: Database session
session_id: Session UUID
Returns:
Deactivated UserSession if found, None otherwise
"""
try: try:
session = await self.get(db, id=session_id) session = await self.get(db, id=session_id)
if not session: if not session:
logger.warning(f"Session {session_id} not found for deactivation") logger.warning("Session %s not found for deactivation", session_id)
return None return None
session.is_active = False session.is_active = False
@@ -178,31 +128,23 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
await db.refresh(session) await db.refresh(session)
logger.info( logger.info(
f"Session {session_id} deactivated for user {session.user_id} " "Session %s deactivated for user %s (%s)",
f"({session.device_name})" session_id,
session.user_id,
session.device_name,
) )
return session return session
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error(f"Error deactivating session {session_id}: {e!s}") logger.error("Error deactivating session %s: %s", session_id, e)
raise raise
async def deactivate_all_user_sessions( async def deactivate_all_user_sessions(
self, db: AsyncSession, *, user_id: str self, db: AsyncSession, *, user_id: str
) -> int: ) -> int:
""" """Deactivate all active sessions for a user (logout from all devices)."""
Deactivate all active sessions for a user (logout from all devices).
Args:
db: Database session
user_id: User ID
Returns:
Number of sessions deactivated
"""
try: try:
# Convert user_id string to UUID if needed
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
stmt = ( stmt = (
@@ -216,27 +158,18 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
count = result.rowcount count = result.rowcount
logger.info(f"Deactivated {count} sessions for user {user_id}") logger.info("Deactivated %s sessions for user %s", count, user_id)
return count return count
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error(f"Error deactivating all sessions for user {user_id}: {e!s}") logger.error("Error deactivating all sessions for user %s: %s", user_id, e)
raise raise
async def update_last_used( async def update_last_used(
self, db: AsyncSession, *, session: UserSession self, db: AsyncSession, *, session: UserSession
) -> UserSession: ) -> UserSession:
""" """Update the last_used_at timestamp for a session."""
Update the last_used_at timestamp for a session.
Args:
db: Database session
session: UserSession object
Returns:
Updated UserSession
"""
try: try:
session.last_used_at = datetime.now(UTC) session.last_used_at = datetime.now(UTC)
db.add(session) db.add(session)
@@ -245,7 +178,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
return session return session
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error(f"Error updating last_used for session {session.id}: {e!s}") logger.error("Error updating last_used for session %s: %s", session.id, e)
raise raise
async def update_refresh_token( async def update_refresh_token(
@@ -256,20 +189,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
new_jti: str, new_jti: str,
new_expires_at: datetime, new_expires_at: datetime,
) -> UserSession: ) -> UserSession:
""" """Update session with new refresh token JTI and expiration."""
Update session with new refresh token JTI and expiration.
Called during token refresh.
Args:
db: Database session
session: UserSession object
new_jti: New refresh token JTI
new_expires_at: New expiration datetime
Returns:
Updated UserSession
"""
try: try:
session.refresh_token_jti = new_jti session.refresh_token_jti = new_jti
session.expires_at = new_expires_at session.expires_at = new_expires_at
@@ -281,32 +201,16 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error( logger.error(
f"Error updating refresh token for session {session.id}: {e!s}" "Error updating refresh token for session %s: %s", session.id, e
) )
raise raise
async def cleanup_expired(self, db: AsyncSession, *, keep_days: int = 30) -> int: async def cleanup_expired(self, db: AsyncSession, *, keep_days: int = 30) -> int:
""" """Clean up expired sessions using optimized bulk DELETE."""
Clean up expired sessions using optimized bulk DELETE.
Deletes sessions that are:
- Expired AND inactive
- Older than keep_days
Uses single DELETE query instead of N individual deletes for efficiency.
Args:
db: Database session
keep_days: Keep inactive sessions for this many days (for audit)
Returns:
Number of sessions deleted
"""
try: try:
cutoff_date = datetime.now(UTC) - timedelta(days=keep_days) cutoff_date = datetime.now(UTC) - timedelta(days=keep_days)
now = datetime.now(UTC) now = datetime.now(UTC)
# Use bulk DELETE with WHERE clause - single query
stmt = delete(UserSession).where( stmt = delete(UserSession).where(
and_( and_(
UserSession.is_active == False, # noqa: E712 UserSession.is_active == False, # noqa: E712
@@ -321,38 +225,25 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
count = result.rowcount count = result.rowcount
if count > 0: if count > 0:
logger.info(f"Cleaned up {count} expired sessions using bulk DELETE") logger.info("Cleaned up %s expired sessions using bulk DELETE", count)
return count return count
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error(f"Error cleaning up expired sessions: {e!s}") logger.error("Error cleaning up expired sessions: %s", e)
raise raise
async def cleanup_expired_for_user(self, db: AsyncSession, *, user_id: str) -> int: async def cleanup_expired_for_user(self, db: AsyncSession, *, user_id: str) -> int:
""" """Clean up expired and inactive sessions for a specific user."""
Clean up expired and inactive sessions for a specific user.
Uses single bulk DELETE query for efficiency instead of N individual deletes.
Args:
db: Database session
user_id: User ID to cleanup sessions for
Returns:
Number of sessions deleted
"""
try: try:
# Validate UUID
try: try:
uuid_obj = uuid.UUID(user_id) uuid_obj = uuid.UUID(user_id)
except (ValueError, AttributeError): except (ValueError, AttributeError):
logger.error(f"Invalid UUID format: {user_id}") logger.error("Invalid UUID format: %s", user_id)
raise ValueError(f"Invalid user ID format: {user_id}") raise InvalidInputError(f"Invalid user ID format: {user_id}")
now = datetime.now(UTC) now = datetime.now(UTC)
# Use bulk DELETE with WHERE clause - single query
stmt = delete(UserSession).where( stmt = delete(UserSession).where(
and_( and_(
UserSession.user_id == uuid_obj, UserSession.user_id == uuid_obj,
@@ -368,30 +259,22 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
if count > 0: if count > 0:
logger.info( logger.info(
f"Cleaned up {count} expired sessions for user {user_id} using bulk DELETE" "Cleaned up %s expired sessions for user %s using bulk DELETE",
count,
user_id,
) )
return count return count
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error( logger.error(
f"Error cleaning up expired sessions for user {user_id}: {e!s}" "Error cleaning up expired sessions for user %s: %s", user_id, e
) )
raise raise
async def get_user_session_count(self, db: AsyncSession, *, user_id: str) -> int: async def get_user_session_count(self, db: AsyncSession, *, user_id: str) -> int:
""" """Get count of active sessions for a user."""
Get count of active sessions for a user.
Args:
db: Database session
user_id: User ID
Returns:
Number of active sessions
"""
try: try:
# Convert user_id string to UUID if needed
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
result = await db.execute( result = await db.execute(
@@ -401,7 +284,7 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
) )
return result.scalar_one() return result.scalar_one()
except Exception as e: except Exception as e:
logger.error(f"Error counting sessions for user {user_id}: {e!s}") logger.error("Error counting sessions for user %s: %s", user_id, e)
raise raise
async def get_all_sessions( async def get_all_sessions(
@@ -413,31 +296,16 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
active_only: bool = True, active_only: bool = True,
with_user: bool = True, with_user: bool = True,
) -> tuple[list[UserSession], int]: ) -> tuple[list[UserSession], int]:
""" """Get all sessions across all users with pagination (admin only)."""
Get all sessions across all users with pagination (admin only).
Args:
db: Database session
skip: Number of records to skip
limit: Maximum number of records to return
active_only: If True, return only active sessions
with_user: If True, eager load user relationship to prevent N+1
Returns:
Tuple of (list of UserSession objects, total count)
"""
try: try:
# Build query
query = select(UserSession) query = select(UserSession)
# Add eager loading if requested to prevent N+1 queries
if with_user: if with_user:
query = query.options(joinedload(UserSession.user)) query = query.options(joinedload(UserSession.user))
if active_only: if active_only:
query = query.where(UserSession.is_active) query = query.where(UserSession.is_active)
# Get total count
count_query = select(func.count(UserSession.id)) count_query = select(func.count(UserSession.id))
if active_only: if active_only:
count_query = count_query.where(UserSession.is_active) count_query = count_query.where(UserSession.is_active)
@@ -445,7 +313,6 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
count_result = await db.execute(count_query) count_result = await db.execute(count_query)
total = count_result.scalar_one() total = count_result.scalar_one()
# Apply pagination and ordering
query = ( query = (
query.order_by(UserSession.last_used_at.desc()) query.order_by(UserSession.last_used_at.desc())
.offset(skip) .offset(skip)
@@ -458,9 +325,9 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
return sessions, total return sessions, total
except Exception as e: except Exception as e:
logger.error(f"Error getting all sessions: {e!s}", exc_info=True) logger.exception("Error getting all sessions: %s", e)
raise raise
# Create singleton instance # Singleton instance
session = CRUDSession(UserSession) session_repo = SessionRepository(UserSession)

View File

@@ -1,5 +1,5 @@
# app/crud/user_async.py # app/repositories/user.py
"""Async CRUD operations for User model using SQLAlchemy 2.0 patterns.""" """Repository for User model async database operations using SQLAlchemy 2.0 patterns."""
import logging import logging
from datetime import UTC, datetime from datetime import UTC, datetime
@@ -11,15 +11,16 @@ from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.core.auth import get_password_hash_async from app.core.auth import get_password_hash_async
from app.crud.base import CRUDBase from app.core.repository_exceptions import DuplicateEntryError, InvalidInputError
from app.models.user import User from app.models.user import User
from app.repositories.base import BaseRepository
from app.schemas.users import UserCreate, UserUpdate from app.schemas.users import UserCreate, UserUpdate
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]): class UserRepository(BaseRepository[User, UserCreate, UserUpdate]):
"""Async CRUD operations for User model.""" """Repository for User model."""
async def get_by_email(self, db: AsyncSession, *, email: str) -> User | None: async def get_by_email(self, db: AsyncSession, *, email: str) -> User | None:
"""Get user by email address.""" """Get user by email address."""
@@ -27,13 +28,12 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
result = await db.execute(select(User).where(User.email == email)) result = await db.execute(select(User).where(User.email == email))
return result.scalar_one_or_none() return result.scalar_one_or_none()
except Exception as e: except Exception as e:
logger.error(f"Error getting user by email {email}: {e!s}") logger.error("Error getting user by email %s: %s", email, e)
raise raise
async def create(self, db: AsyncSession, *, obj_in: UserCreate) -> User: async def create(self, db: AsyncSession, *, obj_in: UserCreate) -> User:
"""Create a new user with async password hashing and error handling.""" """Create a new user with async password hashing and error handling."""
try: try:
# Hash password asynchronously to avoid blocking event loop
password_hash = await get_password_hash_async(obj_in.password) password_hash = await get_password_hash_async(obj_in.password)
db_obj = User( db_obj = User(
@@ -57,13 +57,49 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
await db.rollback() await db.rollback()
error_msg = str(e.orig) if hasattr(e, "orig") else str(e) error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
if "email" in error_msg.lower(): if "email" in error_msg.lower():
logger.warning(f"Duplicate email attempted: {obj_in.email}") logger.warning("Duplicate email attempted: %s", obj_in.email)
raise ValueError(f"User with email {obj_in.email} already exists") raise DuplicateEntryError(
logger.error(f"Integrity error creating user: {error_msg}") f"User with email {obj_in.email} already exists"
raise ValueError(f"Database integrity error: {error_msg}") )
logger.error("Integrity error creating user: %s", error_msg)
raise DuplicateEntryError(f"Database integrity error: {error_msg}")
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error(f"Unexpected error creating user: {e!s}", exc_info=True) logger.exception("Unexpected error creating user: %s", e)
raise
async def create_oauth_user(
self,
db: AsyncSession,
*,
email: str,
first_name: str = "User",
last_name: str | None = None,
) -> User:
"""Create a new passwordless user for OAuth sign-in."""
try:
db_obj = User(
email=email,
password_hash=None, # OAuth-only user
first_name=first_name,
last_name=last_name,
is_active=True,
is_superuser=False,
)
db.add(db_obj)
await db.flush() # Get user.id without committing
return db_obj
except IntegrityError as e:
await db.rollback()
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
if "email" in error_msg.lower():
logger.warning("Duplicate email attempted: %s", email)
raise DuplicateEntryError(f"User with email {email} already exists")
logger.error("Integrity error creating OAuth user: %s", error_msg)
raise DuplicateEntryError(f"Database integrity error: {error_msg}")
except Exception as e:
await db.rollback()
logger.exception("Unexpected error creating OAuth user: %s", e)
raise raise
async def update( async def update(
@@ -75,8 +111,6 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
else: else:
update_data = obj_in.model_dump(exclude_unset=True) update_data = obj_in.model_dump(exclude_unset=True)
# Handle password separately if it exists in update data
# Hash password asynchronously to avoid blocking event loop
if "password" in update_data: if "password" in update_data:
update_data["password_hash"] = await get_password_hash_async( update_data["password_hash"] = await get_password_hash_async(
update_data["password"] update_data["password"]
@@ -85,6 +119,15 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
return await super().update(db, db_obj=db_obj, obj_in=update_data) return await super().update(db, db_obj=db_obj, obj_in=update_data)
async def update_password(
self, db: AsyncSession, *, user: User, password_hash: str
) -> User:
"""Set a new password hash on a user and commit."""
user.password_hash = password_hash
await db.commit()
await db.refresh(user)
return user
async def get_multi_with_total( async def get_multi_with_total(
self, self,
db: AsyncSession, db: AsyncSession,
@@ -96,43 +139,23 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
filters: dict[str, Any] | None = None, filters: dict[str, Any] | None = None,
search: str | None = None, search: str | None = None,
) -> tuple[list[User], int]: ) -> tuple[list[User], int]:
""" """Get multiple users with total count, filtering, sorting, and search."""
Get multiple users with total count, filtering, sorting, and search.
Args:
db: Database session
skip: Number of records to skip
limit: Maximum number of records to return
sort_by: Field name to sort by
sort_order: Sort order ("asc" or "desc")
filters: Dictionary of filters (field_name: value)
search: Search term to match against email, first_name, last_name
Returns:
Tuple of (users list, total count)
"""
# Validate pagination
if skip < 0: if skip < 0:
raise ValueError("skip must be non-negative") raise InvalidInputError("skip must be non-negative")
if limit < 0: if limit < 0:
raise ValueError("limit must be non-negative") raise InvalidInputError("limit must be non-negative")
if limit > 1000: if limit > 1000:
raise ValueError("Maximum limit is 1000") raise InvalidInputError("Maximum limit is 1000")
try: try:
# Build base query
query = select(User) query = select(User)
# Exclude soft-deleted users
query = query.where(User.deleted_at.is_(None)) query = query.where(User.deleted_at.is_(None))
# Apply filters
if filters: if filters:
for field, value in filters.items(): for field, value in filters.items():
if hasattr(User, field) and value is not None: if hasattr(User, field) and value is not None:
query = query.where(getattr(User, field) == value) query = query.where(getattr(User, field) == value)
# Apply search
if search: if search:
search_filter = or_( search_filter = or_(
User.email.ilike(f"%{search}%"), User.email.ilike(f"%{search}%"),
@@ -141,14 +164,12 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
) )
query = query.where(search_filter) query = query.where(search_filter)
# Get total count
from sqlalchemy import func from sqlalchemy import func
count_query = select(func.count()).select_from(query.alias()) count_query = select(func.count()).select_from(query.alias())
count_result = await db.execute(count_query) count_result = await db.execute(count_query)
total = count_result.scalar_one() total = count_result.scalar_one()
# Apply sorting
if sort_by and hasattr(User, sort_by): if sort_by and hasattr(User, sort_by):
sort_column = getattr(User, sort_by) sort_column = getattr(User, sort_by)
if sort_order.lower() == "desc": if sort_order.lower() == "desc":
@@ -156,7 +177,6 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
else: else:
query = query.order_by(sort_column.asc()) query = query.order_by(sort_column.asc())
# Apply pagination
query = query.offset(skip).limit(limit) query = query.offset(skip).limit(limit)
result = await db.execute(query) result = await db.execute(query)
users = list(result.scalars().all()) users = list(result.scalars().all())
@@ -164,32 +184,21 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
return users, total return users, total
except Exception as e: except Exception as e:
logger.error(f"Error retrieving paginated users: {e!s}") logger.error("Error retrieving paginated users: %s", e)
raise raise
async def bulk_update_status( async def bulk_update_status(
self, db: AsyncSession, *, user_ids: list[UUID], is_active: bool self, db: AsyncSession, *, user_ids: list[UUID], is_active: bool
) -> int: ) -> int:
""" """Bulk update is_active status for multiple users."""
Bulk update is_active status for multiple users.
Args:
db: Database session
user_ids: List of user IDs to update
is_active: New active status
Returns:
Number of users updated
"""
try: try:
if not user_ids: if not user_ids:
return 0 return 0
# Use UPDATE with WHERE IN for efficiency
stmt = ( stmt = (
update(User) update(User)
.where(User.id.in_(user_ids)) .where(User.id.in_(user_ids))
.where(User.deleted_at.is_(None)) # Don't update deleted users .where(User.deleted_at.is_(None))
.values(is_active=is_active, updated_at=datetime.now(UTC)) .values(is_active=is_active, updated_at=datetime.now(UTC))
) )
@@ -197,12 +206,14 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
await db.commit() await db.commit()
updated_count = result.rowcount updated_count = result.rowcount
logger.info(f"Bulk updated {updated_count} users to is_active={is_active}") logger.info(
"Bulk updated %s users to is_active=%s", updated_count, is_active
)
return updated_count return updated_count
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error(f"Error bulk updating user status: {e!s}", exc_info=True) logger.exception("Error bulk updating user status: %s", e)
raise raise
async def bulk_soft_delete( async def bulk_soft_delete(
@@ -212,34 +223,20 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
user_ids: list[UUID], user_ids: list[UUID],
exclude_user_id: UUID | None = None, exclude_user_id: UUID | None = None,
) -> int: ) -> int:
""" """Bulk soft delete multiple users."""
Bulk soft delete multiple users.
Args:
db: Database session
user_ids: List of user IDs to delete
exclude_user_id: Optional user ID to exclude (e.g., the admin performing the action)
Returns:
Number of users deleted
"""
try: try:
if not user_ids: if not user_ids:
return 0 return 0
# Remove excluded user from list
filtered_ids = [uid for uid in user_ids if uid != exclude_user_id] filtered_ids = [uid for uid in user_ids if uid != exclude_user_id]
if not filtered_ids: if not filtered_ids:
return 0 return 0
# Use UPDATE with WHERE IN for efficiency
stmt = ( stmt = (
update(User) update(User)
.where(User.id.in_(filtered_ids)) .where(User.id.in_(filtered_ids))
.where( .where(User.deleted_at.is_(None))
User.deleted_at.is_(None)
) # Don't re-delete already deleted users
.values( .values(
deleted_at=datetime.now(UTC), deleted_at=datetime.now(UTC),
is_active=False, is_active=False,
@@ -251,22 +248,22 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
await db.commit() await db.commit()
deleted_count = result.rowcount deleted_count = result.rowcount
logger.info(f"Bulk soft deleted {deleted_count} users") logger.info("Bulk soft deleted %s users", deleted_count)
return deleted_count return deleted_count
except Exception as e: except Exception as e:
await db.rollback() await db.rollback()
logger.error(f"Error bulk deleting users: {e!s}", exc_info=True) logger.exception("Error bulk deleting users: %s", e)
raise raise
def is_active(self, user: User) -> bool: def is_active(self, user: User) -> bool:
"""Check if user is active.""" """Check if user is active."""
return user.is_active return bool(user.is_active)
def is_superuser(self, user: User) -> bool: def is_superuser(self, user: User) -> bool:
"""Check if user is a superuser.""" """Check if user is a superuser."""
return user.is_superuser return bool(user.is_superuser)
# Create a singleton instance for use across the application # Singleton instance
user = CRUDUser(User) user_repo = UserRepository(User)

View File

@@ -0,0 +1,395 @@
"""
Pydantic schemas for OAuth authentication.
"""
from datetime import datetime
from uuid import UUID
from pydantic import BaseModel, ConfigDict, Field
# ============================================================================
# OAuth Provider Info (for frontend to display available providers)
# ============================================================================
class OAuthProviderInfo(BaseModel):
"""Information about an available OAuth provider."""
provider: str = Field(..., description="Provider identifier (google, github)")
name: str = Field(..., description="Human-readable provider name")
icon: str | None = Field(None, description="Icon identifier for frontend")
class OAuthProvidersResponse(BaseModel):
"""Response containing list of enabled OAuth providers."""
enabled: bool = Field(..., description="Whether OAuth is globally enabled")
providers: list[OAuthProviderInfo] = Field(
default_factory=list, description="List of enabled providers"
)
model_config = ConfigDict(
json_schema_extra={
"example": {
"enabled": True,
"providers": [
{"provider": "google", "name": "Google", "icon": "google"},
{"provider": "github", "name": "GitHub", "icon": "github"},
],
}
}
)
# ============================================================================
# OAuth Account (linked provider accounts)
# ============================================================================
class OAuthAccountBase(BaseModel):
"""Base schema for OAuth accounts."""
provider: str = Field(..., max_length=50, description="OAuth provider name")
provider_email: str | None = Field(
None, max_length=255, description="Email from OAuth provider"
)
class OAuthAccountCreate(OAuthAccountBase):
"""Schema for creating an OAuth account link (internal use)."""
user_id: UUID
provider_user_id: str = Field(..., max_length=255)
access_token: str | None = None
refresh_token: str | None = None
token_expires_at: datetime | None = None
class OAuthAccountResponse(OAuthAccountBase):
"""Schema for OAuth account response to clients."""
id: UUID
created_at: datetime
model_config = ConfigDict(
from_attributes=True,
json_schema_extra={
"example": {
"id": "123e4567-e89b-12d3-a456-426614174000",
"provider": "google",
"provider_email": "user@gmail.com",
"created_at": "2025-11-24T12:00:00Z",
}
},
)
class OAuthAccountsListResponse(BaseModel):
"""Response containing list of linked OAuth accounts."""
accounts: list[OAuthAccountResponse]
model_config = ConfigDict(
json_schema_extra={
"example": {
"accounts": [
{
"id": "123e4567-e89b-12d3-a456-426614174000",
"provider": "google",
"provider_email": "user@gmail.com",
"created_at": "2025-11-24T12:00:00Z",
}
]
}
}
)
# ============================================================================
# OAuth Flow (authorization, callback, etc.)
# ============================================================================
class OAuthAuthorizeRequest(BaseModel):
"""Request parameters for OAuth authorization."""
provider: str = Field(..., description="OAuth provider (google, github)")
redirect_uri: str | None = Field(
None, description="Frontend callback URL after OAuth"
)
mode: str = Field(
default="login",
description="OAuth mode: login, register, or link",
pattern="^(login|register|link)$",
)
class OAuthCallbackRequest(BaseModel):
"""Request parameters for OAuth callback."""
code: str = Field(..., description="Authorization code from provider")
state: str = Field(..., description="State parameter for CSRF protection")
class OAuthCallbackResponse(BaseModel):
"""Response after successful OAuth authentication."""
access_token: str = Field(..., description="JWT access token")
refresh_token: str = Field(..., description="JWT refresh token")
token_type: str = Field(default="bearer")
expires_in: int = Field(..., description="Token expiration in seconds")
is_new_user: bool = Field(
default=False, description="Whether a new user was created"
)
model_config = ConfigDict(
json_schema_extra={
"example": {
"access_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...",
"refresh_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...",
"token_type": "bearer",
"expires_in": 900,
"is_new_user": False,
}
}
)
class OAuthUnlinkResponse(BaseModel):
"""Response after unlinking an OAuth account."""
success: bool = Field(..., description="Whether the unlink was successful")
message: str = Field(..., description="Status message")
model_config = ConfigDict(
json_schema_extra={
"example": {"success": True, "message": "Google account unlinked"}
}
)
# ============================================================================
# OAuth State (CSRF protection - internal use)
# ============================================================================
class OAuthStateCreate(BaseModel):
"""Schema for creating OAuth state (internal use)."""
state: str = Field(..., max_length=255)
code_verifier: str | None = Field(None, max_length=128)
nonce: str | None = Field(None, max_length=255)
provider: str = Field(..., max_length=50)
redirect_uri: str | None = Field(None, max_length=500)
user_id: UUID | None = None
expires_at: datetime
# ============================================================================
# OAuth Client (Provider Mode - MCP clients)
# ============================================================================
class OAuthClientBase(BaseModel):
"""Base schema for OAuth clients."""
client_name: str = Field(..., max_length=255, description="Client application name")
client_description: str | None = Field(
None, max_length=1000, description="Client description"
)
redirect_uris: list[str] = Field(
default_factory=list, description="Allowed redirect URIs"
)
allowed_scopes: list[str] = Field(
default_factory=list, description="Allowed OAuth scopes"
)
class OAuthClientCreate(OAuthClientBase):
"""Schema for creating an OAuth client."""
client_type: str = Field(
default="public",
description="Client type: public or confidential",
pattern="^(public|confidential)$",
)
class OAuthClientResponse(OAuthClientBase):
"""Schema for OAuth client response."""
id: UUID
client_id: str = Field(..., description="OAuth client ID")
client_type: str
is_active: bool
created_at: datetime
model_config = ConfigDict(
from_attributes=True,
json_schema_extra={
"example": {
"id": "123e4567-e89b-12d3-a456-426614174000",
"client_id": "abc123def456",
"client_name": "My MCP App",
"client_description": "My application that uses MCP",
"client_type": "public",
"redirect_uris": ["http://localhost:3000/callback"],
"allowed_scopes": ["read:users", "write:users"],
"is_active": True,
"created_at": "2025-11-24T12:00:00Z",
}
},
)
class OAuthClientWithSecret(OAuthClientResponse):
"""Schema for OAuth client response including secret (only shown once)."""
client_secret: str | None = Field(
None, description="Client secret (only shown once for confidential clients)"
)
model_config = ConfigDict(
from_attributes=True,
json_schema_extra={
"example": {
"id": "123e4567-e89b-12d3-a456-426614174000",
"client_id": "abc123def456",
"client_secret": "secret_xyz789",
"client_name": "My MCP App",
"client_type": "confidential",
"redirect_uris": ["http://localhost:3000/callback"],
"allowed_scopes": ["read:users"],
"is_active": True,
"created_at": "2025-11-24T12:00:00Z",
}
},
)
# ============================================================================
# OAuth Provider Discovery (RFC 8414 - skeleton)
# ============================================================================
class OAuthServerMetadata(BaseModel):
"""OAuth 2.0 Authorization Server Metadata (RFC 8414)."""
issuer: str = Field(..., description="Authorization server issuer URL")
authorization_endpoint: str = Field(..., description="Authorization endpoint URL")
token_endpoint: str = Field(..., description="Token endpoint URL")
registration_endpoint: str | None = Field(
None, description="Dynamic client registration endpoint"
)
revocation_endpoint: str | None = Field(
None, description="Token revocation endpoint"
)
introspection_endpoint: str | None = Field(
None, description="Token introspection endpoint (RFC 7662)"
)
scopes_supported: list[str] = Field(
default_factory=list, description="Supported scopes"
)
response_types_supported: list[str] = Field(
default_factory=lambda: ["code"], description="Supported response types"
)
grant_types_supported: list[str] = Field(
default_factory=lambda: ["authorization_code", "refresh_token"],
description="Supported grant types",
)
code_challenge_methods_supported: list[str] = Field(
default_factory=lambda: ["S256"], description="Supported PKCE methods"
)
token_endpoint_auth_methods_supported: list[str] = Field(
default_factory=lambda: ["client_secret_basic", "client_secret_post", "none"],
description="Supported client authentication methods",
)
model_config = ConfigDict(
json_schema_extra={
"example": {
"issuer": "https://api.example.com",
"authorization_endpoint": "https://api.example.com/oauth/authorize",
"token_endpoint": "https://api.example.com/oauth/token",
"revocation_endpoint": "https://api.example.com/oauth/revoke",
"introspection_endpoint": "https://api.example.com/oauth/introspect",
"scopes_supported": ["openid", "profile", "email", "read:users"],
"response_types_supported": ["code"],
"grant_types_supported": ["authorization_code", "refresh_token"],
"code_challenge_methods_supported": ["S256"],
"token_endpoint_auth_methods_supported": [
"client_secret_basic",
"client_secret_post",
"none",
],
}
}
)
# ============================================================================
# OAuth Token Responses (RFC 6749)
# ============================================================================
class OAuthTokenResponse(BaseModel):
"""OAuth 2.0 Token Response (RFC 6749 Section 5.1)."""
access_token: str = Field(..., description="The access token issued by the server")
token_type: str = Field(
default="Bearer", description="The type of token (typically 'Bearer')"
)
expires_in: int | None = Field(None, description="Token lifetime in seconds")
refresh_token: str | None = Field(
None, description="Refresh token for obtaining new access tokens"
)
scope: str | None = Field(
None, description="Space-separated list of granted scopes"
)
model_config = ConfigDict(
json_schema_extra={
"example": {
"access_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...",
"token_type": "Bearer",
"expires_in": 3600,
"refresh_token": "dGhpcyBpcyBhIHJlZnJlc2ggdG9rZW4...",
"scope": "openid profile email",
}
}
)
class OAuthTokenIntrospectionResponse(BaseModel):
"""OAuth 2.0 Token Introspection Response (RFC 7662)."""
active: bool = Field(..., description="Whether the token is currently active")
scope: str | None = Field(None, description="Space-separated list of scopes")
client_id: str | None = Field(None, description="Client identifier for the token")
username: str | None = Field(
None, description="Human-readable identifier for the resource owner"
)
token_type: str | None = Field(
None, description="Type of the token (e.g., 'Bearer')"
)
exp: int | None = Field(None, description="Token expiration time (Unix timestamp)")
iat: int | None = Field(None, description="Token issue time (Unix timestamp)")
nbf: int | None = Field(None, description="Token not-before time (Unix timestamp)")
sub: str | None = Field(None, description="Subject of the token (user ID)")
aud: str | None = Field(None, description="Intended audience of the token")
iss: str | None = Field(None, description="Issuer of the token")
model_config = ConfigDict(
json_schema_extra={
"example": {
"active": True,
"scope": "openid profile",
"client_id": "client123",
"username": "user@example.com",
"token_type": "Bearer",
"exp": 1735689600,
"iat": 1735686000,
"sub": "user-uuid-here",
}
}
)

View File

@@ -48,7 +48,7 @@ class OrganizationCreate(OrganizationBase):
"""Schema for creating a new organization.""" """Schema for creating a new organization."""
name: str = Field(..., min_length=1, max_length=255) name: str = Field(..., min_length=1, max_length=255)
slug: str = Field(..., min_length=1, max_length=255) slug: str = Field(..., min_length=1, max_length=255) # pyright: ignore[reportIncompatibleVariableOverride]
class OrganizationUpdate(BaseModel): class OrganizationUpdate(BaseModel):

View File

@@ -23,6 +23,7 @@ class UserBase(BaseModel):
class UserCreate(UserBase): class UserCreate(UserBase):
password: str password: str
is_superuser: bool = False is_superuser: bool = False
is_active: bool = True
@field_validator("password") @field_validator("password")
@classmethod @classmethod
@@ -37,6 +38,13 @@ class UserUpdate(BaseModel):
phone_number: str | None = None phone_number: str | None = None
password: str | None = None password: str | None = None
preferences: dict[str, Any] | None = None preferences: dict[str, Any] | None = None
locale: str | None = Field(
None,
max_length=10,
pattern=r"^[a-z]{2}(-[A-Z]{2})?$",
description="User's preferred locale (BCP 47 format: en, it, en-US, it-IT)",
examples=["en", "it", "en-US", "it-IT"],
)
is_active: bool | None = ( is_active: bool | None = (
None # Changed default from True to None to avoid unintended updates None # Changed default from True to None to avoid unintended updates
) )
@@ -55,6 +63,24 @@ class UserUpdate(BaseModel):
return v return v
return validate_password_strength(v) return validate_password_strength(v)
@field_validator("locale")
@classmethod
def validate_locale(cls, v: str | None) -> str | None:
"""Validate locale against supported locales."""
if v is None:
return v
# Only support English and Italian for template showcase
# Note: Locales stored in lowercase for case-insensitive matching
supported_locales = {"en", "it", "en-us", "en-gb", "it-it"}
# Normalize to lowercase for comparison and storage
v_lower = v.lower()
if v_lower not in supported_locales:
raise ValueError(
f"Unsupported locale '{v}'. Supported locales: {sorted(supported_locales)}"
)
# Return normalized lowercase version for consistency
return v_lower
@field_validator("is_superuser") @field_validator("is_superuser")
@classmethod @classmethod
def prevent_superuser_modification(cls, v: bool | None) -> bool | None: def prevent_superuser_modification(cls, v: bool | None) -> bool | None:
@@ -70,6 +96,7 @@ class UserInDB(UserBase):
is_superuser: bool is_superuser: bool
created_at: datetime created_at: datetime
updated_at: datetime | None = None updated_at: datetime | None = None
locale: str | None = None
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)
@@ -80,6 +107,7 @@ class UserResponse(UserBase):
is_superuser: bool is_superuser: bool
created_at: datetime created_at: datetime
updated_at: datetime | None = None updated_at: datetime | None = None
locale: str | None = None
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)

View File

@@ -60,6 +60,15 @@ def validate_password_strength(password: str) -> str:
>>> validate_password_strength("MySecureP@ss123") # Valid >>> validate_password_strength("MySecureP@ss123") # Valid
>>> validate_password_strength("password1") # Invalid - too weak >>> validate_password_strength("password1") # Invalid - too weak
""" """
# Check if we are in demo mode
from app.core.config import settings
if settings.DEMO_MODE:
# In demo mode, allow specific weak passwords for demo accounts
demo_passwords = {"Demo123!", "Admin123!"}
if password in demo_passwords:
return password
# Check minimum length # Check minimum length
if len(password) < 12: if len(password) < 12:
raise ValueError("Password must be at least 12 characters long") raise ValueError("Password must be at least 12 characters long")

View File

@@ -0,0 +1,19 @@
# app/services/__init__.py
from . import oauth_provider_service
from .auth_service import AuthService
from .oauth_service import OAuthService
from .organization_service import OrganizationService, organization_service
from .session_service import SessionService, session_service
from .user_service import UserService, user_service
__all__ = [
"AuthService",
"OAuthService",
"OrganizationService",
"SessionService",
"UserService",
"oauth_provider_service",
"organization_service",
"session_service",
"user_service",
]

View File

@@ -2,7 +2,6 @@
import logging import logging
from uuid import UUID from uuid import UUID
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.core.auth import ( from app.core.auth import (
@@ -14,12 +13,18 @@ from app.core.auth import (
verify_password_async, verify_password_async,
) )
from app.core.config import settings from app.core.config import settings
from app.core.exceptions import AuthenticationError from app.core.exceptions import AuthenticationError, DuplicateError
from app.core.repository_exceptions import DuplicateEntryError
from app.models.user import User from app.models.user import User
from app.repositories.user import user_repo
from app.schemas.users import Token, UserCreate, UserResponse from app.schemas.users import Token, UserCreate, UserResponse
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Pre-computed bcrypt hash used for constant-time comparison when user is not found,
# preventing timing attacks that could enumerate valid email addresses.
_DUMMY_HASH = "$2b$12$EixZaYVK1fsbw1ZfbX3OXePaWxn96p36zLFbnJHfxPSEFBzXKiHia"
class AuthService: class AuthService:
"""Service for handling authentication operations""" """Service for handling authentication operations"""
@@ -39,10 +44,12 @@ class AuthService:
Returns: Returns:
User if authenticated, None otherwise User if authenticated, None otherwise
""" """
result = await db.execute(select(User).where(User.email == email)) user = await user_repo.get_by_email(db, email=email)
user = result.scalar_one_or_none()
if not user: if not user:
# Perform a dummy verification to match timing of a real bcrypt check,
# preventing email enumeration via response-time differences.
await verify_password_async(password, _DUMMY_HASH)
return None return None
# Verify password asynchronously to avoid blocking event loop # Verify password asynchronously to avoid blocking event loop
@@ -71,40 +78,23 @@ class AuthService:
""" """
try: try:
# Check if user already exists # Check if user already exists
result = await db.execute(select(User).where(User.email == user_data.email)) existing_user = await user_repo.get_by_email(db, email=user_data.email)
existing_user = result.scalar_one_or_none()
if existing_user: if existing_user:
raise AuthenticationError("User with this email already exists") raise DuplicateError("User with this email already exists")
# Create new user with async password hashing # Delegate creation (hashing + commit) to the repository
# Hash password asynchronously to avoid blocking event loop user = await user_repo.create(db, obj_in=user_data)
hashed_password = await get_password_hash_async(user_data.password)
# Create user object from model logger.info("User created successfully: %s", user.email)
user = User(
email=user_data.email,
password_hash=hashed_password,
first_name=user_data.first_name,
last_name=user_data.last_name,
phone_number=user_data.phone_number,
is_active=True,
is_superuser=False,
)
db.add(user)
await db.commit()
await db.refresh(user)
logger.info(f"User created successfully: {user.email}")
return user return user
except AuthenticationError: except (AuthenticationError, DuplicateError):
# Re-raise authentication errors without rollback # Re-raise API exceptions without rollback
raise raise
except DuplicateEntryError as e:
raise DuplicateError(str(e))
except Exception as e: except Exception as e:
# Rollback on any database errors logger.exception("Error creating user: %s", e)
await db.rollback()
logger.error(f"Error creating user: {e!s}", exc_info=True)
raise AuthenticationError(f"Failed to create user: {e!s}") raise AuthenticationError(f"Failed to create user: {e!s}")
@staticmethod @staticmethod
@@ -168,8 +158,7 @@ class AuthService:
user_id = token_data.user_id user_id = token_data.user_id
# Get user from database # Get user from database
result = await db.execute(select(User).where(User.id == user_id)) user = await user_repo.get(db, id=str(user_id))
user = result.scalar_one_or_none()
if not user or not user.is_active: if not user or not user.is_active:
raise TokenInvalidError("Invalid user or inactive account") raise TokenInvalidError("Invalid user or inactive account")
@@ -177,7 +166,7 @@ class AuthService:
return AuthService.create_tokens(user) return AuthService.create_tokens(user)
except (TokenExpiredError, TokenInvalidError) as e: except (TokenExpiredError, TokenInvalidError) as e:
logger.warning(f"Token refresh failed: {e!s}") logger.warning("Token refresh failed: %s", e)
raise raise
@staticmethod @staticmethod
@@ -200,8 +189,7 @@ class AuthService:
AuthenticationError: If current password is incorrect or update fails AuthenticationError: If current password is incorrect or update fails
""" """
try: try:
result = await db.execute(select(User).where(User.id == user_id)) user = await user_repo.get(db, id=str(user_id))
user = result.scalar_one_or_none()
if not user: if not user:
raise AuthenticationError("User not found") raise AuthenticationError("User not found")
@@ -210,10 +198,10 @@ class AuthService:
raise AuthenticationError("Current password is incorrect") raise AuthenticationError("Current password is incorrect")
# Hash new password asynchronously to avoid blocking event loop # Hash new password asynchronously to avoid blocking event loop
user.password_hash = await get_password_hash_async(new_password) new_hash = await get_password_hash_async(new_password)
await db.commit() await user_repo.update_password(db, user=user, password_hash=new_hash)
logger.info(f"Password changed successfully for user {user_id}") logger.info("Password changed successfully for user %s", user_id)
return True return True
except AuthenticationError: except AuthenticationError:
@@ -222,7 +210,34 @@ class AuthService:
except Exception as e: except Exception as e:
# Rollback on any database errors # Rollback on any database errors
await db.rollback() await db.rollback()
logger.error( logger.exception("Error changing password for user %s: %s", user_id, e)
f"Error changing password for user {user_id}: {e!s}", exc_info=True
)
raise AuthenticationError(f"Failed to change password: {e!s}") raise AuthenticationError(f"Failed to change password: {e!s}")
@staticmethod
async def reset_password(
db: AsyncSession, *, email: str, new_password: str
) -> User:
"""
Reset a user's password without requiring the current password.
Args:
db: Database session
email: User email address
new_password: New password to set
Returns:
Updated user
Raises:
AuthenticationError: If user not found or inactive
"""
user = await user_repo.get_by_email(db, email=email)
if not user:
raise AuthenticationError("User not found")
if not user.is_active:
raise AuthenticationError("User account is inactive")
new_hash = await get_password_hash_async(new_password)
user = await user_repo.update_password(db, user=user, password_hash=new_hash)
logger.info("Password reset successfully for %s", email)
return user

View File

@@ -58,8 +58,8 @@ class ConsoleEmailBackend(EmailBackend):
logger.info("=" * 80) logger.info("=" * 80)
logger.info("EMAIL SENT (Console Backend)") logger.info("EMAIL SENT (Console Backend)")
logger.info("=" * 80) logger.info("=" * 80)
logger.info(f"To: {', '.join(to)}") logger.info("To: %s", ", ".join(to))
logger.info(f"Subject: {subject}") logger.info("Subject: %s", subject)
logger.info("-" * 80) logger.info("-" * 80)
if text_content: if text_content:
logger.info("Plain Text Content:") logger.info("Plain Text Content:")
@@ -199,7 +199,7 @@ The {settings.PROJECT_NAME} Team
text_content=text_content, text_content=text_content,
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to send password reset email to {to_email}: {e!s}") logger.error("Failed to send password reset email to %s: %s", to_email, e)
return False return False
async def send_email_verification( async def send_email_verification(
@@ -287,7 +287,7 @@ The {settings.PROJECT_NAME} Team
text_content=text_content, text_content=text_content,
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to send verification email to {to_email}: {e!s}") logger.error("Failed to send verification email to %s: %s", to_email, e)
return False return False

View File

@@ -0,0 +1,970 @@
"""
OAuth Provider Service for MCP integration.
Implements OAuth 2.0 Authorization Server functionality:
- Authorization code flow with PKCE
- Token issuance (JWT access tokens, opaque refresh tokens)
- Token refresh
- Token revocation
- Consent management
Security features:
- PKCE required for public clients (S256)
- Short-lived authorization codes (10 minutes)
- JWT access tokens (self-contained, no DB lookup)
- Secure refresh token storage (hashed)
- Token rotation on refresh
- Comprehensive validation
"""
import base64
import hashlib
import logging
import secrets
from datetime import UTC, datetime, timedelta
from typing import Any
from uuid import UUID
import jwt
from jwt.exceptions import ExpiredSignatureError, InvalidTokenError
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.models.oauth_client import OAuthClient
from app.models.user import User
from app.repositories.oauth_authorization_code import oauth_authorization_code_repo
from app.repositories.oauth_client import oauth_client_repo
from app.repositories.oauth_consent import oauth_consent_repo
from app.repositories.oauth_provider_token import oauth_provider_token_repo
from app.repositories.user import user_repo
from app.schemas.oauth import OAuthClientCreate
logger = logging.getLogger(__name__)
# Constants
AUTHORIZATION_CODE_EXPIRY_MINUTES = 10
ACCESS_TOKEN_EXPIRY_MINUTES = 60 # 1 hour for MCP clients
REFRESH_TOKEN_EXPIRY_DAYS = 30
class OAuthProviderError(Exception):
"""Base exception for OAuth provider errors."""
def __init__(
self,
error: str,
error_description: str | None = None,
error_uri: str | None = None,
):
self.error = error
self.error_description = error_description
self.error_uri = error_uri
super().__init__(error_description or error)
class InvalidClientError(OAuthProviderError):
"""Client authentication failed."""
def __init__(self, description: str = "Invalid client credentials"):
super().__init__("invalid_client", description)
class InvalidGrantError(OAuthProviderError):
"""Invalid authorization grant."""
def __init__(self, description: str = "Invalid grant"):
super().__init__("invalid_grant", description)
class InvalidRequestError(OAuthProviderError):
"""Invalid request parameters."""
def __init__(self, description: str = "Invalid request"):
super().__init__("invalid_request", description)
class InvalidScopeError(OAuthProviderError):
"""Invalid scope requested."""
def __init__(self, description: str = "Invalid scope"):
super().__init__("invalid_scope", description)
class UnauthorizedClientError(OAuthProviderError):
"""Client not authorized for this grant type."""
def __init__(self, description: str = "Unauthorized client"):
super().__init__("unauthorized_client", description)
class AccessDeniedError(OAuthProviderError):
"""User denied authorization."""
def __init__(self, description: str = "Access denied"):
super().__init__("access_denied", description)
# ============================================================================
# Helper Functions
# ============================================================================
def generate_code() -> str:
"""Generate a cryptographically secure authorization code."""
return secrets.token_urlsafe(64)
def generate_token() -> str:
"""Generate a cryptographically secure token."""
return secrets.token_urlsafe(48)
def generate_jti() -> str:
"""Generate a unique JWT ID."""
return secrets.token_urlsafe(32)
def hash_token(token: str) -> str:
"""Hash a token using SHA-256."""
return hashlib.sha256(token.encode()).hexdigest()
def verify_pkce(code_verifier: str, code_challenge: str, method: str) -> bool:
"""
Verify PKCE code_verifier against stored code_challenge.
SECURITY: Only S256 method is supported. The 'plain' method provides
no security benefit and is explicitly rejected per RFC 7636 Section 4.3.
"""
if method != "S256":
# SECURITY: Reject any method other than S256
# 'plain' method provides no security against code interception attacks
logger.warning("PKCE verification rejected for unsupported method: %s", method)
return False
# SHA-256 hash, then base64url encode (RFC 7636 Section 4.2)
digest = hashlib.sha256(code_verifier.encode()).digest()
computed = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
return secrets.compare_digest(computed, code_challenge)
def parse_scope(scope: str) -> list[str]:
"""Parse space-separated scope string into list."""
return [s.strip() for s in scope.split() if s.strip()]
def join_scope(scopes: list[str]) -> str:
"""Join scope list into space-separated string."""
return " ".join(sorted(set(scopes)))
# ============================================================================
# Client Validation
# ============================================================================
async def get_client(db: AsyncSession, client_id: str) -> OAuthClient | None:
"""Get OAuth client by client_id."""
return await oauth_client_repo.get_by_client_id(db, client_id=client_id)
async def validate_client(
db: AsyncSession,
client_id: str,
client_secret: str | None = None,
require_secret: bool = False,
) -> OAuthClient:
"""
Validate OAuth client credentials.
Args:
db: Database session
client_id: Client identifier
client_secret: Client secret (required for confidential clients)
require_secret: Whether to require secret validation
Returns:
Validated OAuthClient
Raises:
InvalidClientError: If client validation fails
"""
client = await get_client(db, client_id)
if not client:
raise InvalidClientError("Unknown client_id")
# Confidential clients must provide valid secret
if client.client_type == "confidential" or require_secret:
if not client_secret:
raise InvalidClientError("Client secret required")
if not client.client_secret_hash:
raise InvalidClientError("Client not configured with secret")
# SECURITY: Verify secret using bcrypt
from app.core.auth import verify_password
stored_hash = str(client.client_secret_hash)
if not stored_hash.startswith("$2"):
raise InvalidClientError(
"Client secret uses deprecated hash format. "
"Please regenerate your client credentials."
)
if not verify_password(client_secret, stored_hash):
raise InvalidClientError("Invalid client secret")
return client
def validate_redirect_uri(client: OAuthClient, redirect_uri: str) -> None:
"""
Validate redirect_uri against client's registered URIs.
Raises:
InvalidRequestError: If redirect_uri is not registered
"""
if not client.redirect_uris:
raise InvalidRequestError("Client has no registered redirect URIs")
if redirect_uri not in client.redirect_uris:
raise InvalidRequestError("Invalid redirect_uri")
def validate_scopes(client: OAuthClient, requested_scopes: list[str]) -> list[str]:
"""
Validate requested scopes against client's allowed scopes.
Returns:
List of valid scopes (intersection of requested and allowed)
Raises:
InvalidScopeError: If no valid scopes
"""
allowed = set(client.allowed_scopes or [])
requested = set(requested_scopes)
# If no scopes requested, use all allowed scopes
if not requested:
return list(allowed)
valid = requested & allowed
if not valid:
raise InvalidScopeError(
"None of the requested scopes are allowed for this client"
)
# Warn if some scopes were filtered out
invalid = requested - allowed
if invalid:
logger.warning(
"Client %s requested invalid scopes: %s", client.client_id, invalid
)
return list(valid)
# ============================================================================
# Authorization Code Flow
# ============================================================================
async def create_authorization_code(
db: AsyncSession,
client: OAuthClient,
user: User,
redirect_uri: str,
scope: str,
code_challenge: str | None = None,
code_challenge_method: str | None = None,
state: str | None = None,
nonce: str | None = None,
) -> str:
"""
Create an authorization code for the authorization code flow.
Args:
db: Database session
client: Validated OAuth client
user: Authenticated user
redirect_uri: Validated redirect URI
scope: Granted scopes (space-separated)
code_challenge: PKCE code challenge
code_challenge_method: PKCE method (S256)
state: CSRF state parameter
nonce: OpenID Connect nonce
Returns:
Authorization code string
"""
# Public clients MUST use PKCE
if client.client_type == "public":
if not code_challenge or code_challenge_method != "S256":
raise InvalidRequestError("PKCE with S256 is required for public clients")
code = generate_code()
expires_at = datetime.now(UTC) + timedelta(
minutes=AUTHORIZATION_CODE_EXPIRY_MINUTES
)
await oauth_authorization_code_repo.create_code(
db,
code=code,
client_id=client.client_id,
user_id=user.id,
redirect_uri=redirect_uri,
scope=scope,
expires_at=expires_at,
code_challenge=code_challenge,
code_challenge_method=code_challenge_method,
state=state,
nonce=nonce,
)
logger.info(
"Created authorization code for user %s and client %s",
user.id,
client.client_id,
)
return code
async def exchange_authorization_code(
db: AsyncSession,
code: str,
client_id: str,
redirect_uri: str,
code_verifier: str | None = None,
client_secret: str | None = None,
device_info: str | None = None,
ip_address: str | None = None,
) -> dict[str, Any]:
"""
Exchange authorization code for tokens.
Args:
db: Database session
code: Authorization code
client_id: Client identifier
redirect_uri: Must match the original redirect_uri
code_verifier: PKCE code verifier
client_secret: Client secret (for confidential clients)
device_info: Optional device information
ip_address: Optional IP address
Returns:
Token response dict with access_token, refresh_token, etc.
Raises:
InvalidGrantError: If code is invalid, expired, or already used
InvalidClientError: If client validation fails
"""
# Atomically mark code as used and fetch it (prevents race condition)
# RFC 6749 Section 4.1.2: Authorization codes MUST be single-use
updated_id = await oauth_authorization_code_repo.consume_code_atomically(
db, code=code
)
if not updated_id:
# Either code doesn't exist or was already used
# Check if it exists to provide appropriate error
existing_code = await oauth_authorization_code_repo.get_by_code(db, code=code)
if existing_code and existing_code.used:
# Code reuse is a security incident - revoke all tokens for this grant
logger.warning(
"Authorization code reuse detected for client %s",
existing_code.client_id,
)
await revoke_tokens_for_user_client(
db, UUID(str(existing_code.user_id)), str(existing_code.client_id)
)
raise InvalidGrantError("Authorization code has already been used")
else:
raise InvalidGrantError("Invalid authorization code")
# Now fetch the full auth code record
auth_code = await oauth_authorization_code_repo.get_by_id(db, code_id=updated_id)
if auth_code is None:
raise InvalidGrantError("Authorization code not found after consumption")
if auth_code.is_expired:
raise InvalidGrantError("Authorization code has expired")
if auth_code.client_id != client_id:
raise InvalidGrantError("Authorization code was not issued to this client")
if auth_code.redirect_uri != redirect_uri:
raise InvalidGrantError("redirect_uri mismatch")
# Validate client - ALWAYS require secret for confidential clients
client = await get_client(db, client_id)
if not client:
raise InvalidClientError("Unknown client_id")
# Confidential clients MUST authenticate (RFC 6749 Section 3.2.1)
if client.client_type == "confidential":
if not client_secret:
raise InvalidClientError("Client secret required for confidential clients")
client = await validate_client(
db, client_id, client_secret, require_secret=True
)
elif client_secret:
# Public client provided secret - validate it if given
client = await validate_client(
db, client_id, client_secret, require_secret=True
)
# Verify PKCE
if auth_code.code_challenge:
if not code_verifier:
raise InvalidGrantError("code_verifier required")
if not verify_pkce(
code_verifier,
str(auth_code.code_challenge),
str(auth_code.code_challenge_method or "S256"),
):
raise InvalidGrantError("Invalid code_verifier")
elif client.client_type == "public":
# Public clients without PKCE - this shouldn't happen if we validated on authorize
raise InvalidGrantError("PKCE required for public clients")
# Get user
user = await user_repo.get(db, id=str(auth_code.user_id))
if not user or not user.is_active:
raise InvalidGrantError("User not found or inactive")
# Generate tokens
return await create_tokens(
db=db,
client=client,
user=user,
scope=str(auth_code.scope),
nonce=str(auth_code.nonce) if auth_code.nonce else None,
device_info=device_info,
ip_address=ip_address,
)
# ============================================================================
# Token Generation
# ============================================================================
async def create_tokens(
db: AsyncSession,
client: OAuthClient,
user: User,
scope: str,
nonce: str | None = None,
device_info: str | None = None,
ip_address: str | None = None,
) -> dict[str, Any]:
"""
Create access and refresh tokens.
Args:
db: Database session
client: OAuth client
user: User
scope: Granted scopes
nonce: OpenID Connect nonce (included in ID token)
device_info: Optional device information
ip_address: Optional IP address
Returns:
Token response dict
"""
now = datetime.now(UTC)
jti = generate_jti()
# Access token expiry
access_token_lifetime = int(client.access_token_lifetime or "3600")
access_expires = now + timedelta(seconds=access_token_lifetime)
# Refresh token expiry
refresh_token_lifetime = int(
client.refresh_token_lifetime or str(REFRESH_TOKEN_EXPIRY_DAYS * 86400)
)
refresh_expires = now + timedelta(seconds=refresh_token_lifetime)
# Create JWT access token
# SECURITY: Include all standard JWT claims per RFC 7519
access_token_payload = {
"iss": settings.OAUTH_ISSUER,
"sub": str(user.id),
"aud": client.client_id,
"exp": int(access_expires.timestamp()),
"iat": int(now.timestamp()),
"nbf": int(now.timestamp()), # Not Before - token is valid immediately
"jti": jti,
"scope": scope,
"client_id": client.client_id,
# User info (basic claims)
"email": user.email,
"name": f"{user.first_name or ''} {user.last_name or ''}".strip() or user.email,
}
# Add nonce for OpenID Connect
if nonce:
access_token_payload["nonce"] = nonce
access_token = jwt.encode(
access_token_payload,
settings.SECRET_KEY,
algorithm=settings.ALGORITHM,
)
# Create opaque refresh token
refresh_token = generate_token()
refresh_token_hash = hash_token(refresh_token)
# Store refresh token in database
await oauth_provider_token_repo.create_token(
db,
token_hash=refresh_token_hash,
jti=jti,
client_id=client.client_id,
user_id=user.id,
scope=scope,
expires_at=refresh_expires,
device_info=device_info,
ip_address=ip_address,
)
logger.info("Issued tokens for user %s to client %s", user.id, client.client_id)
return {
"access_token": access_token,
"token_type": "Bearer",
"expires_in": access_token_lifetime,
"refresh_token": refresh_token,
"scope": scope,
}
async def refresh_tokens(
db: AsyncSession,
refresh_token: str,
client_id: str,
client_secret: str | None = None,
scope: str | None = None,
device_info: str | None = None,
ip_address: str | None = None,
) -> dict[str, Any]:
"""
Refresh access token using refresh token.
Implements token rotation - old refresh token is invalidated,
new refresh token is issued.
Args:
db: Database session
refresh_token: Refresh token
client_id: Client identifier
client_secret: Client secret (for confidential clients)
scope: Optional reduced scope
device_info: Optional device information
ip_address: Optional IP address
Returns:
New token response dict
Raises:
InvalidGrantError: If refresh token is invalid
"""
# Find refresh token
token_hash = hash_token(refresh_token)
token_record = await oauth_provider_token_repo.get_by_token_hash(
db, token_hash=token_hash
)
if not token_record:
raise InvalidGrantError("Invalid refresh token")
if token_record.revoked:
# Token reuse after revocation - security incident
logger.warning(
"Revoked refresh token reuse detected for client %s", token_record.client_id
)
raise InvalidGrantError("Refresh token has been revoked")
if token_record.is_expired:
raise InvalidGrantError("Refresh token has expired")
if token_record.client_id != client_id:
raise InvalidGrantError("Refresh token was not issued to this client")
# Validate client
client = await validate_client(
db,
client_id,
client_secret,
require_secret=(client_secret is not None),
)
# Get user
user = await user_repo.get(db, id=str(token_record.user_id))
if not user or not user.is_active:
raise InvalidGrantError("User not found or inactive")
# Validate scope (can only reduce, not expand)
token_scope = str(token_record.scope) if token_record.scope else ""
original_scopes = set(parse_scope(token_scope))
if scope:
requested_scopes = set(parse_scope(scope))
if not requested_scopes.issubset(original_scopes):
raise InvalidScopeError("Cannot expand scope on refresh")
final_scope = join_scope(list(requested_scopes))
else:
final_scope = token_scope
# Revoke old refresh token (token rotation)
await oauth_provider_token_repo.revoke(db, token=token_record)
# Issue new tokens
device = str(token_record.device_info) if token_record.device_info else None
ip_addr = str(token_record.ip_address) if token_record.ip_address else None
return await create_tokens(
db=db,
client=client,
user=user,
scope=final_scope,
device_info=device_info or device,
ip_address=ip_address or ip_addr,
)
# ============================================================================
# Token Revocation
# ============================================================================
async def revoke_token(
db: AsyncSession,
token: str,
token_type_hint: str | None = None,
client_id: str | None = None,
client_secret: str | None = None,
) -> bool:
"""
Revoke a token (access or refresh).
For refresh tokens: marks as revoked in database
For access tokens: we can't truly revoke JWTs, but we can revoke
the associated refresh token to prevent further refreshes
Args:
db: Database session
token: Token to revoke
token_type_hint: "access_token" or "refresh_token"
client_id: Client identifier (for validation)
client_secret: Client secret (for confidential clients)
Returns:
True if token was revoked, False if not found
"""
# Try as refresh token first (more likely)
if token_type_hint != "access_token":
token_hash = hash_token(token)
refresh_record = await oauth_provider_token_repo.get_by_token_hash(
db, token_hash=token_hash
)
if refresh_record:
# Validate client if provided
if client_id and refresh_record.client_id != client_id:
raise InvalidClientError("Token was not issued to this client")
await oauth_provider_token_repo.revoke(db, token=refresh_record)
logger.info("Revoked refresh token %s...", refresh_record.jti[:8])
return True
# Try as access token (JWT)
if token_type_hint != "refresh_token":
try:
payload = jwt.decode(
token,
settings.SECRET_KEY,
algorithms=[settings.ALGORITHM],
options={
"verify_exp": False,
"verify_aud": False,
}, # Allow expired tokens
)
jti = payload.get("jti")
if jti:
# Find and revoke the associated refresh token
refresh_record = await oauth_provider_token_repo.get_by_jti(db, jti=jti)
if refresh_record:
if client_id and refresh_record.client_id != client_id:
raise InvalidClientError("Token was not issued to this client")
await oauth_provider_token_repo.revoke(db, token=refresh_record)
logger.info(
"Revoked refresh token via access token JTI %s...", jti[:8]
)
return True
except InvalidTokenError:
pass
except Exception: # noqa: S110 - Intentional: invalid JWT not an error
pass
return False
async def revoke_tokens_for_user_client(
db: AsyncSession,
user_id: UUID,
client_id: str,
) -> int:
"""
Revoke all tokens for a specific user-client pair.
Used when security incidents are detected (e.g., code reuse).
Args:
db: Database session
user_id: User identifier
client_id: Client identifier
Returns:
Number of tokens revoked
"""
count = await oauth_provider_token_repo.revoke_all_for_user_client(
db, user_id=user_id, client_id=client_id
)
if count > 0:
logger.warning(
"Revoked %s tokens for user %s and client %s", count, user_id, client_id
)
return count
async def revoke_all_user_tokens(db: AsyncSession, user_id: UUID) -> int:
"""
Revoke all OAuth provider tokens for a user.
Used when user changes password or explicitly logs out everywhere.
Args:
db: Database session
user_id: User identifier
Returns:
Number of tokens revoked
"""
count = await oauth_provider_token_repo.revoke_all_for_user(db, user_id=user_id)
if count > 0:
logger.info("Revoked %s OAuth provider tokens for user %s", count, user_id)
return count
# ============================================================================
# Token Introspection (RFC 7662)
# ============================================================================
async def introspect_token(
db: AsyncSession,
token: str,
token_type_hint: str | None = None,
client_id: str | None = None,
client_secret: str | None = None,
) -> dict[str, Any]:
"""
Introspect a token to determine its validity and metadata.
Implements RFC 7662 Token Introspection.
Args:
db: Database session
token: Token to introspect
token_type_hint: "access_token" or "refresh_token"
client_id: Client requesting introspection
client_secret: Client secret
Returns:
Introspection response dict
"""
# Validate client if credentials provided
if client_id:
await validate_client(db, client_id, client_secret)
# Try as access token (JWT) first
if token_type_hint != "refresh_token":
try:
payload = jwt.decode(
token,
settings.SECRET_KEY,
algorithms=[settings.ALGORITHM],
options={
"verify_aud": False
}, # Don't require audience match for introspection
)
# Check if associated refresh token is revoked
jti = payload.get("jti")
if jti:
refresh_record = await oauth_provider_token_repo.get_by_jti(db, jti=jti)
if refresh_record and refresh_record.revoked:
return {"active": False}
return {
"active": True,
"scope": payload.get("scope", ""),
"client_id": payload.get("client_id"),
"username": payload.get("email"),
"token_type": "Bearer",
"exp": payload.get("exp"),
"iat": payload.get("iat"),
"sub": payload.get("sub"),
"aud": payload.get("aud"),
"iss": payload.get("iss"),
}
except ExpiredSignatureError:
return {"active": False}
except InvalidTokenError:
pass
except Exception: # noqa: S110 - Intentional: invalid JWT falls through to refresh token check
pass
# Try as refresh token
if token_type_hint != "access_token":
token_hash = hash_token(token)
refresh_record = await oauth_provider_token_repo.get_by_token_hash(
db, token_hash=token_hash
)
if refresh_record and refresh_record.is_valid:
return {
"active": True,
"scope": refresh_record.scope,
"client_id": refresh_record.client_id,
"token_type": "refresh_token",
"exp": int(refresh_record.expires_at.timestamp()),
"iat": int(refresh_record.created_at.timestamp()),
"sub": str(refresh_record.user_id),
}
return {"active": False}
# ============================================================================
# Consent Management
# ============================================================================
async def get_consent(
db: AsyncSession,
user_id: UUID,
client_id: str,
):
"""Get existing consent record for user-client pair."""
return await oauth_consent_repo.get_consent(
db, user_id=user_id, client_id=client_id
)
async def check_consent(
db: AsyncSession,
user_id: UUID,
client_id: str,
requested_scopes: list[str],
) -> bool:
"""
Check if user has already consented to the requested scopes.
Returns True if all requested scopes are already granted.
"""
consent = await get_consent(db, user_id, client_id)
if not consent:
return False
return consent.has_scopes(requested_scopes)
async def grant_consent(
db: AsyncSession,
user_id: UUID,
client_id: str,
scopes: list[str],
):
"""
Grant or update consent for a user-client pair.
If consent already exists, updates the granted scopes.
"""
return await oauth_consent_repo.grant_consent(
db, user_id=user_id, client_id=client_id, scopes=scopes
)
async def revoke_consent(
db: AsyncSession,
user_id: UUID,
client_id: str,
) -> bool:
"""
Revoke consent and all tokens for a user-client pair.
Returns True if consent was found and revoked.
"""
# Revoke all tokens first
await revoke_tokens_for_user_client(db, user_id, client_id)
# Delete consent record
return await oauth_consent_repo.revoke_consent(
db, user_id=user_id, client_id=client_id
)
# ============================================================================
# Cleanup
# ============================================================================
async def register_client(db: AsyncSession, client_data: OAuthClientCreate) -> tuple:
"""Create a new OAuth client. Returns (client, secret)."""
return await oauth_client_repo.create_client(db, obj_in=client_data)
async def list_clients(db: AsyncSession) -> list:
"""List all registered OAuth clients."""
return await oauth_client_repo.get_all_clients(db)
async def delete_client_by_id(db: AsyncSession, client_id: str) -> None:
"""Delete an OAuth client by client_id."""
await oauth_client_repo.delete_client(db, client_id=client_id)
async def list_user_consents(db: AsyncSession, user_id: UUID) -> list[dict]:
"""Get all OAuth consents for a user with client details."""
return await oauth_consent_repo.get_user_consents_with_clients(db, user_id=user_id)
async def cleanup_expired_codes(db: AsyncSession) -> int:
"""
Delete expired authorization codes.
Should be called periodically (e.g., every hour).
Returns:
Number of codes deleted
"""
return await oauth_authorization_code_repo.cleanup_expired(db)
async def cleanup_expired_tokens(db: AsyncSession) -> int:
"""
Delete expired and revoked refresh tokens.
Should be called periodically (e.g., daily).
Returns:
Number of tokens deleted
"""
return await oauth_provider_token_repo.cleanup_expired(db, cutoff_days=7)

View File

@@ -0,0 +1,744 @@
"""
OAuth Service for handling social authentication flows.
Supports:
- Google OAuth (OpenID Connect)
- GitHub OAuth
Features:
- PKCE support for public clients
- State parameter for CSRF protection
- Auto-linking by email (configurable)
- Account linking for existing users
"""
import logging
import secrets
from datetime import UTC, datetime, timedelta
from typing import TypedDict, cast
from uuid import UUID
from authlib.integrations.httpx_client import AsyncOAuth2Client
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.auth import create_access_token, create_refresh_token
from app.core.config import settings
from app.core.exceptions import AuthenticationError
from app.models.user import User
from app.repositories.oauth_account import oauth_account_repo as oauth_account
from app.repositories.oauth_state import oauth_state_repo as oauth_state
from app.repositories.user import user_repo
from app.schemas.oauth import (
OAuthAccountCreate,
OAuthCallbackResponse,
OAuthProviderInfo,
OAuthProvidersResponse,
OAuthStateCreate,
)
logger = logging.getLogger(__name__)
class _OAuthProviderConfigRequired(TypedDict):
name: str
icon: str
authorize_url: str
token_url: str
userinfo_url: str
scopes: list[str]
supports_pkce: bool
class OAuthProviderConfig(_OAuthProviderConfigRequired, total=False):
"""Type definition for OAuth provider configuration."""
email_url: str # Optional, GitHub-only
# Provider configurations
OAUTH_PROVIDERS: dict[str, OAuthProviderConfig] = {
"google": {
"name": "Google",
"icon": "google",
"authorize_url": "https://accounts.google.com/o/oauth2/v2/auth",
"token_url": "https://oauth2.googleapis.com/token",
"userinfo_url": "https://www.googleapis.com/oauth2/v3/userinfo",
"scopes": ["openid", "email", "profile"],
"supports_pkce": True,
},
"github": {
"name": "GitHub",
"icon": "github",
"authorize_url": "https://github.com/login/oauth/authorize",
"token_url": "https://github.com/login/oauth/access_token",
"userinfo_url": "https://api.github.com/user",
"email_url": "https://api.github.com/user/emails",
"scopes": ["read:user", "user:email"],
"supports_pkce": False, # GitHub doesn't support PKCE
},
}
class OAuthService:
"""Service for handling OAuth authentication flows."""
@staticmethod
def get_enabled_providers() -> OAuthProvidersResponse:
"""
Get list of enabled OAuth providers.
Returns:
OAuthProvidersResponse with enabled providers
"""
providers = []
for provider_id in settings.enabled_oauth_providers:
if provider_id in OAUTH_PROVIDERS:
config = OAUTH_PROVIDERS[provider_id]
providers.append(
OAuthProviderInfo(
provider=provider_id,
name=config["name"],
icon=config["icon"],
)
)
return OAuthProvidersResponse(
enabled=settings.OAUTH_ENABLED and len(providers) > 0,
providers=providers,
)
@staticmethod
def _get_provider_credentials(provider: str) -> tuple[str, str]:
"""Get client ID and secret for a provider."""
if provider == "google":
client_id = settings.OAUTH_GOOGLE_CLIENT_ID
client_secret = settings.OAUTH_GOOGLE_CLIENT_SECRET
elif provider == "github":
client_id = settings.OAUTH_GITHUB_CLIENT_ID
client_secret = settings.OAUTH_GITHUB_CLIENT_SECRET
else:
raise AuthenticationError(f"Unknown OAuth provider: {provider}")
if not client_id or not client_secret:
raise AuthenticationError(f"OAuth provider {provider} is not configured")
return client_id, client_secret
@staticmethod
async def create_authorization_url(
db: AsyncSession,
*,
provider: str,
redirect_uri: str,
user_id: str | None = None,
) -> tuple[str, str]:
"""
Create OAuth authorization URL with state and optional PKCE.
Args:
db: Database session
provider: OAuth provider (google, github)
redirect_uri: Callback URL after OAuth
user_id: User ID if linking account (user is logged in)
Returns:
Tuple of (authorization_url, state)
Raises:
AuthenticationError: If provider is not configured
"""
if not settings.OAUTH_ENABLED:
raise AuthenticationError("OAuth is not enabled")
if provider not in OAUTH_PROVIDERS:
raise AuthenticationError(f"Unknown OAuth provider: {provider}")
if provider not in settings.enabled_oauth_providers:
raise AuthenticationError(f"OAuth provider {provider} is not enabled")
config = OAUTH_PROVIDERS[provider]
client_id, client_secret = OAuthService._get_provider_credentials(provider)
# Generate state for CSRF protection
state = secrets.token_urlsafe(32)
# Generate PKCE code verifier and challenge if supported
code_verifier = None
code_challenge = None
if config.get("supports_pkce"):
code_verifier = secrets.token_urlsafe(64)
# Create code_challenge using S256 method
import base64
import hashlib
code_challenge_bytes = hashlib.sha256(code_verifier.encode()).digest()
code_challenge = (
base64.urlsafe_b64encode(code_challenge_bytes).decode().rstrip("=")
)
# Generate nonce for OIDC (Google)
nonce = secrets.token_urlsafe(32) if provider == "google" else None
# Store state in database
from uuid import UUID
state_data = OAuthStateCreate(
state=state,
code_verifier=code_verifier,
nonce=nonce,
provider=provider,
redirect_uri=redirect_uri,
user_id=UUID(user_id) if user_id else None,
expires_at=datetime.now(UTC)
+ timedelta(minutes=settings.OAUTH_STATE_EXPIRE_MINUTES),
)
await oauth_state.create_state(db, obj_in=state_data)
# Build authorization URL
async with AsyncOAuth2Client(
client_id=client_id,
client_secret=client_secret,
redirect_uri=redirect_uri,
) as client:
# Prepare authorization params
auth_params = {
"state": state,
"scope": " ".join(config["scopes"]),
}
if code_challenge:
auth_params["code_challenge"] = code_challenge
auth_params["code_challenge_method"] = "S256"
if nonce:
auth_params["nonce"] = nonce
url, _ = client.create_authorization_url(
config["authorize_url"],
**auth_params,
)
logger.info("OAuth authorization URL created for %s", provider)
return url, state
@staticmethod
async def handle_callback(
db: AsyncSession,
*,
code: str,
state: str,
redirect_uri: str,
) -> OAuthCallbackResponse:
"""
Handle OAuth callback and authenticate/create user.
Args:
db: Database session
code: Authorization code from provider
state: State parameter for CSRF verification
redirect_uri: Callback URL (must match authorization request)
Returns:
OAuthCallbackResponse with tokens
Raises:
AuthenticationError: If authentication fails
"""
# Validate and consume state
state_record = await oauth_state.get_and_consume_state(db, state=state)
if not state_record:
raise AuthenticationError("Invalid or expired OAuth state")
# SECURITY: Validate redirect_uri matches the one from authorization request
# This prevents authorization code injection attacks (RFC 6749 Section 10.6)
if state_record.redirect_uri != redirect_uri:
logger.warning(
"OAuth redirect_uri mismatch: expected %s, got %s",
state_record.redirect_uri,
redirect_uri,
)
raise AuthenticationError("Redirect URI mismatch")
# Extract provider from state record (str for type safety)
provider: str = str(state_record.provider)
if provider not in OAUTH_PROVIDERS:
raise AuthenticationError(f"Unknown OAuth provider: {provider}")
config = OAUTH_PROVIDERS[provider]
client_id, client_secret = OAuthService._get_provider_credentials(provider)
# Exchange code for tokens
async with AsyncOAuth2Client(
client_id=client_id,
client_secret=client_secret,
redirect_uri=redirect_uri,
) as client:
try:
# Prepare token request params
token_params: dict[str, str] = {"code": code}
if state_record.code_verifier:
token_params["code_verifier"] = str(state_record.code_verifier)
token = await client.fetch_token(
config["token_url"],
**token_params,
)
# SECURITY: Validate ID token signature and nonce for OpenID Connect
# This prevents token forgery and replay attacks (OIDC Core 3.1.3.7)
if provider == "google" and state_record.nonce:
id_token = token.get("id_token")
if id_token:
await OAuthService._verify_google_id_token(
id_token=str(id_token),
expected_nonce=str(state_record.nonce),
client_id=client_id,
)
except AuthenticationError:
raise
except Exception as e:
logger.error("OAuth token exchange failed: %s", e)
raise AuthenticationError("Failed to exchange authorization code")
# Get user info from provider
try:
access_token = token.get("access_token")
if not access_token:
raise AuthenticationError("No access token received")
user_info = await OAuthService._get_user_info(
client, provider, config, access_token
)
except Exception as e:
logger.error("Failed to get user info: %s", e)
raise AuthenticationError(
"Failed to get user information from provider"
)
# Process user info and create/link account
provider_user_id = str(user_info.get("id") or user_info.get("sub"))
# Email can be None if user didn't grant email permission
# SECURITY: Normalize email (lowercase, strip) to prevent case-based account duplication
email_raw = user_info.get("email")
provider_email: str | None = (
str(email_raw).lower().strip() if email_raw else None
)
if not provider_user_id:
raise AuthenticationError("Provider did not return user ID")
# Check if this OAuth account already exists
existing_oauth = await oauth_account.get_by_provider_id(
db, provider=provider, provider_user_id=provider_user_id
)
is_new_user = False
if existing_oauth:
# Existing OAuth account - login
user = existing_oauth.user
if not user.is_active:
raise AuthenticationError("User account is inactive")
# Update tokens if stored
if token.get("access_token"):
await oauth_account.update_tokens(
db,
account=existing_oauth,
access_token=token.get("access_token"),
refresh_token=token.get("refresh_token"),
token_expires_at=datetime.now(UTC)
+ timedelta(seconds=token.get("expires_in", 3600)),
)
logger.info("OAuth login successful for %s via %s", user.email, provider)
elif state_record.user_id:
# Account linking flow (user is already logged in)
user = await user_repo.get(db, id=str(state_record.user_id))
if not user:
raise AuthenticationError("User not found for account linking")
# Check if user already has this provider linked
user_id = cast(UUID, user.id)
existing_provider = await oauth_account.get_user_account_by_provider(
db, user_id=user_id, provider=provider
)
if existing_provider:
raise AuthenticationError(
f"You already have a {provider} account linked"
)
# Create OAuth account link
oauth_create = OAuthAccountCreate(
user_id=user_id,
provider=provider,
provider_user_id=provider_user_id,
provider_email=provider_email,
access_token=token.get("access_token"),
refresh_token=token.get("refresh_token"),
token_expires_at=datetime.now(UTC)
+ timedelta(seconds=token.get("expires_in", 3600))
if token.get("expires_in")
else None,
)
await oauth_account.create_account(db, obj_in=oauth_create)
logger.info("OAuth account linked: %s -> %s", provider, user.email)
else:
# New OAuth login - check for existing user by email
user = None
if provider_email and settings.OAUTH_AUTO_LINK_BY_EMAIL:
user = await user_repo.get_by_email(db, email=provider_email)
if user:
# Auto-link to existing user
if not user.is_active:
raise AuthenticationError("User account is inactive")
# Check if user already has this provider linked
user_id = cast(UUID, user.id)
existing_provider = await oauth_account.get_user_account_by_provider(
db, user_id=user_id, provider=provider
)
if existing_provider:
# This shouldn't happen if we got here, but safety check
logger.warning(
"OAuth account already linked (race condition?): %s -> %s",
provider,
user.email,
)
else:
# Create OAuth account link
oauth_create = OAuthAccountCreate(
user_id=user_id,
provider=provider,
provider_user_id=provider_user_id,
provider_email=provider_email,
access_token=token.get("access_token"),
refresh_token=token.get("refresh_token"),
token_expires_at=datetime.now(UTC)
+ timedelta(seconds=token.get("expires_in", 3600))
if token.get("expires_in")
else None,
)
await oauth_account.create_account(db, obj_in=oauth_create)
logger.info(
"OAuth auto-linked by email: %s -> %s", provider, user.email
)
else:
# Create new user
if not provider_email:
raise AuthenticationError(
f"Email is required for registration. "
f"Please grant email permission to {provider}."
)
user = await OAuthService._create_oauth_user(
db,
email=provider_email,
provider=provider,
provider_user_id=provider_user_id,
user_info=user_info,
token=token,
)
is_new_user = True
logger.info("New user created via OAuth: %s (%s)", user.email, provider)
# Generate JWT tokens
claims = {
"is_superuser": user.is_superuser,
"email": user.email,
"first_name": user.first_name,
}
access_token_jwt = create_access_token(subject=str(user.id), claims=claims)
refresh_token_jwt = create_refresh_token(subject=str(user.id))
return OAuthCallbackResponse(
access_token=access_token_jwt,
refresh_token=refresh_token_jwt,
token_type="bearer",
expires_in=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60,
is_new_user=is_new_user,
)
@staticmethod
async def _get_user_info(
client: AsyncOAuth2Client,
provider: str,
config: OAuthProviderConfig,
access_token: str,
) -> dict[str, object]:
"""Get user info from OAuth provider."""
headers = {"Authorization": f"Bearer {access_token}"}
if provider == "github":
# GitHub returns JSON with Accept header
headers["Accept"] = "application/vnd.github+json"
resp = await client.get(config["userinfo_url"], headers=headers)
resp.raise_for_status()
user_info = resp.json()
# GitHub requires separate request for email
if provider == "github" and not user_info.get("email"):
email_resp = await client.get(
config["email_url"], # pyright: ignore[reportTypedDictNotRequiredAccess]
headers=headers,
)
email_resp.raise_for_status()
emails = email_resp.json()
# Find primary verified email
for email_data in emails:
if email_data.get("primary") and email_data.get("verified"):
user_info["email"] = email_data["email"]
break
return user_info
# Google's OIDC configuration endpoints
GOOGLE_JWKS_URL = "https://www.googleapis.com/oauth2/v3/certs"
GOOGLE_ISSUERS = ("https://accounts.google.com", "accounts.google.com")
@staticmethod
async def _verify_google_id_token(
id_token: str,
expected_nonce: str,
client_id: str,
) -> dict[str, object]:
"""
Verify Google ID token signature and claims.
SECURITY: This properly verifies the ID token by:
1. Fetching Google's public keys (JWKS)
2. Verifying the JWT signature against the public key
3. Validating issuer, audience, expiry, and nonce claims
Args:
id_token: The ID token JWT string
expected_nonce: The nonce we sent in the authorization request
client_id: Our OAuth client ID (expected audience)
Returns:
Decoded ID token payload
Raises:
AuthenticationError: If verification fails
"""
import httpx
import jwt as pyjwt
from jwt.algorithms import RSAAlgorithm
from jwt.exceptions import InvalidTokenError
try:
# Fetch Google's public keys (JWKS)
# In production, consider caching this with TTL matching Cache-Control header
async with httpx.AsyncClient() as client:
jwks_response = await client.get(
OAuthService.GOOGLE_JWKS_URL,
timeout=10.0,
)
jwks_response.raise_for_status()
jwks = jwks_response.json()
# Get the key ID from the token header
unverified_header = pyjwt.get_unverified_header(id_token)
kid = unverified_header.get("kid")
if not kid:
raise AuthenticationError("ID token missing key ID (kid)")
# Find the matching public key
jwk_data = None
for key in jwks.get("keys", []):
if key.get("kid") == kid:
jwk_data = key
break
if not jwk_data:
raise AuthenticationError("ID token signed with unknown key")
# Convert JWK to a public key object for PyJWT
public_key = RSAAlgorithm.from_jwk(jwk_data)
# Verify the token signature and decode claims
# PyJWT will verify signature against the RSA public key
payload = pyjwt.decode(
id_token,
public_key,
algorithms=["RS256"], # Google uses RS256
audience=client_id,
issuer=OAuthService.GOOGLE_ISSUERS,
options={
"verify_signature": True,
"verify_aud": True,
"verify_iss": True,
"verify_exp": True,
"verify_iat": True,
},
)
# Verify nonce (OIDC replay attack protection)
token_nonce = payload.get("nonce")
if token_nonce != expected_nonce:
logger.warning(
"OAuth ID token nonce mismatch: expected %s, got %s",
expected_nonce,
token_nonce,
)
raise AuthenticationError("Invalid ID token nonce")
logger.debug("Google ID token verified successfully")
return payload
except InvalidTokenError as e:
logger.warning("Google ID token verification failed: %s", e)
raise AuthenticationError("Invalid ID token signature")
except httpx.HTTPError as e:
logger.error("Failed to fetch Google JWKS: %s", e)
# If we can't verify the ID token, fail closed for security
raise AuthenticationError("Failed to verify ID token")
except Exception as e:
logger.error("Unexpected error verifying Google ID token: %s", e)
raise AuthenticationError("ID token verification error")
@staticmethod
async def _create_oauth_user(
db: AsyncSession,
*,
email: str,
provider: str,
provider_user_id: str,
user_info: dict,
token: dict,
) -> User:
"""Create a new user from OAuth provider data."""
# Extract name from user_info
first_name = "User"
last_name = None
if provider == "google":
first_name = user_info.get("given_name") or user_info.get("name", "User")
last_name = user_info.get("family_name")
elif provider == "github":
# GitHub has full name, try to split
name = user_info.get("name") or user_info.get("login", "User")
parts = name.split(" ", 1)
first_name = parts[0]
last_name = parts[1] if len(parts) > 1 else None
# Create user (no password for OAuth-only users)
user = User(
email=email,
password_hash=None, # OAuth-only user
first_name=first_name,
last_name=last_name,
is_active=True,
is_superuser=False,
)
db.add(user)
await db.flush() # Get user.id
# Create OAuth account link
user_id = cast(UUID, user.id)
oauth_create = OAuthAccountCreate(
user_id=user_id,
provider=provider,
provider_user_id=provider_user_id,
provider_email=email,
access_token=token.get("access_token"),
refresh_token=token.get("refresh_token"),
token_expires_at=datetime.now(UTC)
+ timedelta(seconds=token.get("expires_in", 3600))
if token.get("expires_in")
else None,
)
await oauth_account.create_account(db, obj_in=oauth_create)
await db.refresh(user)
return user
@staticmethod
async def unlink_provider(
db: AsyncSession,
*,
user: User,
provider: str,
) -> bool:
"""
Unlink an OAuth provider from a user account.
Args:
db: Database session
user: User to unlink from
provider: Provider to unlink
Returns:
True if unlinked successfully
Raises:
AuthenticationError: If unlinking would leave user without login method
"""
# Check if user can safely remove this OAuth account
# Note: We query directly instead of using user.can_remove_oauth property
# because the property uses lazy loading which doesn't work in async context
user_id = cast(UUID, user.id)
has_password = user.password_hash is not None
oauth_accounts = await oauth_account.get_user_accounts(db, user_id=user_id)
can_remove = has_password or len(oauth_accounts) > 1
if not can_remove:
raise AuthenticationError(
"Cannot unlink OAuth account. You must have either a password set "
"or at least one other OAuth provider linked."
)
deleted = await oauth_account.delete_account(
db, user_id=user_id, provider=provider
)
if not deleted:
raise AuthenticationError(f"No {provider} account found to unlink")
logger.info("OAuth provider unlinked: %s from %s", provider, user.email)
return True
@staticmethod
async def get_user_accounts(db: AsyncSession, *, user_id: UUID) -> list:
"""Get all OAuth accounts linked to a user."""
return await oauth_account.get_user_accounts(db, user_id=user_id)
@staticmethod
async def get_user_account_by_provider(
db: AsyncSession, *, user_id: UUID, provider: str
):
"""Get a specific OAuth account for a user and provider."""
return await oauth_account.get_user_account_by_provider(
db, user_id=user_id, provider=provider
)
@staticmethod
async def cleanup_expired_states(db: AsyncSession) -> int:
"""
Clean up expired OAuth states.
Should be called periodically (e.g., by a background task).
Args:
db: Database session
Returns:
Number of states cleaned up
"""
return await oauth_state.cleanup_expired(db)

View File

@@ -0,0 +1,155 @@
# app/services/organization_service.py
"""Service layer for organization operations — delegates to OrganizationRepository."""
import logging
from typing import Any
from uuid import UUID
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.exceptions import NotFoundError
from app.models.organization import Organization
from app.models.user_organization import OrganizationRole, UserOrganization
from app.repositories.organization import OrganizationRepository, organization_repo
from app.schemas.organizations import OrganizationCreate, OrganizationUpdate
logger = logging.getLogger(__name__)
class OrganizationService:
"""Service for organization management operations."""
def __init__(
self, organization_repository: OrganizationRepository | None = None
) -> None:
self._repo = organization_repository or organization_repo
async def get_organization(self, db: AsyncSession, org_id: str) -> Organization:
"""Get organization by ID, raising NotFoundError if not found."""
org = await self._repo.get(db, id=org_id)
if not org:
raise NotFoundError(f"Organization {org_id} not found")
return org
async def create_organization(
self, db: AsyncSession, *, obj_in: OrganizationCreate
) -> Organization:
"""Create a new organization."""
return await self._repo.create(db, obj_in=obj_in)
async def update_organization(
self,
db: AsyncSession,
*,
org: Organization,
obj_in: OrganizationUpdate | dict[str, Any],
) -> Organization:
"""Update an existing organization."""
return await self._repo.update(db, db_obj=org, obj_in=obj_in)
async def remove_organization(self, db: AsyncSession, org_id: str) -> None:
"""Permanently delete an organization by ID."""
await self._repo.remove(db, id=org_id)
async def get_member_count(self, db: AsyncSession, *, organization_id: UUID) -> int:
"""Get number of active members in an organization."""
return await self._repo.get_member_count(db, organization_id=organization_id)
async def get_multi_with_member_counts(
self,
db: AsyncSession,
*,
skip: int = 0,
limit: int = 100,
is_active: bool | None = None,
search: str | None = None,
) -> tuple[list[dict[str, Any]], int]:
"""List organizations with member counts and pagination."""
return await self._repo.get_multi_with_member_counts(
db, skip=skip, limit=limit, is_active=is_active, search=search
)
async def get_user_organizations_with_details(
self,
db: AsyncSession,
*,
user_id: UUID,
is_active: bool | None = None,
) -> list[dict[str, Any]]:
"""Get all organizations a user belongs to, with membership details."""
return await self._repo.get_user_organizations_with_details(
db, user_id=user_id, is_active=is_active
)
async def get_organization_members(
self,
db: AsyncSession,
*,
organization_id: UUID,
skip: int = 0,
limit: int = 100,
is_active: bool | None = True,
) -> tuple[list[dict[str, Any]], int]:
"""Get members of an organization with pagination."""
return await self._repo.get_organization_members(
db,
organization_id=organization_id,
skip=skip,
limit=limit,
is_active=is_active,
)
async def add_member(
self,
db: AsyncSession,
*,
organization_id: UUID,
user_id: UUID,
role: OrganizationRole = OrganizationRole.MEMBER,
) -> UserOrganization:
"""Add a user to an organization."""
return await self._repo.add_user(
db, organization_id=organization_id, user_id=user_id, role=role
)
async def remove_member(
self,
db: AsyncSession,
*,
organization_id: UUID,
user_id: UUID,
) -> bool:
"""Remove a user from an organization. Returns True if found and removed."""
return await self._repo.remove_user(
db, organization_id=organization_id, user_id=user_id
)
async def get_user_role_in_org(
self, db: AsyncSession, *, user_id: UUID, organization_id: UUID
) -> OrganizationRole | None:
"""Get the role of a user in an organization."""
return await self._repo.get_user_role_in_org(
db, user_id=user_id, organization_id=organization_id
)
async def get_org_distribution(
self, db: AsyncSession, *, limit: int = 6
) -> list[dict[str, Any]]:
"""Return top organizations by member count for admin dashboard."""
from sqlalchemy import func, select
result = await db.execute(
select(
Organization.name,
func.count(UserOrganization.user_id).label("count"),
)
.join(UserOrganization, Organization.id == UserOrganization.organization_id)
.group_by(Organization.name)
.order_by(func.count(UserOrganization.user_id).desc())
.limit(limit)
)
return [{"name": row.name, "value": row.count} for row in result.all()]
# Default singleton
organization_service = OrganizationService()

View File

@@ -8,7 +8,7 @@ import logging
from datetime import UTC, datetime from datetime import UTC, datetime
from app.core.database import SessionLocal from app.core.database import SessionLocal
from app.crud.session import session as session_crud from app.repositories.session import session_repo as session_repo
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -32,15 +32,15 @@ async def cleanup_expired_sessions(keep_days: int = 30) -> int:
async with SessionLocal() as db: async with SessionLocal() as db:
try: try:
# Use CRUD method to cleanup # Use repository method to cleanup
count = await session_crud.cleanup_expired(db, keep_days=keep_days) count = await session_repo.cleanup_expired(db, keep_days=keep_days)
logger.info(f"Session cleanup complete: {count} sessions deleted") logger.info("Session cleanup complete: %s sessions deleted", count)
return count return count
except Exception as e: except Exception as e:
logger.error(f"Error during session cleanup: {e!s}", exc_info=True) logger.exception("Error during session cleanup: %s", e)
return 0 return 0
@@ -79,10 +79,10 @@ async def get_session_statistics() -> dict:
"expired": expired_sessions, "expired": expired_sessions,
} }
logger.info(f"Session statistics: {stats}") logger.info("Session statistics: %s", stats)
return stats return stats
except Exception as e: except Exception as e:
logger.error(f"Error getting session statistics: {e!s}", exc_info=True) logger.exception("Error getting session statistics: %s", e)
return {} return {}

View File

@@ -0,0 +1,97 @@
# app/services/session_service.py
"""Service layer for session operations — delegates to SessionRepository."""
import logging
from datetime import datetime
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.user_session import UserSession
from app.repositories.session import SessionRepository, session_repo
from app.schemas.sessions import SessionCreate
logger = logging.getLogger(__name__)
class SessionService:
"""Service for user session management operations."""
def __init__(self, session_repository: SessionRepository | None = None) -> None:
self._repo = session_repository or session_repo
async def create_session(
self, db: AsyncSession, *, obj_in: SessionCreate
) -> UserSession:
"""Create a new session record."""
return await self._repo.create_session(db, obj_in=obj_in)
async def get_session(
self, db: AsyncSession, session_id: str
) -> UserSession | None:
"""Get session by ID."""
return await self._repo.get(db, id=session_id)
async def get_user_sessions(
self, db: AsyncSession, *, user_id: str, active_only: bool = True
) -> list[UserSession]:
"""Get all sessions for a user."""
return await self._repo.get_user_sessions(
db, user_id=user_id, active_only=active_only
)
async def get_active_by_jti(
self, db: AsyncSession, *, jti: str
) -> UserSession | None:
"""Get active session by refresh token JTI."""
return await self._repo.get_active_by_jti(db, jti=jti)
async def get_by_jti(self, db: AsyncSession, *, jti: str) -> UserSession | None:
"""Get session by refresh token JTI (active or inactive)."""
return await self._repo.get_by_jti(db, jti=jti)
async def deactivate(
self, db: AsyncSession, *, session_id: str
) -> UserSession | None:
"""Deactivate a session (logout from device)."""
return await self._repo.deactivate(db, session_id=session_id)
async def deactivate_all_user_sessions(
self, db: AsyncSession, *, user_id: str
) -> int:
"""Deactivate all sessions for a user. Returns count deactivated."""
return await self._repo.deactivate_all_user_sessions(db, user_id=user_id)
async def update_refresh_token(
self,
db: AsyncSession,
*,
session: UserSession,
new_jti: str,
new_expires_at: datetime,
) -> UserSession:
"""Update session with a rotated refresh token."""
return await self._repo.update_refresh_token(
db, session=session, new_jti=new_jti, new_expires_at=new_expires_at
)
async def cleanup_expired_for_user(self, db: AsyncSession, *, user_id: str) -> int:
"""Remove expired sessions for a user. Returns count removed."""
return await self._repo.cleanup_expired_for_user(db, user_id=user_id)
async def get_all_sessions(
self,
db: AsyncSession,
*,
skip: int = 0,
limit: int = 100,
active_only: bool = True,
with_user: bool = True,
) -> tuple[list[UserSession], int]:
"""Get all sessions with pagination (admin only)."""
return await self._repo.get_all_sessions(
db, skip=skip, limit=limit, active_only=active_only, with_user=with_user
)
# Default singleton
session_service = SessionService()

View File

@@ -0,0 +1,120 @@
# app/services/user_service.py
"""Service layer for user operations — delegates to UserRepository."""
import logging
from typing import Any
from uuid import UUID
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.exceptions import NotFoundError
from app.models.user import User
from app.repositories.user import UserRepository, user_repo
from app.schemas.users import UserCreate, UserUpdate
logger = logging.getLogger(__name__)
class UserService:
"""Service for user management operations."""
def __init__(self, user_repository: UserRepository | None = None) -> None:
self._repo = user_repository or user_repo
async def get_user(self, db: AsyncSession, user_id: str) -> User:
"""Get user by ID, raising NotFoundError if not found."""
user = await self._repo.get(db, id=user_id)
if not user:
raise NotFoundError(f"User {user_id} not found")
return user
async def get_by_email(self, db: AsyncSession, email: str) -> User | None:
"""Get user by email address."""
return await self._repo.get_by_email(db, email=email)
async def create_user(self, db: AsyncSession, user_data: UserCreate) -> User:
"""Create a new user."""
return await self._repo.create(db, obj_in=user_data)
async def update_user(
self, db: AsyncSession, *, user: User, obj_in: UserUpdate | dict[str, Any]
) -> User:
"""Update an existing user."""
return await self._repo.update(db, db_obj=user, obj_in=obj_in)
async def soft_delete_user(self, db: AsyncSession, user_id: str) -> None:
"""Soft-delete a user by ID."""
await self._repo.soft_delete(db, id=user_id)
async def list_users(
self,
db: AsyncSession,
*,
skip: int = 0,
limit: int = 100,
sort_by: str | None = None,
sort_order: str = "asc",
filters: dict[str, Any] | None = None,
search: str | None = None,
) -> tuple[list[User], int]:
"""List users with pagination, sorting, filtering, and search."""
return await self._repo.get_multi_with_total(
db,
skip=skip,
limit=limit,
sort_by=sort_by,
sort_order=sort_order,
filters=filters,
search=search,
)
async def bulk_update_status(
self, db: AsyncSession, *, user_ids: list[UUID], is_active: bool
) -> int:
"""Bulk update active status for multiple users. Returns count updated."""
return await self._repo.bulk_update_status(
db, user_ids=user_ids, is_active=is_active
)
async def bulk_soft_delete(
self,
db: AsyncSession,
*,
user_ids: list[UUID],
exclude_user_id: UUID | None = None,
) -> int:
"""Bulk soft-delete multiple users. Returns count deleted."""
return await self._repo.bulk_soft_delete(
db, user_ids=user_ids, exclude_user_id=exclude_user_id
)
async def get_stats(self, db: AsyncSession) -> dict[str, Any]:
"""Return user stats needed for the admin dashboard."""
from sqlalchemy import func, select
total_users = (
await db.execute(select(func.count()).select_from(User))
).scalar() or 0
active_count = (
await db.execute(
select(func.count()).select_from(User).where(User.is_active)
)
).scalar() or 0
inactive_count = (
await db.execute(
select(func.count()).select_from(User).where(User.is_active.is_(False))
)
).scalar() or 0
all_users = list(
(await db.execute(select(User).order_by(User.created_at))).scalars().all()
)
return {
"total_users": total_users,
"active_count": active_count,
"inactive_count": inactive_count,
"all_users": all_users,
}
# Default singleton
user_service = UserService()

View File

@@ -65,10 +65,10 @@ async def setup_async_test_db():
async with test_engine.begin() as conn: async with test_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all) await conn.run_sync(Base.metadata.create_all)
AsyncTestingSessionLocal = sessionmaker( AsyncTestingSessionLocal = sessionmaker( # pyright: ignore[reportCallIssue]
autocommit=False, autocommit=False,
autoflush=False, autoflush=False,
bind=test_engine, bind=test_engine, # pyright: ignore[reportArgumentType]
expire_on_commit=False, expire_on_commit=False,
class_=AsyncSession, class_=AsyncSession,
) )

View File

@@ -79,12 +79,13 @@ This FastAPI backend application follows a **clean layered architecture** patter
### Authentication & Security ### Authentication & Security
- **python-jose**: JWT token generation and validation - **PyJWT**: JWT token generation and validation
- Cryptographic signing - Cryptographic signing (HS256, RS256)
- Token expiration handling - Token expiration handling
- Claims validation - Claims validation
- JWK support for Google ID token verification
- **passlib + bcrypt**: Password hashing - **bcrypt**: Password hashing
- Industry-standard bcrypt algorithm - Industry-standard bcrypt algorithm
- Configurable cost factor - Configurable cost factor
- Salt generation - Salt generation
@@ -117,7 +118,8 @@ backend/
│ ├── api/ # API layer │ ├── api/ # API layer
│ │ ├── dependencies/ # Dependency injection │ │ ├── dependencies/ # Dependency injection
│ │ │ ├── auth.py # Authentication dependencies │ │ │ ├── auth.py # Authentication dependencies
│ │ │ ── permissions.py # Authorization dependencies │ │ │ ── permissions.py # Authorization dependencies
│ │ │ └── services.py # Service singleton injection
│ │ ├── routes/ # API endpoints │ │ ├── routes/ # API endpoints
│ │ │ ├── auth.py # Authentication routes │ │ │ ├── auth.py # Authentication routes
│ │ │ ├── users.py # User management routes │ │ │ ├── users.py # User management routes
@@ -131,13 +133,14 @@ backend/
│ │ ├── config.py # Application configuration │ │ ├── config.py # Application configuration
│ │ ├── database.py # Database connection │ │ ├── database.py # Database connection
│ │ ├── exceptions.py # Custom exception classes │ │ ├── exceptions.py # Custom exception classes
│ │ ├── repository_exceptions.py # Repository-level exception hierarchy
│ │ └── middleware.py # Custom middleware │ │ └── middleware.py # Custom middleware
│ │ │ │
│ ├── crud/ # Database operations │ ├── repositories/ # Data access layer
│ │ ├── base.py # Generic CRUD base class │ │ ├── base.py # Generic repository base class
│ │ ├── user.py # User CRUD operations │ │ ├── user.py # User repository
│ │ ├── session.py # Session CRUD operations │ │ ├── session.py # Session repository
│ │ └── organization.py # Organization CRUD │ │ └── organization.py # Organization repository
│ │ │ │
│ ├── models/ # SQLAlchemy models │ ├── models/ # SQLAlchemy models
│ │ ├── base.py # Base model with mixins │ │ ├── base.py # Base model with mixins
@@ -153,8 +156,11 @@ backend/
│ │ ├── sessions.py # Session schemas │ │ ├── sessions.py # Session schemas
│ │ └── organizations.py # Organization schemas │ │ └── organizations.py # Organization schemas
│ │ │ │
│ ├── services/ # Business logic │ ├── services/ # Business logic layer
│ │ ├── auth_service.py # Authentication service │ │ ├── auth_service.py # Authentication service
│ │ ├── user_service.py # User management service
│ │ ├── session_service.py # Session management service
│ │ ├── organization_service.py # Organization service
│ │ ├── email_service.py # Email service │ │ ├── email_service.py # Email service
│ │ └── session_cleanup.py # Background cleanup │ │ └── session_cleanup.py # Background cleanup
│ │ │ │
@@ -168,20 +174,25 @@ backend/
├── tests/ # Test suite ├── tests/ # Test suite
│ ├── api/ # Integration tests │ ├── api/ # Integration tests
│ ├── crud/ # CRUD tests │ ├── repositories/ # Repository unit tests
│ ├── services/ # Service unit tests
│ ├── models/ # Model tests │ ├── models/ # Model tests
│ ├── services/ # Service tests
│ └── conftest.py # Test configuration │ └── conftest.py # Test configuration
├── docs/ # Documentation ├── docs/ # Documentation
│ ├── ARCHITECTURE.md # This file │ ├── ARCHITECTURE.md # This file
│ ├── CODING_STANDARDS.md # Coding standards │ ├── CODING_STANDARDS.md # Coding standards
│ ├── COMMON_PITFALLS.md # Common mistakes to avoid
│ ├── E2E_TESTING.md # E2E testing guide
│ └── FEATURE_EXAMPLE.md # Feature implementation guide │ └── FEATURE_EXAMPLE.md # Feature implementation guide
├── requirements.txt # Python dependencies ├── pyproject.toml # Dependencies, tool configs (Ruff, pytest, coverage, Pyright)
├── pytest.ini # Pytest configuration ├── uv.lock # Locked dependency versions (commit to git)
├── .coveragerc # Coverage configuration ├── Makefile # Development commands (quality, security, testing)
── alembic.ini # Alembic configuration ── .pre-commit-config.yaml # Pre-commit hook configuration
├── .secrets.baseline # detect-secrets baseline (known false positives)
├── alembic.ini # Alembic configuration
└── migrate.py # Migration helper script
``` ```
## Layered Architecture ## Layered Architecture
@@ -214,11 +225,11 @@ The application follows a strict 5-layer architecture:
└──────────────────────────┬──────────────────────────────────┘ └──────────────────────────┬──────────────────────────────────┘
│ calls │ calls
┌──────────────────────────▼──────────────────────────────────┐ ┌──────────────────────────▼──────────────────────────────────┐
CRUD Layer (crud/) Repository Layer (repositories/)
│ - Database operations │ │ - Database operations │
│ - Query building │ │ - Query building │
│ - Transaction management │ - Custom repository exceptions
│ - Error handling │ - No business logic
└──────────────────────────┬──────────────────────────────────┘ └──────────────────────────┬──────────────────────────────────┘
│ uses │ uses
┌──────────────────────────▼──────────────────────────────────┐ ┌──────────────────────────▼──────────────────────────────────┐
@@ -262,7 +273,7 @@ async def get_current_user_info(
**Rules**: **Rules**:
- Should NOT contain business logic - Should NOT contain business logic
- Should NOT directly perform database operations (use CRUD or services) - Should NOT directly call repositories (use services injected via `dependencies/services.py`)
- Must validate all input via Pydantic schemas - Must validate all input via Pydantic schemas
- Must specify response models - Must specify response models
- Should apply appropriate rate limits - Should apply appropriate rate limits
@@ -279,9 +290,9 @@ async def get_current_user_info(
**Example**: **Example**:
```python ```python
def get_current_user( async def get_current_user(
token: str = Depends(oauth2_scheme), token: str = Depends(oauth2_scheme),
db: Session = Depends(get_db) db: AsyncSession = Depends(get_db)
) -> User: ) -> User:
""" """
Extract and validate user from JWT token. Extract and validate user from JWT token.
@@ -295,7 +306,7 @@ def get_current_user(
except Exception: except Exception:
raise AuthenticationError("Invalid authentication credentials") raise AuthenticationError("Invalid authentication credentials")
user = user_crud.get(db, id=user_id) user = await user_repo.get(db, id=user_id)
if not user: if not user:
raise AuthenticationError("User not found") raise AuthenticationError("User not found")
@@ -313,7 +324,7 @@ def get_current_user(
**Responsibility**: Implement complex business logic **Responsibility**: Implement complex business logic
**Key Functions**: **Key Functions**:
- Orchestrate multiple CRUD operations - Orchestrate multiple repository operations
- Implement business rules - Implement business rules
- Handle external service integration - Handle external service integration
- Coordinate transactions - Coordinate transactions
@@ -323,9 +334,9 @@ def get_current_user(
class AuthService: class AuthService:
"""Authentication service with business logic.""" """Authentication service with business logic."""
def login( async def login(
self, self,
db: Session, db: AsyncSession,
email: str, email: str,
password: str, password: str,
request: Request request: Request
@@ -339,8 +350,8 @@ class AuthService:
3. Generate tokens 3. Generate tokens
4. Return tokens and user info 4. Return tokens and user info
""" """
# Validate credentials # Validate credentials via repository
user = user_crud.get_by_email(db, email=email) user = await user_repo.get_by_email(db, email=email)
if not user or not verify_password(password, user.hashed_password): if not user or not verify_password(password, user.hashed_password):
raise AuthenticationError("Invalid credentials") raise AuthenticationError("Invalid credentials")
@@ -350,11 +361,10 @@ class AuthService:
# Extract device info # Extract device info
device_info = extract_device_info(request) device_info = extract_device_info(request)
# Create session # Create session via repository
session = session_crud.create_session( session = await session_repo.create(
db, db,
user_id=user.id, obj_in=SessionCreate(user_id=user.id, **device_info)
device_info=device_info
) )
# Generate tokens # Generate tokens
@@ -373,75 +383,60 @@ class AuthService:
**Rules**: **Rules**:
- Contains business logic, not just data operations - Contains business logic, not just data operations
- Can call multiple CRUD operations - Can call multiple repository operations
- Should handle complex workflows - Should handle complex workflows
- Must maintain data consistency - Must maintain data consistency
- Should use transactions when needed - Should use transactions when needed
#### 4. CRUD Layer (`app/crud/`) #### 4. Repository Layer (`app/repositories/`)
**Responsibility**: Database operations and queries **Responsibility**: Database operations and queries — no business logic
**Key Functions**: **Key Functions**:
- Create, read, update, delete operations - Create, read, update, delete operations
- Build database queries - Build database queries
- Handle database errors - Raise custom repository exceptions (`DuplicateEntryError`, `IntegrityConstraintError`)
- Manage soft deletes - Manage soft deletes
- Implement pagination and filtering - Implement pagination and filtering
**Example**: **Example**:
```python ```python
class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): class SessionRepository(RepositoryBase[UserSession, SessionCreate, SessionUpdate]):
"""CRUD operations for user sessions.""" """Repository for user sessions — database operations only."""
def get_by_jti(self, db: Session, jti: UUID) -> Optional[UserSession]: async def get_by_jti(self, db: AsyncSession, *, jti: str) -> UserSession | None:
"""Get session by refresh token JTI.""" """Get session by refresh token JTI."""
try: result = await db.execute(
return ( select(UserSession).where(UserSession.refresh_token_jti == jti)
db.query(UserSession) )
.filter(UserSession.refresh_token_jti == jti) return result.scalar_one_or_none()
.first()
)
except Exception as e:
logger.error(f"Error getting session by JTI: {str(e)}")
return None
def get_active_by_jti( async def deactivate(self, db: AsyncSession, *, session_id: UUID) -> bool:
self,
db: Session,
jti: UUID
) -> Optional[UserSession]:
"""Get active session by refresh token JTI."""
session = self.get_by_jti(db, jti=jti)
if session and session.is_active and not session.is_expired:
return session
return None
def deactivate(self, db: Session, session_id: UUID) -> bool:
"""Deactivate a session (logout).""" """Deactivate a session (logout)."""
try: try:
session = self.get(db, id=session_id) session = await self.get(db, id=session_id)
if not session: if not session:
return False return False
session.is_active = False session.is_active = False
db.commit() await db.commit()
logger.info(f"Session {session_id} deactivated") logger.info(f"Session {session_id} deactivated")
return True return True
except Exception as e: except Exception as e:
db.rollback() await db.rollback()
logger.error(f"Error deactivating session: {str(e)}") logger.error(f"Error deactivating session: {str(e)}")
return False return False
``` ```
**Rules**: **Rules**:
- Should NOT contain business logic - Should NOT contain business logic
- Must handle database exceptions - Must raise custom repository exceptions (not raw `ValueError`/`IntegrityError`)
- Must use parameterized queries (SQLAlchemy does this) - Must use async SQLAlchemy 2.0 `select()` API (never `db.query()`)
- Should log all database errors - Should log all database errors
- Must rollback on errors - Must rollback on errors
- Should use soft deletes when possible - Should use soft deletes when possible
- **Never imported directly by routes** — always called through services
#### 5. Data Layer (`app/models/` + `app/schemas/`) #### 5. Data Layer (`app/models/` + `app/schemas/`)
@@ -546,51 +541,23 @@ SessionLocal = sessionmaker(
#### Dependency Injection Pattern #### Dependency Injection Pattern
```python ```python
def get_db() -> Generator[Session, None, None]: async def get_db() -> AsyncGenerator[AsyncSession, None]:
""" """
Database session dependency for FastAPI routes. Async database session dependency for FastAPI routes.
Automatically commits on success, rolls back on error. The session is passed to service methods; commit/rollback is
managed inside service or repository methods.
""" """
db = SessionLocal() async with AsyncSessionLocal() as db:
try:
yield db yield db
finally:
db.close()
# Usage in routes # Usage in routes — always through a service, never direct repository
@router.get("/users") @router.get("/users")
def list_users(db: Session = Depends(get_db)): async def list_users(
return user_crud.get_multi(db) user_service: UserService = Depends(get_user_service),
``` db: AsyncSession = Depends(get_db),
):
#### Context Manager Pattern return await user_service.get_users(db)
```python
@contextmanager
def transaction_scope() -> Generator[Session, None, None]:
"""
Context manager for database transactions.
Use for complex operations requiring multiple steps.
Automatically commits on success, rolls back on error.
"""
db = SessionLocal()
try:
yield db
db.commit()
except Exception:
db.rollback()
raise
finally:
db.close()
# Usage in services
def complex_operation():
with transaction_scope() as db:
user = user_crud.create(db, obj_in=user_data)
session = session_crud.create(db, session_data)
return user, session
``` ```
### Model Mixins ### Model Mixins
@@ -782,22 +749,15 @@ def get_profile(
```python ```python
@router.delete("/sessions/{session_id}") @router.delete("/sessions/{session_id}")
def revoke_session( async def revoke_session(
session_id: UUID, session_id: UUID,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db) session_service: SessionService = Depends(get_session_service),
db: AsyncSession = Depends(get_db),
): ):
"""Users can only revoke their own sessions.""" """Users can only revoke their own sessions."""
session = session_crud.get(db, id=session_id) # SessionService verifies ownership and raises NotFoundError / AuthorizationError
await session_service.revoke_session(db, session_id=session_id, user_id=current_user.id)
if not session:
raise NotFoundError("Session not found")
# Check ownership
if session.user_id != current_user.id:
raise AuthorizationError("You can only revoke your own sessions")
session_crud.deactivate(db, session_id=session_id)
return MessageResponse(success=True, message="Session revoked") return MessageResponse(success=True, message="Session revoked")
``` ```
@@ -818,6 +778,84 @@ def add_member(
pass pass
``` ```
### OAuth Integration
The system supports two OAuth modes:
#### OAuth Consumer Mode (Social Login)
Users can authenticate via Google or GitHub OAuth providers:
```python
# Get authorization URL with PKCE support
GET /oauth/authorize/{provider}?redirect_uri=https://yourapp.com/callback
# Handle callback and exchange code for tokens
POST /oauth/callback/{provider}
{
"code": "authorization_code_from_provider",
"state": "csrf_state_token"
}
```
**Security Features:**
- PKCE (S256) for Google
- State parameter for CSRF protection
- Nonce for Google OIDC replay attack prevention
- Google ID token signature verification via JWKS
- Email normalization to prevent account duplication
- Auto-linking by email (configurable)
#### OAuth Provider Mode (MCP Integration)
Full OAuth 2.0 Authorization Server for third-party clients (RFC compliant):
```
┌─────────────┐ ┌─────────────┐
│ MCP Client │ │ Backend │
└──────┬──────┘ └──────┬──────┘
│ │
│ GET /.well-known/oauth-authorization-server│
│─────────────────────────────────────────────>│
│ {metadata} │
│<─────────────────────────────────────────────│
│ │
│ GET /oauth/provider/authorize │
│ ?response_type=code&client_id=... │
│ &redirect_uri=...&code_challenge=... │
│─────────────────────────────────────────────>│
│ │
│ (User consents) │
│ │
│ 302 redirect_uri?code=AUTH_CODE&state=... │
│<─────────────────────────────────────────────│
│ │
│ POST /oauth/provider/token │
│ {grant_type=authorization_code, │
│ code=AUTH_CODE, code_verifier=...} │
│─────────────────────────────────────────────>│
│ │
│ {access_token, refresh_token, expires_in} │
│<─────────────────────────────────────────────│
│ │
```
**Endpoints:**
- `GET /.well-known/oauth-authorization-server` - RFC 8414 metadata
- `GET /oauth/provider/authorize` - Authorization endpoint
- `POST /oauth/provider/token` - Token endpoint (authorization_code, refresh_token)
- `POST /oauth/provider/revoke` - RFC 7009 token revocation
- `POST /oauth/provider/introspect` - RFC 7662 token introspection
**Security Features:**
- PKCE S256 required for public clients (plain method rejected)
- Authorization codes are single-use with 10-minute expiry
- Code reuse detection triggers security incident (all tokens revoked)
- Refresh token rotation on use
- Opaque refresh tokens (hashed in database)
- JWT access tokens with standard claims
- Consent management per client
## Error Handling ## Error Handling
### Exception Hierarchy ### Exception Hierarchy
@@ -983,23 +1021,27 @@ from app.services.session_cleanup import cleanup_expired_sessions
scheduler = AsyncIOScheduler() scheduler = AsyncIOScheduler()
@app.on_event("startup") @asynccontextmanager
async def startup_event(): async def lifespan(app: FastAPI):
"""Start background jobs on application startup.""" """Application lifespan context manager."""
if not settings.IS_TEST: # Don't run in tests # Startup
if os.getenv("IS_TEST", "False") != "True":
scheduler.add_job( scheduler.add_job(
cleanup_expired_sessions, cleanup_expired_sessions,
"cron", "cron",
hour=2, # Run at 2 AM daily hour=2, # Run at 2 AM daily
id="cleanup_expired_sessions" id="cleanup_expired_sessions",
replace_existing=True,
) )
scheduler.start() scheduler.start()
logger.info("Background jobs started") logger.info("Background jobs started")
@app.on_event("shutdown") yield
async def shutdown_event():
"""Stop background jobs on application shutdown.""" # Shutdown
scheduler.shutdown() if os.getenv("IS_TEST", "False") != "True":
scheduler.shutdown()
await close_async_db() # Dispose database engine connections
``` ```
### Job Implementation ### Job Implementation
@@ -1014,8 +1056,8 @@ async def cleanup_expired_sessions():
Runs daily at 2 AM. Removes sessions expired for more than 30 days. Runs daily at 2 AM. Removes sessions expired for more than 30 days.
""" """
try: try:
with transaction_scope() as db: async with AsyncSessionLocal() as db:
count = session_crud.cleanup_expired(db, keep_days=30) count = await session_repo.cleanup_expired(db, keep_days=30)
logger.info(f"Cleaned up {count} expired sessions") logger.info(f"Cleaned up {count} expired sessions")
except Exception as e: except Exception as e:
logger.error(f"Error cleaning up sessions: {str(e)}", exc_info=True) logger.error(f"Error cleaning up sessions: {str(e)}", exc_info=True)
@@ -1032,7 +1074,7 @@ async def cleanup_expired_sessions():
│Integration │ ← API endpoint tests │Integration │ ← API endpoint tests
│ Tests │ │ Tests │
├─────────────┤ ├─────────────┤
│ Unit │ ← CRUD, services, utilities │ Unit │ ← repositories, services, utilities
│ Tests │ │ Tests │
└─────────────┘ └─────────────┘
``` ```
@@ -1127,6 +1169,8 @@ app.add_middleware(
## Performance Considerations ## Performance Considerations
> 📖 For the full benchmarking guide (how to run, read results, write new benchmarks, and manage baselines), see **[BENCHMARKS.md](BENCHMARKS.md)**.
### Database Connection Pooling ### Database Connection Pooling
- Pool size: 20 connections - Pool size: 20 connections

311
backend/docs/BENCHMARKS.md Normal file
View File

@@ -0,0 +1,311 @@
# Performance Benchmarks Guide
Automated performance benchmarking infrastructure using **pytest-benchmark** to detect latency regressions in critical API endpoints.
## Table of Contents
- [Why Benchmark?](#why-benchmark)
- [Quick Start](#quick-start)
- [How It Works](#how-it-works)
- [Understanding Results](#understanding-results)
- [Test Organization](#test-organization)
- [Writing Benchmark Tests](#writing-benchmark-tests)
- [Baseline Management](#baseline-management)
- [CI/CD Integration](#cicd-integration)
- [Troubleshooting](#troubleshooting)
---
## Why Benchmark?
Performance regressions are silent bugs — they don't break tests or cause errors, but they degrade the user experience over time. Common causes include:
- **Unintended N+1 queries** after adding a relationship
- **Heavier serialization** after adding new fields to a response model
- **Middleware overhead** from new security headers or logging
- **Dependency upgrades** that introduce slower code paths
Without automated benchmarks, these regressions go unnoticed until users complain. Performance benchmarks serve as an **early warning system** — they measure endpoint latency on every run and flag significant deviations from an established baseline.
### What benchmarks give you
| Benefit | Description |
|---------|-------------|
| **Regression detection** | Automatically flags when an endpoint becomes significantly slower |
| **Baseline tracking** | Stores known-good performance numbers for comparison |
| **Confidence in refactors** | Verify that code changes don't degrade response times |
| **Visibility** | Makes performance a first-class, measurable quality attribute |
---
## Quick Start
```bash
# Run benchmarks (no comparison, just see current numbers)
make benchmark
# Save current results as the baseline
make benchmark-save
# Run benchmarks and compare against the saved baseline
make benchmark-check
```
---
## How It Works
The benchmarking system has three layers:
### 1. pytest-benchmark integration
[pytest-benchmark](https://pytest-benchmark.readthedocs.io/) is a pytest plugin that provides a `benchmark` fixture. It handles:
- **Calibration**: Automatically determines how many iterations to run for statistical significance
- **Timing**: Uses `time.perf_counter` for high-resolution measurements
- **Statistics**: Computes min, max, mean, median, standard deviation, IQR, and outlier detection
- **Comparison**: Compares current results against saved baselines and flags regressions
### 2. Benchmark types
The test suite includes two categories of performance tests:
| Type | How it works | Examples |
|------|-------------|----------|
| **pytest-benchmark tests** | Uses the `benchmark` fixture for precise, multi-round timing | `test_health_endpoint_performance`, `test_openapi_schema_performance`, `test_password_hashing_performance`, `test_password_verification_performance`, `test_access_token_creation_performance`, `test_refresh_token_creation_performance`, `test_token_decode_performance` |
| **Manual latency tests** | Uses `time.perf_counter` with explicit thresholds (for async endpoints that pytest-benchmark doesn't support natively) | `test_login_latency`, `test_get_current_user_latency`, `test_register_latency`, `test_token_refresh_latency`, `test_sessions_list_latency`, `test_user_profile_update_latency` |
### 3. Regression detection
When running `make benchmark-check`, the system:
1. Runs all benchmark tests
2. Compares results against the saved baseline (`.benchmarks/` directory)
3. **Fails the build** if any test's mean time exceeds **200%** of the baseline (i.e., 3× slower)
The `200%` threshold in `--benchmark-compare-fail=mean:200%` means "fail if the mean increased by more than 200% relative to the baseline." This is deliberately generous to avoid false positives from normal run-to-run variance while still catching real regressions.
---
## Understanding Results
A typical benchmark output looks like this:
```
--------------------------------------------------------------------------------------- benchmark: 2 tests --------------------------------------------------------------------------------------
Name (time in ms) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_health_endpoint_performance 0.9841 (1.0) 1.5513 (1.0) 1.1390 (1.0) 0.1098 (1.0) 1.1151 (1.0) 0.1672 (1.0) 39;2 877.9666 (1.0) 133 1
test_openapi_schema_performance 1.6523 (1.68) 2.0892 (1.35) 1.7843 (1.57) 0.1553 (1.41) 1.7200 (1.54) 0.1727 (1.03) 2;0 560.4471 (0.64) 10 1
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
```
### Column reference
| Column | Meaning |
|--------|---------|
| **Min** | Fastest single execution |
| **Max** | Slowest single execution |
| **Mean** | Average across all rounds — the primary metric for regression detection |
| **StdDev** | How much results vary between rounds (lower = more stable) |
| **Median** | Middle value, less sensitive to outliers than mean |
| **IQR** | Interquartile range — spread of the middle 50% of results |
| **Outliers** | Format `A;B` — A = within 1 StdDev, B = within 1.5 IQR from quartiles |
| **OPS** | Operations per second (`1 / Mean`) |
| **Rounds** | How many times the test was executed (auto-calibrated) |
| **Iterations** | Iterations per round (usually 1 for ms-scale tests) |
### The ratio numbers `(1.0)`, `(1.68)`, etc.
These show how each test compares **to the best result in that column**. The fastest test is always `(1.0)`, and others show their relative factor. For example, `(1.68)` means "1.68× slower than the fastest."
### Color coding
- **Green**: The fastest (best) value in each column
- **Red**: The slowest (worst) value in each column
This is a **relative ranking within the current run** — red does NOT mean the test failed or that performance is bad. It simply highlights which endpoint is the slower one in the group.
### What's "normal"?
For this project's current endpoints:
| Test | Expected range | Why |
|------|---------------|-----|
| `GET /health` | ~11.5ms | Minimal logic, mocked DB check |
| `GET /api/v1/openapi.json` | ~1.52.5ms | Serializes entire API schema |
| `get_password_hash` | ~200ms | CPU-bound bcrypt hashing |
| `verify_password` | ~200ms | CPU-bound bcrypt verification |
| `create_access_token` | ~1720µs | JWT encoding with HMAC-SHA256 |
| `create_refresh_token` | ~1720µs | JWT encoding with HMAC-SHA256 |
| `decode_token` | ~2025µs | JWT decoding and claim validation |
| `POST /api/v1/auth/login` | < 500ms threshold | Includes bcrypt password verification |
| `POST /api/v1/auth/register` | < 500ms threshold | Includes bcrypt password hashing |
| `POST /api/v1/auth/refresh` | < 200ms threshold | Token rotation + DB session update |
| `GET /api/v1/users/me` | < 200ms threshold | DB lookup + token validation |
| `GET /api/v1/sessions/me` | < 200ms threshold | Session list query + token validation |
| `PATCH /api/v1/users/me` | < 200ms threshold | DB update + token validation |
---
## Test Organization
```
backend/tests/
├── benchmarks/
│ └── test_endpoint_performance.py # All performance benchmark tests
backend/.benchmarks/ # Saved baselines (auto-generated)
└── Linux-CPython-3.12-64bit/
└── 0001_baseline.json # Platform-specific baseline file
```
### Test markers
All benchmark tests use the `@pytest.mark.benchmark` marker. The `--benchmark-only` flag ensures that only tests using the `benchmark` fixture are executed during benchmark runs, while manual latency tests (async) are skipped.
---
## Writing Benchmark Tests
### Stateless endpoint (using pytest-benchmark fixture)
```python
import pytest
from fastapi.testclient import TestClient
def test_my_endpoint_performance(sync_client, benchmark):
"""Benchmark: GET /my-endpoint should respond within acceptable latency."""
result = benchmark(sync_client.get, "/my-endpoint")
assert result.status_code == 200
```
The `benchmark` fixture handles all timing, calibration, and statistics automatically. Just pass it the callable and arguments.
### Async / DB-dependent endpoint (manual timing)
For async endpoints that require database access, use manual timing with an explicit threshold:
```python
import time
import pytest
MAX_RESPONSE_MS = 300
@pytest.mark.asyncio
async def test_my_async_endpoint_latency(client, setup_fixture):
"""Performance: endpoint must respond under threshold."""
iterations = 5
total_ms = 0.0
for _ in range(iterations):
start = time.perf_counter()
response = await client.get("/api/v1/my-endpoint")
elapsed_ms = (time.perf_counter() - start) * 1000
total_ms += elapsed_ms
assert response.status_code == 200
mean_ms = total_ms / iterations
assert mean_ms < MAX_RESPONSE_MS, (
f"Latency regression: {mean_ms:.1f}ms exceeds {MAX_RESPONSE_MS}ms threshold"
)
```
### Guidelines for new benchmarks
1. **Benchmark critical paths** — endpoints users hit frequently or where latency matters most
2. **Mock external dependencies** for stateless tests to isolate endpoint overhead
3. **Set generous thresholds** for manual tests — account for CI variability
4. **Keep benchmarks fast** — they run on every check, so avoid heavy setup
---
## Baseline Management
### Saving a baseline
```bash
make benchmark-save
```
This runs all benchmarks and saves results to `.benchmarks/<platform>/0001_baseline.json`. The baseline captures:
- Mean, min, max, median, stddev for each test
- Machine info (CPU, OS, Python version)
- Timestamp
### Comparing against baseline
```bash
make benchmark-check
```
If no baseline exists, this command automatically creates one and prints a warning. On subsequent runs, it compares current results against the saved baseline.
### When to update the baseline
- **After intentional performance changes** (e.g., you optimized an endpoint — save the new, faster baseline)
- **After infrastructure changes** (e.g., new CI runner, different hardware)
- **After adding new benchmark tests** (the new tests need a baseline entry)
```bash
# Update the baseline after intentional changes
make benchmark-save
```
### Version control
The `.benchmarks/` directory can be committed to version control so that CI pipelines can compare against a known-good baseline. However, since benchmark results are machine-specific, you may prefer to generate baselines in CI rather than committing local results.
---
## CI/CD Integration
Add benchmark checking to your CI pipeline to catch regressions on every PR:
```yaml
# Example GitHub Actions step
- name: Performance regression check
run: |
cd backend
make benchmark-save # Create baseline from main branch
# ... apply PR changes ...
make benchmark-check # Compare PR against baseline
```
A more robust approach:
1. Save the baseline on the `main` branch after each merge
2. On PR branches, run `make benchmark-check` against the `main` baseline
3. The pipeline fails if any endpoint regresses beyond the 200% threshold
---
## Troubleshooting
### "No benchmark baseline found" warning
```
⚠️ No benchmark baseline found. Run 'make benchmark-save' first to create one.
```
This means no baseline file exists yet. The command will auto-create one. Future runs of `make benchmark-check` will compare against it.
### Machine info mismatch warning
```
WARNING: benchmark machine_info is different
```
This is expected when comparing baselines generated on a different machine or OS. The comparison still works, but absolute numbers may differ. Re-save the baseline on the current machine if needed.
### High variance (large StdDev)
If StdDev is high relative to the Mean, results may be unreliable. Common causes:
- System under load during benchmark run
- Garbage collection interference
- Thermal throttling
Try running benchmarks on an idle system or increasing `min_rounds` in `pyproject.toml`.
### Only 7 of 13 tests run
The async tests (`test_login_latency`, `test_get_current_user_latency`, `test_register_latency`, `test_token_refresh_latency`, `test_sessions_list_latency`, `test_user_profile_update_latency`) are skipped during `--benchmark-only` runs because they don't use the `benchmark` fixture. They run as part of the normal test suite (`make test`) with manual threshold assertions.

View File

@@ -8,6 +8,7 @@ This document outlines the coding standards and best practices for the FastAPI b
- [Code Organization](#code-organization) - [Code Organization](#code-organization)
- [Naming Conventions](#naming-conventions) - [Naming Conventions](#naming-conventions)
- [Error Handling](#error-handling) - [Error Handling](#error-handling)
- [Data Models and Migrations](#data-models-and-migrations)
- [Database Operations](#database-operations) - [Database Operations](#database-operations)
- [API Endpoints](#api-endpoints) - [API Endpoints](#api-endpoints)
- [Authentication & Security](#authentication--security) - [Authentication & Security](#authentication--security)
@@ -74,15 +75,14 @@ def create_user(db: Session, user_in: UserCreate) -> User:
### 4. Code Formatting ### 4. Code Formatting
Use automated formatters: Use automated formatters:
- **Black**: Code formatting - **Ruff**: Code formatting and linting (replaces Black, isort, flake8)
- **isort**: Import sorting - **pyright**: Static type checking
- **flake8**: Linting
Run before committing: Run before committing (or use `make validate`):
```bash ```bash
black app tests uv run ruff format app tests
isort app tests uv run ruff check app tests
flake8 app tests uv run pyright app
``` ```
## Code Organization ## Code Organization
@@ -93,19 +93,17 @@ Follow the 5-layer architecture strictly:
``` ```
API Layer (routes/) API Layer (routes/)
↓ calls ↓ calls (via service injected from dependencies/services.py)
Dependencies (dependencies/)
↓ injects
Service Layer (services/) Service Layer (services/)
↓ calls ↓ calls
CRUD Layer (crud/) Repository Layer (repositories/)
↓ uses ↓ uses
Models & Schemas (models/, schemas/) Models & Schemas (models/, schemas/)
``` ```
**Rules:** **Rules:**
- Routes should NOT directly call CRUD operations (use services when business logic is needed) - Routes must NEVER import repositories directly — always use a service
- CRUD operations should NOT contain business logic - Services call repositories; repositories contain only database operations
- Models should NOT import from higher layers - Models should NOT import from higher layers
- Each layer should only depend on the layer directly below it - Each layer should only depend on the layer directly below it
@@ -124,7 +122,7 @@ from sqlalchemy.orm import Session
# 3. Local application imports # 3. Local application imports
from app.api.dependencies.auth import get_current_user from app.api.dependencies.auth import get_current_user
from app.crud import user_crud from app.api.dependencies.services import get_user_service
from app.models.user import User from app.models.user import User
from app.schemas.users import UserResponse, UserCreate from app.schemas.users import UserResponse, UserCreate
``` ```
@@ -216,7 +214,7 @@ if not user:
### Error Handling Pattern ### Error Handling Pattern
Always follow this pattern in CRUD operations (Async version): Always follow this pattern in repository operations (Async version):
```python ```python
from sqlalchemy.exc import IntegrityError, OperationalError, DataError from sqlalchemy.exc import IntegrityError, OperationalError, DataError
@@ -282,9 +280,154 @@ All error responses follow this structure:
} }
``` ```
## Data Models and Migrations
### Model Definition Best Practices
To ensure Alembic autogenerate works reliably without drift, follow these rules:
#### 1. Simple Indexes: Use Column-Level or `__table_args__`, Not Both
```python
# ❌ BAD - Creates DUPLICATE indexes with different names
class User(Base):
role = Column(String(50), index=True) # Creates ix_users_role
__table_args__ = (
Index("ix_user_role", "role"), # Creates ANOTHER index!
)
# ✅ GOOD - Choose ONE approach
class User(Base):
role = Column(String(50)) # No index=True
__table_args__ = (
Index("ix_user_role", "role"), # Single index with explicit name
)
# ✅ ALSO GOOD - For simple single-column indexes
class User(Base):
role = Column(String(50), index=True) # Auto-named ix_users_role
```
#### 2. Composite Indexes: Always Use `__table_args__`
```python
class UserOrganization(Base):
__tablename__ = "user_organizations"
user_id = Column(UUID, nullable=False)
organization_id = Column(UUID, nullable=False)
is_active = Column(Boolean, default=True, nullable=False, index=True)
__table_args__ = (
Index("ix_user_org_user_active", "user_id", "is_active"),
Index("ix_user_org_org_active", "organization_id", "is_active"),
)
```
#### 3. Functional/Partial Indexes: Use `ix_perf_` Prefix
Alembic **cannot** auto-detect:
- **Functional indexes**: `LOWER(column)`, `UPPER(column)`, expressions
- **Partial indexes**: Indexes with `WHERE` clauses
**Solution**: Use the `ix_perf_` naming prefix. Any index with this prefix is automatically excluded from autogenerate by `env.py`.
```python
# In migration file (NOT in model) - use ix_perf_ prefix:
op.create_index(
"ix_perf_users_email_lower", # <-- ix_perf_ prefix!
"users",
[sa.text("LOWER(email)")], # Functional
postgresql_where=sa.text("deleted_at IS NULL"), # Partial
)
```
**No need to update `env.py`** - the prefix convention handles it automatically:
```python
# env.py - already configured:
def include_object(object, name, type_, reflected, compare_to):
if type_ == "index" and name:
if name.startswith("ix_perf_"): # Auto-excluded!
return False
return True
```
**To add new performance indexes:**
1. Create a new migration file
2. Name your indexes with `ix_perf_` prefix
3. Done - Alembic will ignore them automatically
#### 4. Use Correct Types
```python
# ✅ GOOD - PostgreSQL-native types
from sqlalchemy.dialects.postgresql import JSONB, UUID
class User(Base):
id = Column(UUID(as_uuid=True), primary_key=True)
preferences = Column(JSONB) # Not JSON!
# ❌ BAD - Generic types may cause migration drift
from sqlalchemy import JSON
preferences = Column(JSON) # May detect as different from JSONB
```
### Migration Workflow
#### Creating Migrations
```bash
# Generate autogenerate migration:
python migrate.py generate "Add new field"
# Or inside Docker:
docker exec -w /app backend uv run alembic revision --autogenerate -m "Add new field"
# Apply migration:
python migrate.py apply
# Or: docker exec -w /app backend uv run alembic upgrade head
```
#### Testing for Drift
After any model changes, verify no unintended drift:
```bash
# Generate test migration
docker exec -w /app backend uv run alembic revision --autogenerate -m "test_drift"
# Check the generated file - should be empty (just 'pass')
# If it has operations, investigate why
# Delete test file
rm backend/app/alembic/versions/*_test_drift.py
```
#### Migration File Structure
```
backend/app/alembic/versions/
├── cbddc8aa6eda_initial_models.py # Auto-generated, tracks all models
├── 0002_performance_indexes.py # Manual, functional/partial indexes
└── __init__.py
```
### Summary: What Goes Where
| Index Type | In Model? | Alembic Detects? | Where to Define |
|------------|-----------|------------------|-----------------|
| Simple column (`index=True`) | Yes | Yes | Column definition |
| Composite (`col1, col2`) | Yes | Yes | `__table_args__` |
| Unique composite | Yes | Yes | `__table_args__` with `unique=True` |
| Functional (`LOWER(col)`) | No | No | Migration with `ix_perf_` prefix |
| Partial (`WHERE ...`) | No | No | Migration with `ix_perf_` prefix |
## Database Operations ## Database Operations
### Async CRUD Pattern ### Async Repository Pattern
**IMPORTANT**: This application uses **async SQLAlchemy** with modern patterns for better performance and testability. **IMPORTANT**: This application uses **async SQLAlchemy** with modern patterns for better performance and testability.
@@ -296,19 +439,19 @@ All error responses follow this structure:
4. **Testability**: Easy to mock and test 4. **Testability**: Easy to mock and test
5. **Consistent Ordering**: Always order queries for pagination 5. **Consistent Ordering**: Always order queries for pagination
### Use the Async CRUD Base Class ### Use the Async Repository Base Class
Always inherit from `CRUDBase` for database operations: Always inherit from `RepositoryBase` for database operations:
```python ```python
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select from sqlalchemy import select
from app.crud.base import CRUDBase from app.repositories.base import RepositoryBase
from app.models.user import User from app.models.user import User
from app.schemas.users import UserCreate, UserUpdate from app.schemas.users import UserCreate, UserUpdate
class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]): class UserRepository(RepositoryBase[User, UserCreate, UserUpdate]):
"""CRUD operations for User model.""" """Repository for User model — database operations only."""
async def get_by_email( async def get_by_email(
self, self,
@@ -321,7 +464,7 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
) )
return result.scalar_one_or_none() return result.scalar_one_or_none()
user_crud = CRUDUser(User) user_repo = UserRepository(User)
``` ```
**Key Points:** **Key Points:**
@@ -330,6 +473,7 @@ user_crud = CRUDUser(User)
- Use `await db.execute()` for queries - Use `await db.execute()` for queries
- Use `.scalar_one_or_none()` instead of `.first()` - Use `.scalar_one_or_none()` instead of `.first()`
- Use `T | None` instead of `Optional[T]` - Use `T | None` instead of `Optional[T]`
- Repository instances are used internally by services — never import them in routes
### Modern SQLAlchemy Patterns ### Modern SQLAlchemy Patterns
@@ -417,13 +561,13 @@ async def create_user(
The database session is automatically managed by FastAPI. The database session is automatically managed by FastAPI.
Commit on success, rollback on error. Commit on success, rollback on error.
""" """
return await user_crud.create(db, obj_in=user_in) return await user_service.create_user(db, obj_in=user_in)
``` ```
**Key Points:** **Key Points:**
- Route functions must be `async def` - Route functions must be `async def`
- Database parameter is `AsyncSession` - Database parameter is `AsyncSession`
- Always `await` CRUD operations - Always `await` repository operations
#### In Services (Multiple Operations) #### In Services (Multiple Operations)
@@ -436,12 +580,11 @@ async def complex_operation(
""" """
Perform multiple database operations atomically. Perform multiple database operations atomically.
The session automatically commits on success or rolls back on error. Services call repositories; commit/rollback is handled inside
each repository method.
""" """
user = await user_crud.create(db, obj_in=user_data) user = await user_repo.create(db, obj_in=user_data)
session = await session_crud.create(db, obj_in=session_data) session = await session_repo.create(db, obj_in=session_data)
# Commit is handled by the route's dependency
return user, session return user, session
``` ```
@@ -451,10 +594,10 @@ Prefer soft deletes over hard deletes for audit trails:
```python ```python
# Good - Soft delete (sets deleted_at) # Good - Soft delete (sets deleted_at)
await user_crud.soft_delete(db, id=user_id) await user_repo.soft_delete(db, id=user_id)
# Acceptable only when required - Hard delete # Acceptable only when required - Hard delete
user_crud.remove(db, id=user_id) await user_repo.remove(db, id=user_id)
``` ```
### Query Patterns ### Query Patterns
@@ -594,9 +737,10 @@ Always implement pagination for list endpoints:
from app.schemas.common import PaginationParams, PaginatedResponse from app.schemas.common import PaginationParams, PaginatedResponse
@router.get("/users", response_model=PaginatedResponse[UserResponse]) @router.get("/users", response_model=PaginatedResponse[UserResponse])
def list_users( async def list_users(
pagination: PaginationParams = Depends(), pagination: PaginationParams = Depends(),
db: Session = Depends(get_db) user_service: UserService = Depends(get_user_service),
db: AsyncSession = Depends(get_db),
): ):
""" """
List all users with pagination. List all users with pagination.
@@ -604,10 +748,8 @@ def list_users(
Default page size: 20 Default page size: 20
Maximum page size: 100 Maximum page size: 100
""" """
users, total = user_crud.get_multi_with_total( users, total = await user_service.get_users(
db, db, skip=pagination.offset, limit=pagination.limit
skip=pagination.offset,
limit=pagination.limit
) )
return PaginatedResponse(data=users, pagination=pagination.create_meta(total)) return PaginatedResponse(data=users, pagination=pagination.create_meta(total))
``` ```
@@ -670,19 +812,17 @@ def admin_route(
pass pass
# Check ownership # Check ownership
def delete_resource( async def delete_resource(
resource_id: UUID, resource_id: UUID,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db) resource_service: ResourceService = Depends(get_resource_service),
db: AsyncSession = Depends(get_db),
): ):
resource = resource_crud.get(db, id=resource_id) # Service handles ownership check and raises appropriate errors
if not resource: await resource_service.delete_resource(
raise NotFoundError("Resource not found") db, resource_id=resource_id, user_id=current_user.id,
is_superuser=current_user.is_superuser,
if resource.user_id != current_user.id and not current_user.is_superuser: )
raise AuthorizationError("You can only delete your own resources")
resource_crud.remove(db, id=resource_id)
``` ```
### Input Validation ### Input Validation
@@ -716,9 +856,9 @@ tests/
├── api/ # Integration tests ├── api/ # Integration tests
│ ├── test_users.py │ ├── test_users.py
│ └── test_auth.py │ └── test_auth.py
├── crud/ # Unit tests for CRUD ├── repositories/ # Unit tests for repositories
├── models/ # Model tests ├── services/ # Unit tests for services
└── services/ # Service tests └── models/ # Model tests
``` ```
### Async Testing with pytest-asyncio ### Async Testing with pytest-asyncio
@@ -781,7 +921,7 @@ async def test_user(db_session: AsyncSession) -> User:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_user(db_session: AsyncSession, test_user: User): async def test_get_user(db_session: AsyncSession, test_user: User):
"""Test retrieving a user by ID.""" """Test retrieving a user by ID."""
user = await user_crud.get(db_session, id=test_user.id) user = await user_repo.get(db_session, id=test_user.id)
assert user is not None assert user is not None
assert user.email == test_user.email assert user.email == test_user.email
``` ```

View File

@@ -334,14 +334,14 @@ def login(request: Request, credentials: OAuth2PasswordRequestForm):
# ❌ WRONG - Returns password hash! # ❌ WRONG - Returns password hash!
@router.get("/users/{user_id}") @router.get("/users/{user_id}")
def get_user(user_id: UUID, db: Session = Depends(get_db)) -> User: def get_user(user_id: UUID, db: Session = Depends(get_db)) -> User:
return user_crud.get(db, id=user_id) # Returns ORM model with ALL fields! return user_repo.get(db, id=user_id) # Returns ORM model with ALL fields!
``` ```
```python ```python
# ✅ CORRECT - Use response schema # ✅ CORRECT - Use response schema
@router.get("/users/{user_id}", response_model=UserResponse) @router.get("/users/{user_id}", response_model=UserResponse)
def get_user(user_id: UUID, db: Session = Depends(get_db)): def get_user(user_id: UUID, db: Session = Depends(get_db)):
user = user_crud.get(db, id=user_id) user = user_repo.get(db, id=user_id)
if not user: if not user:
raise HTTPException(status_code=404, detail="User not found") raise HTTPException(status_code=404, detail="User not found")
return user # Pydantic filters to only UserResponse fields return user # Pydantic filters to only UserResponse fields
@@ -506,8 +506,8 @@ def revoke_session(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db) db: Session = Depends(get_db)
): ):
session = session_crud.get(db, id=session_id) session = session_repo.get(db, id=session_id)
session_crud.deactivate(db, session_id=session_id) session_repo.deactivate(db, session_id=session_id)
# BUG: User can revoke ANYONE'S session! # BUG: User can revoke ANYONE'S session!
return {"message": "Session revoked"} return {"message": "Session revoked"}
``` ```
@@ -520,7 +520,7 @@ def revoke_session(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db) db: Session = Depends(get_db)
): ):
session = session_crud.get(db, id=session_id) session = session_repo.get(db, id=session_id)
if not session: if not session:
raise NotFoundError("Session not found") raise NotFoundError("Session not found")
@@ -529,7 +529,7 @@ def revoke_session(
if session.user_id != current_user.id: if session.user_id != current_user.id:
raise AuthorizationError("You can only revoke your own sessions") raise AuthorizationError("You can only revoke your own sessions")
session_crud.deactivate(db, session_id=session_id) session_repo.deactivate(db, session_id=session_id)
return {"message": "Session revoked"} return {"message": "Session revoked"}
``` ```
@@ -616,7 +616,43 @@ def create_user(
return user return user
``` ```
**Rule**: Add type hints to ALL functions. Use `mypy` to enforce type checking. **Rule**: Add type hints to ALL functions. Use `pyright` to enforce type checking (`make type-check`).
---
---
### ❌ PITFALL #19: Importing Repositories Directly in Routes
**Issue**: Routes should never call repositories directly. The layered architecture requires all business operations to go through the service layer.
```python
# ❌ WRONG - Route bypasses service layer
from app.repositories.session import session_repo
@router.get("/sessions/me")
async def list_sessions(
current_user: User = Depends(get_current_active_user),
db: AsyncSession = Depends(get_db),
):
return await session_repo.get_user_sessions(db, user_id=current_user.id)
```
```python
# ✅ CORRECT - Route calls service injected via dependency
from app.api.dependencies.services import get_session_service
from app.services.session_service import SessionService
@router.get("/sessions/me")
async def list_sessions(
current_user: User = Depends(get_current_active_user),
session_service: SessionService = Depends(get_session_service),
db: AsyncSession = Depends(get_db),
):
return await session_service.get_user_sessions(db, user_id=current_user.id)
```
**Rule**: Routes import from `app.api.dependencies.services`, never from `app.repositories.*`. Services are the only callers of repositories.
--- ---
@@ -649,6 +685,11 @@ Use this checklist to catch issues before code review:
- [ ] Resource ownership verification - [ ] Resource ownership verification
- [ ] CORS configured (no wildcards in production) - [ ] CORS configured (no wildcards in production)
### Architecture
- [ ] Routes never import repositories directly (only services)
- [ ] Services call repositories; repositories call database only
- [ ] New service registered in `app/api/dependencies/services.py`
### Python ### Python
- [ ] Use `==` not `is` for value comparison - [ ] Use `==` not `is` for value comparison
- [ ] No mutable default arguments - [ ] No mutable default arguments
@@ -661,21 +702,18 @@ Use this checklist to catch issues before code review:
### Pre-commit Checks ### Pre-commit Checks
Add these to your development workflow: Add these to your development workflow (or use `make validate`):
```bash ```bash
# Format code # Format + lint (Ruff replaces Black, isort, flake8)
black app tests uv run ruff format app tests
isort app tests uv run ruff check app tests
# Type checking # Type checking
mypy app --strict uv run pyright app
# Linting
flake8 app tests
# Run tests # Run tests
pytest --cov=app --cov-report=term-missing IS_TEST=True uv run pytest --cov=app --cov-report=term-missing
# Check coverage (should be 80%+) # Check coverage (should be 80%+)
coverage report --fail-under=80 coverage report --fail-under=80
@@ -693,6 +731,6 @@ Add new entries when:
--- ---
**Last Updated**: 2025-10-31 **Last Updated**: 2026-02-28
**Issues Cataloged**: 18 common pitfalls **Issues Cataloged**: 19 common pitfalls
**Remember**: This document exists because these issues HAVE occurred. Don't skip it. **Remember**: This document exists because these issues HAVE occurred. Don't skip it.

348
backend/docs/E2E_TESTING.md Normal file
View File

@@ -0,0 +1,348 @@
# Backend E2E Testing Guide
End-to-end testing infrastructure using **Testcontainers** (real PostgreSQL) and **Schemathesis** (OpenAPI contract testing).
## Table of Contents
- [Quick Start](#quick-start)
- [Requirements](#requirements)
- [How It Works](#how-it-works)
- [Test Organization](#test-organization)
- [Writing E2E Tests](#writing-e2e-tests)
- [Running Tests](#running-tests)
- [Troubleshooting](#troubleshooting)
- [CI/CD Integration](#cicd-integration)
---
## Quick Start
```bash
# 1. Install E2E dependencies
make install-e2e
# 2. Ensure Docker is running
make check-docker
# 3. Run E2E tests
make test-e2e
```
---
## Requirements
### Docker
E2E tests use Testcontainers to spin up real PostgreSQL containers. Docker must be running:
- **macOS/Windows**: Docker Desktop
- **Linux**: Docker Engine (`sudo systemctl start docker`)
### Dependencies
E2E tests require additional packages beyond the standard dev dependencies:
```bash
# Install E2E dependencies
make install-e2e
# Or manually:
uv sync --extra dev --extra e2e
```
This installs:
- `testcontainers[postgres]>=4.0.0` - Docker container management
- `schemathesis>=3.30.0` - OpenAPI contract testing
---
## How It Works
### Testcontainers
Testcontainers automatically manages Docker containers for tests:
1. **Session-scoped container**: A single PostgreSQL 17 container starts once per test session
2. **Function-scoped isolation**: Each test gets fresh tables (drop + recreate)
3. **Automatic cleanup**: Container is destroyed when tests complete
This approach catches bugs that SQLite-based tests miss:
- PostgreSQL-specific SQL behavior
- Real constraint violations
- Actual transaction semantics
- JSONB column behavior
### Schemathesis
Schemathesis generates test cases from your OpenAPI schema:
1. **Schema loading**: Reads `/api/v1/openapi.json` from your FastAPI app
2. **Test generation**: Creates test cases for each endpoint
3. **Response validation**: Verifies responses match documented schema
This catches:
- Undocumented response codes
- Schema mismatches (wrong types, missing fields)
- Edge cases in input validation
---
## Test Organization
```
backend/tests/
├── e2e/ # E2E tests (PostgreSQL, Docker required)
│ ├── __init__.py
│ ├── conftest.py # Testcontainers fixtures
│ ├── test_api_contracts.py # Schemathesis schema tests
│ └── test_database_workflows.py # PostgreSQL workflow tests
├── api/ # Integration tests (SQLite, fast)
├── repositories/ # Repository unit tests
└── conftest.py # Standard fixtures
```
### Test Markers
Tests use pytest markers for filtering:
| Marker | Description |
|--------|-------------|
| `@pytest.mark.e2e` | End-to-end test requiring Docker |
| `@pytest.mark.postgres` | PostgreSQL-specific test |
| `@pytest.mark.schemathesis` | Schemathesis schema test |
---
## Writing E2E Tests
### Basic E2E Test
```python
import pytest
from uuid import uuid4
@pytest.mark.e2e
@pytest.mark.postgres
@pytest.mark.asyncio
async def test_user_workflow(e2e_client):
"""Test user registration with real PostgreSQL."""
email = f"test-{uuid4().hex[:8]}@example.com"
response = await e2e_client.post(
"/api/v1/auth/register",
json={
"email": email,
"password": "SecurePassword123!",
"first_name": "Test",
"last_name": "User",
},
)
assert response.status_code in [200, 201]
assert response.json()["email"] == email
```
### Available Fixtures
| Fixture | Scope | Description |
|---------|-------|-------------|
| `postgres_container` | session | Raw Testcontainers PostgreSQL container |
| `async_postgres_url` | session | Asyncpg-compatible connection URL |
| `e2e_db_session` | function | SQLAlchemy AsyncSession with fresh tables |
| `e2e_client` | function | httpx AsyncClient connected to real DB |
### Schemathesis Test
```python
import pytest
import schemathesis
from hypothesis import settings, Phase
from app.main import app
schema = schemathesis.from_asgi("/api/v1/openapi.json", app=app)
@pytest.mark.e2e
@pytest.mark.schemathesis
@schema.parametrize(endpoint="/api/v1/auth/register")
@settings(max_examples=20)
def test_registration_schema(case):
"""Test registration endpoint conforms to schema."""
response = case.call_asgi()
case.validate_response(response)
```
---
## Running Tests
### Commands
```bash
# Run all E2E tests
make test-e2e
# Run only Schemathesis schema tests
make test-e2e-schema
# Run all tests (unit + integration + E2E)
make test-all
# Check Docker availability
make check-docker
```
### Direct pytest
```bash
# All E2E tests
IS_TEST=True PYTHONPATH=. uv run pytest tests/e2e/ -v
# Only PostgreSQL tests
IS_TEST=True PYTHONPATH=. uv run pytest tests/e2e/ -v -m postgres
# Only Schemathesis tests
IS_TEST=True PYTHONPATH=. uv run pytest tests/e2e/ -v -m schemathesis
```
---
## Troubleshooting
### Docker Not Running
**Error:**
```
Docker is not running!
E2E tests require Docker to be running.
```
**Solution:**
```bash
# macOS/Windows
# Open Docker Desktop
# Linux
sudo systemctl start docker
```
### Testcontainers Not Installed
**Error:**
```
SKIPPED: testcontainers not installed - run: make install-e2e
```
**Solution:**
```bash
make install-e2e
# Or: uv sync --extra dev --extra e2e
```
### Container Startup Timeout
**Error:**
```
testcontainers.core.waiting_utils.UnexpectedResponse
```
**Solutions:**
1. Increase Docker resources (memory, CPU)
2. Pull the image manually: `docker pull postgres:17-alpine`
3. Check Docker daemon logs: `docker logs`
### Port Conflicts
**Error:**
```
Error starting container: port is already allocated
```
**Solution:**
Testcontainers uses random ports, so conflicts are rare. If occurring:
1. Stop other PostgreSQL containers: `docker stop $(docker ps -q)`
2. Check for orphaned containers: `docker container prune`
### Ryuk/Reaper Port 8080 Issues
**Error:**
```
ConnectionError: Port mapping for container ... and port 8080 is not available
```
**Solution:**
This is related to the Testcontainers Reaper (Ryuk) which handles automatic cleanup.
The `conftest.py` automatically disables Ryuk to avoid this issue. If you still encounter
this error, ensure you're using the latest conftest.py or set the environment variable:
```bash
export TESTCONTAINERS_RYUK_DISABLED=true
```
### Parallel Test Execution Issues
**Error:**
```
ScopeMismatch: ... cannot use a higher-scoped fixture 'postgres_container'
```
**Solution:**
E2E tests must run sequentially (not in parallel) because they share a session-scoped
PostgreSQL container. The Makefile commands use `-n 0` to disable parallel execution.
If running pytest directly, add `-n 0`:
```bash
IS_TEST=True PYTHONPATH=. uv run pytest tests/e2e/ -v -n 0
```
---
## CI/CD Integration
### GitHub Actions
A workflow template is provided at `.github/workflows/backend-e2e-tests.yml.template`.
To enable:
1. Rename to `backend-e2e-tests.yml`
2. Push to repository
The workflow:
- Runs on pushes to `main`/`develop` affecting `backend/`
- Uses `continue-on-error: true` (E2E failures don't block merge)
- Caches uv dependencies for speed
### Local CI Simulation
```bash
# Run what CI runs
make test-all
```
---
## Best Practices
### DO
- Use unique emails per test: `f"test-{uuid4().hex[:8]}@example.com"`
- Mark tests with appropriate markers: `@pytest.mark.e2e`
- Keep E2E tests focused on critical workflows
- Use `e2e_client` fixture for most tests
### DON'T
- Share state between tests (each test gets fresh tables)
- Test every endpoint in E2E (use unit tests for edge cases)
- Skip the `IS_TEST=True` environment variable
- Run E2E tests without Docker
---
## Further Reading
- [Testcontainers Documentation](https://testcontainers.com/guides/getting-started-with-testcontainers-for-python/)
- [Schemathesis Documentation](https://schemathesis.readthedocs.io/)
- [pytest-asyncio Documentation](https://pytest-asyncio.readthedocs.io/)

File diff suppressed because it is too large Load Diff

16
backend/entrypoint.sh Normal file → Executable file
View File

@@ -1,12 +1,22 @@
#!/bin/bash #!/bin/sh
set -e set -e
echo "Starting Backend" echo "Starting Backend"
# Ensure the project's virtualenv binaries are on PATH so commands like
# 'uvicorn' work even when not prefixed by 'uv run'. This matches how uv
# installs the env into /app/.venv in our containers.
if [ -d "/app/.venv/bin" ]; then
export PATH="/app/.venv/bin:$PATH"
fi
# Apply database migrations # Apply database migrations
uv run alembic upgrade head # Avoid installing the project in editable mode (which tries to write egg-info)
# when running inside a bind-mounted volume with restricted permissions.
# See: https://github.com/astral-sh/uv (use --no-project to skip project build)
uv run --no-project alembic upgrade head
# Initialize database (creates first superuser if needed) # Initialize database (creates first superuser if needed)
uv run python app/init_db.py uv run --no-project python app/init_db.py
# Execute the command passed to docker run # Execute the command passed to docker run
exec "$@" exec "$@"

View File

@@ -2,8 +2,32 @@
""" """
Database migration helper script. Database migration helper script.
Provides convenient commands for generating and applying Alembic migrations. Provides convenient commands for generating and applying Alembic migrations.
Usage:
# Generate migration (auto-increments revision ID: 0001, 0002, etc.)
python migrate.py --local generate "Add new field"
python migrate.py --local auto "Add new field"
# Apply migrations
python migrate.py --local apply
# Show next revision ID
python migrate.py next
# Reset after deleting migrations (clears alembic_version table)
python migrate.py --local reset
# Override auto-increment with custom revision ID
python migrate.py --local generate "initial_models" --rev-id custom_id
# Generate empty migration template without database (no autogenerate)
python migrate.py generate "Add performance indexes" --offline
# Inside Docker (without --local flag):
python migrate.py auto "Add new field"
""" """
import argparse import argparse
import os
import subprocess import subprocess
import sys import sys
from pathlib import Path from pathlib import Path
@@ -13,15 +37,21 @@ project_root = Path(__file__).resolve().parent
if str(project_root) not in sys.path: if str(project_root) not in sys.path:
sys.path.append(str(project_root)) sys.path.append(str(project_root))
try:
# Import settings to check if configuration is working
from app.core.config import settings
print(f"Using database URL: {settings.database_url}") def setup_database_url(use_local: bool) -> str:
except ImportError as e: """Setup database URL, optionally using localhost for local development."""
print(f"Error importing settings: {e}") if use_local:
print("Make sure your Python path includes the project root.") # Override DATABASE_URL to use localhost instead of Docker hostname
sys.exit(1) local_url = os.environ.get(
"LOCAL_DATABASE_URL",
"postgresql://postgres:postgres@localhost:5432/app"
)
os.environ["DATABASE_URL"] = local_url
return local_url
# Use the configured DATABASE_URL from environment/.env
from app.core.config import settings
return settings.database_url
def check_models(): def check_models():
@@ -40,11 +70,30 @@ def check_models():
return False return False
def generate_migration(message): def generate_migration(message, rev_id=None, auto_rev_id=True, offline=False):
"""Generate an Alembic migration with the given message""" """Generate an Alembic migration with the given message.
Args:
message: Migration message
rev_id: Custom revision ID (overrides auto_rev_id)
auto_rev_id: If True and rev_id is None, auto-generate sequential ID
offline: If True, generate empty migration without database (no autogenerate)
"""
# Auto-generate sequential revision ID if not provided
if rev_id is None and auto_rev_id:
rev_id = get_next_rev_id()
print(f"Generating migration: {message}") print(f"Generating migration: {message}")
if rev_id:
print(f"Using revision ID: {rev_id}")
if offline:
# Generate migration file directly without database connection
return generate_offline_migration(message, rev_id)
cmd = ["alembic", "revision", "--autogenerate", "-m", message] cmd = ["alembic", "revision", "--autogenerate", "-m", message]
if rev_id:
cmd.extend(["--rev-id", rev_id])
result = subprocess.run(cmd, capture_output=True, text=True) result = subprocess.run(cmd, capture_output=True, text=True)
print(result.stdout) print(result.stdout)
@@ -64,8 +113,9 @@ def generate_migration(message):
if len(part) >= 12 and all(c in "0123456789abcdef" for c in part[:12]): if len(part) >= 12 and all(c in "0123456789abcdef" for c in part[:12]):
revision = part[:12] revision = part[:12]
break break
except Exception: except Exception as e:
pass # If parsing fails, we can still proceed without a detected revision
print(f"Warning: could not parse revision from line '{line}': {e}")
if revision: if revision:
print(f"Generated revision: {revision}") print(f"Generated revision: {revision}")
@@ -131,8 +181,14 @@ def check_database_connection():
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
try: try:
engine = create_engine(settings.database_url) # Use DATABASE_URL from environment (set by setup_database_url)
with engine.connect() as conn: db_url = os.environ.get("DATABASE_URL")
if not db_url:
from app.core.config import settings
db_url = settings.database_url
engine = create_engine(db_url)
with engine.connect():
print("✓ Database connection successful!") print("✓ Database connection successful!")
return True return True
except SQLAlchemyError as e: except SQLAlchemyError as e:
@@ -140,16 +196,172 @@ def check_database_connection():
return False return False
def get_next_rev_id():
"""Get the next sequential revision ID based on existing migrations."""
import re
versions_dir = project_root / "app" / "alembic" / "versions"
if not versions_dir.exists():
return "0001"
# Find all migration files with numeric prefixes
max_num = 0
pattern = re.compile(r"^(\d{4})_.*\.py$")
for f in versions_dir.iterdir():
if f.is_file() and f.suffix == ".py":
match = pattern.match(f.name)
if match:
num = int(match.group(1))
max_num = max(max_num, num)
next_num = max_num + 1
return f"{next_num:04d}"
def get_current_rev_id():
"""Get the current (latest) revision ID from existing migrations."""
import re
versions_dir = project_root / "app" / "alembic" / "versions"
if not versions_dir.exists():
return None
# Find all migration files with numeric prefixes and get the highest
max_num = 0
max_rev_id = None
pattern = re.compile(r"^(\d{4})_.*\.py$")
for f in versions_dir.iterdir():
if f.is_file() and f.suffix == ".py":
match = pattern.match(f.name)
if match:
num = int(match.group(1))
if num > max_num:
max_num = num
max_rev_id = match.group(1)
return max_rev_id
def generate_offline_migration(message, rev_id):
"""Generate a migration file without database connection.
Creates an empty migration template that can be filled in manually.
Useful for performance indexes or when database is not available.
"""
from datetime import datetime
versions_dir = project_root / "app" / "alembic" / "versions"
versions_dir.mkdir(parents=True, exist_ok=True)
# Slugify the message for filename
slug = message.lower().replace(" ", "_").replace("-", "_")
slug = "".join(c for c in slug if c.isalnum() or c == "_")
filename = f"{rev_id}_{slug}.py"
filepath = versions_dir / filename
# Get the previous revision ID
down_revision = get_current_rev_id()
down_rev_str = f'"{down_revision}"' if down_revision else "None"
# Generate the migration file content
content = f'''"""{message}
Revision ID: {rev_id}
Revises: {down_revision or ''}
Create Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')}
"""
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "{rev_id}"
down_revision: str | None = {down_rev_str}
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
# TODO: Add your upgrade operations here
pass
def downgrade() -> None:
# TODO: Add your downgrade operations here
pass
'''
filepath.write_text(content)
print(f"Generated offline migration: {filepath}")
return rev_id
def show_next_rev_id():
"""Show the next sequential revision ID."""
next_id = get_next_rev_id()
print(f"Next revision ID: {next_id}")
print(f"\nUsage:")
print(f" python migrate.py --local generate 'your_message' --rev-id {next_id}")
print(f" python migrate.py --local auto 'your_message' --rev-id {next_id}")
return next_id
def reset_alembic_version():
"""Reset the alembic_version table (for fresh start after deleting migrations)."""
from sqlalchemy import create_engine, text
from sqlalchemy.exc import SQLAlchemyError
db_url = os.environ.get("DATABASE_URL")
if not db_url:
from app.core.config import settings
db_url = settings.database_url
try:
engine = create_engine(db_url)
with engine.connect() as conn:
conn.execute(text("DROP TABLE IF EXISTS alembic_version"))
conn.commit()
print("✓ Alembic version table reset successfully")
print(" You can now run migrations from scratch")
return True
except SQLAlchemyError as e:
print(f"✗ Error resetting alembic version: {e}")
return False
def main(): def main():
"""Main function""" """Main function"""
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Database migration helper for FastNext template' description='Database migration helper for Generative Models Arena'
) )
# Global options
parser.add_argument(
'--local', '-l',
action='store_true',
help='Use localhost instead of Docker hostname (for local development)'
)
subparsers = parser.add_subparsers(dest='command', help='Command to run') subparsers = parser.add_subparsers(dest='command', help='Command to run')
# Generate command # Generate command
generate_parser = subparsers.add_parser('generate', help='Generate a migration') generate_parser = subparsers.add_parser('generate', help='Generate a migration')
generate_parser.add_argument('message', help='Migration message') generate_parser.add_argument('message', help='Migration message')
generate_parser.add_argument(
'--rev-id',
help='Custom revision ID (e.g., 0001, 0002 for sequential naming)'
)
generate_parser.add_argument(
'--offline',
action='store_true',
help='Generate empty migration template without database connection'
)
# Apply command # Apply command
apply_parser = subparsers.add_parser('apply', help='Apply migrations') apply_parser = subparsers.add_parser('apply', help='Apply migrations')
@@ -164,15 +376,56 @@ def main():
# Check command # Check command
subparsers.add_parser('check', help='Check database connection and models') subparsers.add_parser('check', help='Check database connection and models')
# Next command (show next revision ID)
subparsers.add_parser('next', help='Show the next sequential revision ID')
# Reset command (clear alembic_version table)
subparsers.add_parser(
'reset',
help='Reset alembic_version table (use after deleting all migrations)'
)
# Auto command (generate and apply) # Auto command (generate and apply)
auto_parser = subparsers.add_parser('auto', help='Generate and apply migration') auto_parser = subparsers.add_parser('auto', help='Generate and apply migration')
auto_parser.add_argument('message', help='Migration message') auto_parser.add_argument('message', help='Migration message')
auto_parser.add_argument(
'--rev-id',
help='Custom revision ID (e.g., 0001, 0002 for sequential naming)'
)
auto_parser.add_argument(
'--offline',
action='store_true',
help='Generate empty migration template without database connection'
)
args = parser.parse_args() args = parser.parse_args()
# Commands that don't need database connection
if args.command == 'next':
show_next_rev_id()
return
# Check if offline mode is requested
offline = getattr(args, 'offline', False)
# Offline generate doesn't need database or model check
if args.command == 'generate' and offline:
generate_migration(args.message, rev_id=args.rev_id, offline=True)
return
if args.command == 'auto' and offline:
generate_migration(args.message, rev_id=args.rev_id, offline=True)
print("\nOffline migration generated. Apply it later with:")
print(f" python migrate.py --local apply")
return
# Setup database URL (must be done before importing settings elsewhere)
db_url = setup_database_url(args.local)
print(f"Using database URL: {db_url}")
if args.command == 'generate': if args.command == 'generate':
check_models() check_models()
generate_migration(args.message) generate_migration(args.message, rev_id=args.rev_id)
elif args.command == 'apply': elif args.command == 'apply':
apply_migration(args.revision) apply_migration(args.revision)
@@ -187,11 +440,14 @@ def main():
check_database_connection() check_database_connection()
check_models() check_models()
elif args.command == 'reset':
reset_alembic_version()
elif args.command == 'auto': elif args.command == 'auto':
check_models() check_models()
revision = generate_migration(args.message) revision = generate_migration(args.message, rev_id=args.rev_id)
if revision: if revision:
proceed = input("\nPress Enter to apply migration or Ctrl+C to abort... ") input("\nPress Enter to apply migration or Ctrl+C to abort... ")
apply_migration() apply_migration()
else: else:

View File

@@ -20,40 +20,36 @@ dependencies = [
"uvicorn>=0.34.0", "uvicorn>=0.34.0",
"pydantic>=2.10.6", "pydantic>=2.10.6",
"pydantic-settings>=2.2.1", "pydantic-settings>=2.2.1",
"python-multipart>=0.0.19", "python-multipart>=0.0.22",
"fastapi-utils==0.8.0", "fastapi-utils==0.8.0",
# Database # Database
"sqlalchemy>=2.0.29", "sqlalchemy>=2.0.29",
"alembic>=1.14.1", "alembic>=1.14.1",
"psycopg2-binary>=2.9.9", "psycopg2-binary>=2.9.9",
"asyncpg>=0.29.0", "asyncpg>=0.29.0",
"aiosqlite==0.21.0", "aiosqlite==0.21.0",
# Environment configuration # Environment configuration
"python-dotenv>=1.0.1", "python-dotenv>=1.0.1",
# API utilities # API utilities
"email-validator>=2.1.0.post1", "email-validator>=2.1.0.post1",
"ujson>=5.9.0", "ujson>=5.9.0",
# CORS and security # CORS and security
"starlette>=0.40.0", "starlette>=0.40.0",
"starlette-csrf>=1.4.5", "starlette-csrf>=1.4.5",
"slowapi>=0.1.9", "slowapi>=0.1.9",
# Utilities # Utilities
"httpx>=0.27.0", "httpx>=0.27.0",
"tenacity>=8.2.3", "tenacity>=8.2.3",
"pytz>=2024.1", "pytz>=2024.1",
"pillow>=10.3.0", "pillow>=12.1.1",
"apscheduler==3.11.0", "apscheduler==3.11.0",
# Security and authentication
# Security and authentication (pinned for reproducibility) "PyJWT>=2.9.0",
"python-jose==3.4.0",
"passlib==1.7.4",
"bcrypt==4.2.1", "bcrypt==4.2.1",
"cryptography==44.0.1", "cryptography>=46.0.5",
# OAuth authentication
"authlib>=1.6.6",
"urllib3>=2.6.3",
] ]
# Development dependencies # Development dependencies
@@ -69,7 +65,24 @@ dev = [
# Development tools # Development tools
"ruff>=0.8.0", # All-in-one: linting, formatting, import sorting "ruff>=0.8.0", # All-in-one: linting, formatting, import sorting
"mypy>=1.8.0", # Type checking "pyright>=1.1.390", # Type checking
# Security auditing
"pip-audit>=2.7.0", # Dependency vulnerability scanning (PyPA/OSV)
"pip-licenses>=4.0.0", # License compliance checking
"detect-secrets>=1.5.0", # Hardcoded secrets detection
# Performance benchmarking
"pytest-benchmark>=4.0.0", # Performance regression detection
# Pre-commit hooks
"pre-commit>=4.0.0", # Git pre-commit hook framework
]
# E2E testing with real PostgreSQL (requires Docker)
e2e = [
"testcontainers[postgres]>=4.0.0",
"schemathesis>=3.30.0",
] ]
# ============================================================================ # ============================================================================
@@ -122,6 +135,8 @@ select = [
"RUF", # Ruff-specific "RUF", # Ruff-specific
"ASYNC", # flake8-async "ASYNC", # flake8-async
"S", # flake8-bandit (security) "S", # flake8-bandit (security)
"G", # flake8-logging-format (logging best practices)
"T20", # flake8-print (no print statements in production code)
] ]
# Ignore specific rules # Ignore specific rules
@@ -145,11 +160,13 @@ unfixable = []
[tool.ruff.lint.per-file-ignores] [tool.ruff.lint.per-file-ignores]
"app/alembic/env.py" = ["E402", "F403", "F405"] # Alembic requires specific import order "app/alembic/env.py" = ["E402", "F403", "F405"] # Alembic requires specific import order
"app/alembic/versions/*.py" = ["E402"] # Migration files have specific structure "app/alembic/versions/*.py" = ["E402"] # Migration files have specific structure
"tests/**/*.py" = ["S101", "N806", "B017", "N817", "S110", "ASYNC251", "RUF043"] # pytest: asserts, CamelCase fixtures, blind exceptions, try-pass patterns, and async test helpers are intentional "tests/**/*.py" = ["S101", "N806", "B017", "N817", "ASYNC251", "RUF043", "T20"] # pytest: asserts, CamelCase fixtures, blind exceptions, async test helpers, and print for debugging are intentional
"app/models/__init__.py" = ["F401"] # __init__ files re-export modules "app/models/__init__.py" = ["F401"] # __init__ files re-export modules
"app/models/base.py" = ["F401"] # Re-exports Base for use by other models "app/models/base.py" = ["F401"] # Re-exports Base for use by other models
"app/utils/test_utils.py" = ["N806"] # SQLAlchemy session factories use CamelCase convention "app/utils/test_utils.py" = ["N806"] # SQLAlchemy session factories use CamelCase convention
"app/main.py" = ["N806"] # Constants use UPPER_CASE convention "app/main.py" = ["N806"] # Constants use UPPER_CASE convention
"app/init_db.py" = ["T20"] # CLI script uses print for user-facing output
"migrate.py" = ["T20"] # CLI script uses print for user-facing output
# ============================================================================ # ============================================================================
# Ruff Import Sorting (isort replacement) # Ruff Import Sorting (isort replacement)
@@ -176,116 +193,6 @@ indent-style = "space"
skip-magic-trailing-comma = false skip-magic-trailing-comma = false
line-ending = "lf" line-ending = "lf"
# ============================================================================
# mypy Configuration - Type Checking
# ============================================================================
[tool.mypy]
python_version = "3.12"
warn_return_any = false # SQLAlchemy queries return Any - overly strict
warn_unused_configs = true
disallow_untyped_defs = false # Gradual typing - enable later
disallow_incomplete_defs = false
check_untyped_defs = true
no_implicit_optional = true
warn_redundant_casts = true
warn_unused_ignores = true
warn_no_return = true
strict_equality = true
ignore_missing_imports = false
explicit_package_bases = true
namespace_packages = true
# Pydantic plugin for better validation
plugins = ["pydantic.mypy"]
# Per-module options
[[tool.mypy.overrides]]
module = "alembic.*"
ignore_errors = true
[[tool.mypy.overrides]]
module = "app.alembic.*"
ignore_errors = true
[[tool.mypy.overrides]]
module = "sqlalchemy.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "fastapi_utils.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "slowapi.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "jose.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "passlib.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "pydantic_settings.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "fastapi.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "apscheduler.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "starlette.*"
ignore_missing_imports = true
# SQLAlchemy ORM models - Column descriptors cause type confusion
[[tool.mypy.overrides]]
module = "app.models.*"
disable_error_code = ["assignment", "arg-type", "return-value"]
# CRUD operations - Generic ModelType and SQLAlchemy Result issues
[[tool.mypy.overrides]]
module = "app.crud.*"
disable_error_code = ["attr-defined", "assignment", "arg-type", "return-value"]
# API routes - SQLAlchemy Column to Pydantic schema conversions
[[tool.mypy.overrides]]
module = "app.api.routes.*"
disable_error_code = ["arg-type", "call-arg", "call-overload", "assignment"]
# API dependencies - Similar SQLAlchemy Column issues
[[tool.mypy.overrides]]
module = "app.api.dependencies.*"
disable_error_code = ["arg-type"]
# FastAPI exception handlers have correct signatures despite mypy warnings
[[tool.mypy.overrides]]
module = "app.main"
disable_error_code = ["arg-type"]
# Auth service - SQLAlchemy Column issues
[[tool.mypy.overrides]]
module = "app.services.auth_service"
disable_error_code = ["assignment", "arg-type"]
# Test utils - Testing patterns
[[tool.mypy.overrides]]
module = "app.utils.auth_test_utils"
disable_error_code = ["assignment", "arg-type"]
# ============================================================================
# Pydantic mypy plugin configuration
# ============================================================================
[tool.pydantic-mypy]
init_forbid_extra = true
init_typed = true
warn_required_dynamic_aliases = true
# ============================================================================ # ============================================================================
# Pytest Configuration # Pytest Configuration
# ============================================================================ # ============================================================================
@@ -302,10 +209,15 @@ addopts = [
"--cov=app", "--cov=app",
"--cov-report=term-missing", "--cov-report=term-missing",
"--cov-report=html", "--cov-report=html",
"--ignore=tests/benchmarks", # benchmarks are incompatible with xdist; run via 'make benchmark'
"-p", "no:benchmark", # disable pytest-benchmark plugin during normal runs (conflicts with xdist)
] ]
markers = [ markers = [
"sqlite: marks tests that should run on SQLite (mocked).", "sqlite: marks tests that should run on SQLite (mocked).",
"postgres: marks tests that require a real PostgreSQL database.", "postgres: marks tests that require a real PostgreSQL database.",
"e2e: marks end-to-end tests requiring Docker containers.",
"schemathesis: marks Schemathesis-generated API tests.",
"benchmark: marks performance benchmark tests.",
] ]
asyncio_default_fixture_loop_scope = "function" asyncio_default_fixture_loop_scope = "function"
@@ -319,6 +231,7 @@ omit = [
"*/__pycache__/*", "*/__pycache__/*",
"*/alembic/versions/*", "*/alembic/versions/*",
"*/.venv/*", "*/.venv/*",
"app/init_db.py", # CLI script for database initialization
] ]
branch = true branch = true

View File

@@ -0,0 +1,23 @@
{
"include": ["app"],
"exclude": ["app/alembic"],
"pythonVersion": "3.12",
"venvPath": ".",
"venv": ".venv",
"typeCheckingMode": "standard",
"reportMissingImports": true,
"reportMissingTypeStubs": false,
"reportUnknownMemberType": false,
"reportUnknownVariableType": false,
"reportUnknownArgumentType": false,
"reportUnknownParameterType": false,
"reportUnknownLambdaType": false,
"reportReturnType": true,
"reportUnusedImport": false,
"reportGeneralTypeIssues": false,
"reportAttributeAccessIssue": false,
"reportArgumentType": false,
"strictListInference": false,
"strictDictionaryInference": false,
"strictSetInference": false
}

View File

@@ -0,0 +1,313 @@
# tests/api/dependencies/test_locale_dependencies.py
import uuid
from unittest.mock import MagicMock
import pytest
import pytest_asyncio
from app.api.dependencies.locale import (
DEFAULT_LOCALE,
SUPPORTED_LOCALES,
get_locale,
parse_accept_language,
)
from app.core.auth import get_password_hash
from app.models.user import User
class TestParseAcceptLanguage:
"""Tests for parse_accept_language helper function"""
def test_parse_empty_header(self):
"""Test with empty Accept-Language header"""
result = parse_accept_language("")
assert result is None
def test_parse_none_header(self):
"""Test with None Accept-Language header"""
result = parse_accept_language(None)
assert result is None
def test_parse_italian_exact_match(self):
"""Test parsing Italian with exact match"""
result = parse_accept_language("it-IT,it;q=0.9,en;q=0.8")
assert result == "it-it"
def test_parse_italian_language_code_only(self):
"""Test parsing Italian with only language code"""
result = parse_accept_language("it,en;q=0.8")
assert result == "it"
def test_parse_english_us(self):
"""Test parsing English (US)"""
result = parse_accept_language("en-US,en;q=0.9")
assert result == "en-us"
def test_parse_english_language_code(self):
"""Test parsing English with only language code"""
result = parse_accept_language("en")
assert result == "en"
def test_parse_unsupported_language(self):
"""Test parsing unsupported language (French)"""
result = parse_accept_language("fr-FR,fr;q=0.9,de;q=0.8")
assert result is None
def test_parse_mixed_supported_unsupported(self):
"""Test with mix of supported and unsupported, should pick first supported"""
# French first (unsupported), then Italian (supported)
result = parse_accept_language("fr-FR,fr;q=0.9,it;q=0.8")
assert result == "it"
def test_parse_quality_values(self):
"""Test that quality values are respected (first = highest priority)"""
# English has higher implicit priority (no q value means q=1.0)
result = parse_accept_language("en,it;q=0.9")
assert result == "en"
def test_parse_complex_header(self):
"""Test complex Accept-Language header with multiple locales"""
result = parse_accept_language("it-IT,it;q=0.9,en-US;q=0.8,en;q=0.7,fr;q=0.6")
assert result == "it-it"
def test_parse_whitespace_handling(self):
"""Test that whitespace is handled correctly"""
result = parse_accept_language(" it-IT , it ; q=0.9 , en ; q=0.8 ")
assert result == "it-it"
def test_parse_case_insensitive(self):
"""Test that locale matching is case-insensitive"""
result = parse_accept_language("IT-it,EN-us;q=0.9")
# Should normalize to lowercase
assert result == "it-it"
def test_parse_fallback_to_language_code(self):
"""Test fallback from region-specific to language code"""
# it-CH (Switzerland) not supported, but "it" is
result = parse_accept_language("it-CH,en;q=0.8")
assert result == "it"
@pytest_asyncio.fixture
async def async_user_with_locale_en(async_test_db):
"""Async fixture to create a user with 'en' locale preference"""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user = User(
id=uuid.uuid4(),
email="user_en@example.com",
password_hash=get_password_hash("password123"),
first_name="English",
last_name="User",
is_active=True,
is_superuser=False,
locale="en",
)
session.add(user)
await session.commit()
await session.refresh(user)
return user
@pytest_asyncio.fixture
async def async_user_with_locale_it(async_test_db):
"""Async fixture to create a user with 'it' locale preference"""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user = User(
id=uuid.uuid4(),
email="user_it@example.com",
password_hash=get_password_hash("password123"),
first_name="Italian",
last_name="User",
is_active=True,
is_superuser=False,
locale="it",
)
session.add(user)
await session.commit()
await session.refresh(user)
return user
@pytest_asyncio.fixture
async def async_user_without_locale(async_test_db):
"""Async fixture to create a user without locale preference"""
_test_engine, AsyncTestingSessionLocal = async_test_db
async with AsyncTestingSessionLocal() as session:
user = User(
id=uuid.uuid4(),
email="user_no_locale@example.com",
password_hash=get_password_hash("password123"),
first_name="No",
last_name="Locale",
is_active=True,
is_superuser=False,
locale=None,
)
session.add(user)
await session.commit()
await session.refresh(user)
return user
class TestGetLocale:
"""Tests for get_locale dependency"""
@pytest.mark.asyncio
async def test_locale_from_user_preference_en(self, async_user_with_locale_en):
"""Test locale detection from authenticated user's saved preference (en)"""
# Mock request with no Accept-Language header
mock_request = MagicMock()
mock_request.headers = {}
result = await get_locale(
request=mock_request, current_user=async_user_with_locale_en
)
assert result == "en"
@pytest.mark.asyncio
async def test_locale_from_user_preference_it(self, async_user_with_locale_it):
"""Test locale detection from authenticated user's saved preference (it)"""
# Mock request with no Accept-Language header
mock_request = MagicMock()
mock_request.headers = {}
result = await get_locale(
request=mock_request, current_user=async_user_with_locale_it
)
assert result == "it"
@pytest.mark.asyncio
async def test_user_preference_overrides_accept_language(
self, async_user_with_locale_en
):
"""Test that user preference takes precedence over Accept-Language header"""
# Mock request with Italian Accept-Language, but user has English preference
mock_request = MagicMock()
mock_request.headers = {"accept-language": "it-IT,it;q=0.9"}
result = await get_locale(
request=mock_request, current_user=async_user_with_locale_en
)
# Should return user preference, not Accept-Language
assert result == "en"
@pytest.mark.asyncio
async def test_locale_from_accept_language_header(self, async_user_without_locale):
"""Test locale detection from Accept-Language header when user has no preference"""
# Mock request with Italian Accept-Language (it-IT has highest priority)
mock_request = MagicMock()
mock_request.headers = {"accept-language": "it-IT,it;q=0.9,en;q=0.8"}
result = await get_locale(
request=mock_request, current_user=async_user_without_locale
)
# Should return "it-it" (normalized from "it-IT", the first/highest priority locale)
assert result == "it-it"
@pytest.mark.asyncio
async def test_locale_from_accept_language_unauthenticated(self):
"""Test locale detection from Accept-Language header for unauthenticated user"""
# Mock request with Italian Accept-Language (it-IT has highest priority)
mock_request = MagicMock()
mock_request.headers = {"accept-language": "it-IT,it;q=0.9,en;q=0.8"}
result = await get_locale(request=mock_request, current_user=None)
# Should return "it-it" (normalized from "it-IT", the first/highest priority locale)
assert result == "it-it"
@pytest.mark.asyncio
async def test_default_locale_no_user_no_header(self):
"""Test fallback to default locale when no user and no Accept-Language header"""
# Mock request with no Accept-Language header
mock_request = MagicMock()
mock_request.headers = {}
result = await get_locale(request=mock_request, current_user=None)
assert result == DEFAULT_LOCALE
assert result == "en"
@pytest.mark.asyncio
async def test_default_locale_unsupported_language(self):
"""Test fallback to default when Accept-Language has only unsupported languages"""
# Mock request with French (unsupported)
mock_request = MagicMock()
mock_request.headers = {"accept-language": "fr-FR,fr;q=0.9,de;q=0.8"}
result = await get_locale(request=mock_request, current_user=None)
assert result == DEFAULT_LOCALE
assert result == "en"
@pytest.mark.asyncio
async def test_validate_supported_locale_in_db(self, async_user_with_locale_it):
"""Test that saved locale is validated against SUPPORTED_LOCALES"""
# This test verifies the locale in DB is actually supported
assert async_user_with_locale_it.locale in SUPPORTED_LOCALES
mock_request = MagicMock()
mock_request.headers = {}
result = await get_locale(
request=mock_request, current_user=async_user_with_locale_it
)
assert result == "it"
assert result in SUPPORTED_LOCALES
@pytest.mark.asyncio
async def test_accept_language_case_variations(self):
"""Test different case variations in Accept-Language header"""
# All return values are lowercase for consistency
test_cases = [
("it-IT,en;q=0.8", "it-it"),
("IT-it,en;q=0.8", "it-it"),
("en-US,it;q=0.8", "en-us"),
("EN,it;q=0.8", "en"),
]
for accept_lang, expected in test_cases:
mock_request = MagicMock()
mock_request.headers = {"accept-language": accept_lang}
result = await get_locale(request=mock_request, current_user=None)
assert result == expected
@pytest.mark.asyncio
async def test_accept_language_with_quality_values(self):
"""Test Accept-Language parsing respects quality values (priority)"""
# English has implicit q=1.0, Italian has q=0.9
mock_request = MagicMock()
mock_request.headers = {"accept-language": "en,it;q=0.9"}
result = await get_locale(request=mock_request, current_user=None)
# Should return English (higher priority)
assert result == "en"
@pytest.mark.asyncio
async def test_supported_locales_constant(self):
"""Test that SUPPORTED_LOCALES contains expected locales"""
# Note: SUPPORTED_LOCALES uses lowercase for case-insensitive matching
assert "en" in SUPPORTED_LOCALES
assert "it" in SUPPORTED_LOCALES
assert "en-us" in SUPPORTED_LOCALES
assert "en-gb" in SUPPORTED_LOCALES
assert "it-it" in SUPPORTED_LOCALES
# Verify total count matches implementation plan (5 locales for EN/IT showcase)
assert len(SUPPORTED_LOCALES) == 5
@pytest.mark.asyncio
async def test_default_locale_constant(self):
"""Test that DEFAULT_LOCALE is English"""
assert DEFAULT_LOCALE == "en"
assert DEFAULT_LOCALE in SUPPORTED_LOCALES

View File

@@ -147,7 +147,7 @@ class TestAdminCreateUser:
headers={"Authorization": f"Bearer {superuser_token}"}, headers={"Authorization": f"Bearer {superuser_token}"},
) )
assert response.status_code == status.HTTP_404_NOT_FOUND assert response.status_code == status.HTTP_409_CONFLICT
class TestAdminGetUser: class TestAdminGetUser:
@@ -565,7 +565,7 @@ class TestAdminCreateOrganization:
headers={"Authorization": f"Bearer {superuser_token}"}, headers={"Authorization": f"Bearer {superuser_token}"},
) )
assert response.status_code == status.HTTP_404_NOT_FOUND assert response.status_code == status.HTTP_409_CONFLICT
class TestAdminGetOrganization: class TestAdminGetOrganization:
@@ -923,6 +923,27 @@ class TestAdminRemoveOrganizationMember:
assert response.status_code == status.HTTP_404_NOT_FOUND assert response.status_code == status.HTTP_404_NOT_FOUND
@pytest.mark.asyncio
async def test_admin_remove_organization_member_user_not_found(
self, client, async_test_superuser, async_test_db, superuser_token
):
"""Test removing non-existent user from organization."""
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create organization
async with AsyncTestingSessionLocal() as session:
org = Organization(name="User Not Found Org", slug="user-not-found-org")
session.add(org)
await session.commit()
org_id = org.id
response = await client.delete(
f"/api/v1/admin/organizations/{org_id}/members/{uuid4()}",
headers={"Authorization": f"Bearer {superuser_token}"},
)
assert response.status_code == status.HTTP_404_NOT_FOUND
# ===== SESSION MANAGEMENT TESTS ===== # ===== SESSION MANAGEMENT TESTS =====
@@ -1097,3 +1118,102 @@ class TestAdminListSessions:
) )
assert response.status_code == status.HTTP_403_FORBIDDEN assert response.status_code == status.HTTP_403_FORBIDDEN
# ===== ADMIN STATS TESTS =====
class TestAdminStats:
"""Tests for GET /admin/stats endpoint."""
@pytest.mark.asyncio
async def test_admin_get_stats_with_data(
self,
client,
async_test_superuser,
async_test_user,
async_test_db,
superuser_token,
):
"""Test getting admin stats with real data in database."""
_test_engine, AsyncTestingSessionLocal = async_test_db
# Create multiple users and organizations with members
async with AsyncTestingSessionLocal() as session:
from app.core.auth import get_password_hash
from app.models.user import User
# Create several users
for i in range(5):
user = User(
email=f"statsuser{i}@example.com",
password_hash=get_password_hash("TestPassword123!"),
first_name=f"Stats{i}",
last_name="User",
is_active=i % 2 == 0, # Mix of active/inactive
)
session.add(user)
await session.commit()
# Create organizations with members
async with AsyncTestingSessionLocal() as session:
orgs = []
for i in range(3):
org = Organization(name=f"Stats Org {i}", slug=f"stats-org-{i}")
session.add(org)
orgs.append(org)
await session.flush()
# Add some members to organizations
user_org = UserOrganization(
user_id=async_test_user.id,
organization_id=orgs[0].id,
role=OrganizationRole.MEMBER,
is_active=True,
)
session.add(user_org)
await session.commit()
response = await client.get(
"/api/v1/admin/stats",
headers={"Authorization": f"Bearer {superuser_token}"},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
# Verify response structure
assert "user_growth" in data
assert "organization_distribution" in data
assert "registration_activity" in data
assert "user_status" in data
# Verify user_growth has 30 days of data
assert len(data["user_growth"]) == 30
for item in data["user_growth"]:
assert "date" in item
assert "total_users" in item
assert "active_users" in item
# Verify registration_activity has 14 days of data
assert len(data["registration_activity"]) == 14
for item in data["registration_activity"]:
assert "date" in item
assert "registrations" in item
# Verify user_status has active/inactive counts
assert len(data["user_status"]) == 2
status_names = {item["name"] for item in data["user_status"]}
assert status_names == {"Active", "Inactive"}
@pytest.mark.asyncio
async def test_admin_get_stats_unauthorized(
self, client, async_test_user, user_token
):
"""Test that non-admin users cannot access stats endpoint."""
response = await client.get(
"/api/v1/admin/stats",
headers={"Authorization": f"Bearer {user_token}"},
)
assert response.status_code == status.HTTP_403_FORBIDDEN

View File

@@ -45,7 +45,7 @@ class TestAdminListUsersFilters:
async def test_list_users_database_error_propagates(self, client, superuser_token): async def test_list_users_database_error_propagates(self, client, superuser_token):
"""Test that database errors propagate correctly (covers line 118-120).""" """Test that database errors propagate correctly (covers line 118-120)."""
with patch( with patch(
"app.api.routes.admin.user_crud.get_multi_with_total", "app.api.routes.admin.user_service.list_users",
side_effect=Exception("DB error"), side_effect=Exception("DB error"),
): ):
with pytest.raises(Exception): with pytest.raises(Exception):
@@ -74,8 +74,8 @@ class TestAdminCreateUserErrors:
}, },
) )
# Should get error for duplicate email # Should get conflict for duplicate email
assert response.status_code == status.HTTP_404_NOT_FOUND assert response.status_code == status.HTTP_409_CONFLICT
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_user_unexpected_error_propagates( async def test_create_user_unexpected_error_propagates(
@@ -83,7 +83,7 @@ class TestAdminCreateUserErrors:
): ):
"""Test unexpected errors during user creation (covers line 151-153).""" """Test unexpected errors during user creation (covers line 151-153)."""
with patch( with patch(
"app.api.routes.admin.user_crud.create", "app.api.routes.admin.user_service.create_user",
side_effect=RuntimeError("Unexpected error"), side_effect=RuntimeError("Unexpected error"),
): ):
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
@@ -135,7 +135,7 @@ class TestAdminUpdateUserErrors:
): ):
"""Test unexpected errors during user update (covers line 206-208).""" """Test unexpected errors during user update (covers line 206-208)."""
with patch( with patch(
"app.api.routes.admin.user_crud.update", "app.api.routes.admin.user_service.update_user",
side_effect=RuntimeError("Update failed"), side_effect=RuntimeError("Update failed"),
): ):
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
@@ -166,7 +166,7 @@ class TestAdminDeleteUserErrors:
): ):
"""Test unexpected errors during user deletion (covers line 238-240).""" """Test unexpected errors during user deletion (covers line 238-240)."""
with patch( with patch(
"app.api.routes.admin.user_crud.soft_delete", "app.api.routes.admin.user_service.soft_delete_user",
side_effect=Exception("Delete failed"), side_effect=Exception("Delete failed"),
): ):
with pytest.raises(Exception): with pytest.raises(Exception):
@@ -196,7 +196,7 @@ class TestAdminActivateUserErrors:
): ):
"""Test unexpected errors during user activation (covers line 282-284).""" """Test unexpected errors during user activation (covers line 282-284)."""
with patch( with patch(
"app.api.routes.admin.user_crud.update", "app.api.routes.admin.user_service.update_user",
side_effect=Exception("Activation failed"), side_effect=Exception("Activation failed"),
): ):
with pytest.raises(Exception): with pytest.raises(Exception):
@@ -238,7 +238,7 @@ class TestAdminDeactivateUserErrors:
): ):
"""Test unexpected errors during user deactivation (covers line 326-328).""" """Test unexpected errors during user deactivation (covers line 326-328)."""
with patch( with patch(
"app.api.routes.admin.user_crud.update", "app.api.routes.admin.user_service.update_user",
side_effect=Exception("Deactivation failed"), side_effect=Exception("Deactivation failed"),
): ):
with pytest.raises(Exception): with pytest.raises(Exception):
@@ -258,7 +258,7 @@ class TestAdminListOrganizationsErrors:
async def test_list_organizations_database_error(self, client, superuser_token): async def test_list_organizations_database_error(self, client, superuser_token):
"""Test list organizations with database error (covers line 427-456).""" """Test list organizations with database error (covers line 427-456)."""
with patch( with patch(
"app.api.routes.admin.organization_crud.get_multi_with_member_counts", "app.api.routes.admin.organization_service.get_multi_with_member_counts",
side_effect=Exception("DB error"), side_effect=Exception("DB error"),
): ):
with pytest.raises(Exception): with pytest.raises(Exception):
@@ -299,14 +299,14 @@ class TestAdminCreateOrganizationErrors:
}, },
) )
# Should get error for duplicate slug # Should get conflict for duplicate slug
assert response.status_code == status.HTTP_404_NOT_FOUND assert response.status_code == status.HTTP_409_CONFLICT
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_organization_unexpected_error(self, client, superuser_token): async def test_create_organization_unexpected_error(self, client, superuser_token):
"""Test unexpected errors during organization creation (covers line 484-485).""" """Test unexpected errors during organization creation (covers line 484-485)."""
with patch( with patch(
"app.api.routes.admin.organization_crud.create", "app.api.routes.admin.organization_service.create_organization",
side_effect=RuntimeError("Creation failed"), side_effect=RuntimeError("Creation failed"),
): ):
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
@@ -367,7 +367,7 @@ class TestAdminUpdateOrganizationErrors:
org_id = org.id org_id = org.id
with patch( with patch(
"app.api.routes.admin.organization_crud.update", "app.api.routes.admin.organization_service.update_organization",
side_effect=Exception("Update failed"), side_effect=Exception("Update failed"),
): ):
with pytest.raises(Exception): with pytest.raises(Exception):
@@ -412,7 +412,7 @@ class TestAdminDeleteOrganizationErrors:
org_id = org.id org_id = org.id
with patch( with patch(
"app.api.routes.admin.organization_crud.remove", "app.api.routes.admin.organization_service.remove_organization",
side_effect=Exception("Delete failed"), side_effect=Exception("Delete failed"),
): ):
with pytest.raises(Exception): with pytest.raises(Exception):
@@ -456,7 +456,7 @@ class TestAdminListOrganizationMembersErrors:
org_id = org.id org_id = org.id
with patch( with patch(
"app.api.routes.admin.organization_crud.get_organization_members", "app.api.routes.admin.organization_service.get_organization_members",
side_effect=Exception("DB error"), side_effect=Exception("DB error"),
): ):
with pytest.raises(Exception): with pytest.raises(Exception):
@@ -531,7 +531,7 @@ class TestAdminAddOrganizationMemberErrors:
org_id = org.id org_id = org.id
with patch( with patch(
"app.api.routes.admin.organization_crud.add_user", "app.api.routes.admin.organization_service.add_member",
side_effect=Exception("Add failed"), side_effect=Exception("Add failed"),
): ):
with pytest.raises(Exception): with pytest.raises(Exception):
@@ -587,7 +587,7 @@ class TestAdminRemoveOrganizationMemberErrors:
org_id = org.id org_id = org.id
with patch( with patch(
"app.api.routes.admin.organization_crud.remove_user", "app.api.routes.admin.organization_service.remove_member",
side_effect=Exception("Remove failed"), side_effect=Exception("Remove failed"),
): ):
with pytest.raises(Exception): with pytest.raises(Exception):

Some files were not shown because too many files have changed in this diff Show More