From 68275b1dd3ec1fc8d431dd03a9d496c05ab2dd4e Mon Sep 17 00:00:00 2001 From: Felipe Cardoso Date: Sun, 1 Mar 2026 11:13:51 +0100 Subject: [PATCH] refactor(docs): update architecture to reflect repository migration - Rename CRUD layer to Repository layer throughout architecture documentation. - Update dependency injection examples to use repository classes. - Add async SQLAlchemy pattern for Repository methods (`select()` and transactions). - Replace CRUD references in FEATURE_EXAMPLE.md with Repository-focused implementation details. - Highlight repository class responsibilities and remove outdated CRUD patterns. --- backend/docs/ARCHITECTURE.md | 178 +++---- backend/docs/CODING_STANDARDS.md | 90 ++-- backend/docs/COMMON_PITFALLS.md | 62 ++- backend/docs/FEATURE_EXAMPLE.md | 792 ++++++++----------------------- 4 files changed, 349 insertions(+), 773 deletions(-) diff --git a/backend/docs/ARCHITECTURE.md b/backend/docs/ARCHITECTURE.md index 0172ed4..0e994ca 100644 --- a/backend/docs/ARCHITECTURE.md +++ b/backend/docs/ARCHITECTURE.md @@ -117,7 +117,8 @@ backend/ │ ├── api/ # API layer │ │ ├── dependencies/ # Dependency injection │ │ │ ├── auth.py # Authentication dependencies -│ │ │ └── permissions.py # Authorization dependencies +│ │ │ ├── permissions.py # Authorization dependencies +│ │ │ └── services.py # Service singleton injection │ │ ├── routes/ # API endpoints │ │ │ ├── auth.py # Authentication routes │ │ │ ├── users.py # User management routes @@ -131,13 +132,14 @@ backend/ │ │ ├── config.py # Application configuration │ │ ├── database.py # Database connection │ │ ├── exceptions.py # Custom exception classes +│ │ ├── repository_exceptions.py # Repository-level exception hierarchy │ │ └── middleware.py # Custom middleware │ │ -│ ├── crud/ # Database operations -│ │ ├── base.py # Generic CRUD base class -│ │ ├── user.py # User CRUD operations -│ │ ├── session.py # Session CRUD operations -│ │ └── organization.py # Organization CRUD +│ ├── repositories/ # Data access layer +│ │ ├── base.py # Generic repository base class +│ │ ├── user.py # User repository +│ │ ├── session.py # Session repository +│ │ └── organization.py # Organization repository │ │ │ ├── models/ # SQLAlchemy models │ │ ├── base.py # Base model with mixins @@ -153,8 +155,11 @@ backend/ │ │ ├── sessions.py # Session schemas │ │ └── organizations.py # Organization schemas │ │ -│ ├── services/ # Business logic +│ ├── services/ # Business logic layer │ │ ├── auth_service.py # Authentication service +│ │ ├── user_service.py # User management service +│ │ ├── session_service.py # Session management service +│ │ ├── organization_service.py # Organization service │ │ ├── email_service.py # Email service │ │ └── session_cleanup.py # Background cleanup │ │ @@ -168,9 +173,9 @@ backend/ │ ├── tests/ # Test suite │ ├── api/ # Integration tests -│ ├── crud/ # CRUD tests +│ ├── repositories/ # Repository unit tests +│ ├── services/ # Service unit tests │ ├── models/ # Model tests -│ ├── services/ # Service tests │ └── conftest.py # Test configuration │ ├── docs/ # Documentation @@ -214,11 +219,11 @@ The application follows a strict 5-layer architecture: └──────────────────────────┬──────────────────────────────────┘ │ calls ┌──────────────────────────▼──────────────────────────────────┐ -│ CRUD Layer (crud/) │ +│ Repository Layer (repositories/) │ │ - Database operations │ │ - Query building │ -│ - Transaction management │ -│ - Error handling │ +│ - Custom repository exceptions │ +│ - No business logic │ └──────────────────────────┬──────────────────────────────────┘ │ uses ┌──────────────────────────▼──────────────────────────────────┐ @@ -262,7 +267,7 @@ async def get_current_user_info( **Rules**: - Should NOT contain business logic -- Should NOT directly perform database operations (use CRUD or services) +- Should NOT directly call repositories (use services injected via `dependencies/services.py`) - Must validate all input via Pydantic schemas - Must specify response models - Should apply appropriate rate limits @@ -279,9 +284,9 @@ async def get_current_user_info( **Example**: ```python -def get_current_user( +async def get_current_user( token: str = Depends(oauth2_scheme), - db: Session = Depends(get_db) + db: AsyncSession = Depends(get_db) ) -> User: """ Extract and validate user from JWT token. @@ -295,7 +300,7 @@ def get_current_user( except Exception: raise AuthenticationError("Invalid authentication credentials") - user = user_crud.get(db, id=user_id) + user = await user_repo.get(db, id=user_id) if not user: raise AuthenticationError("User not found") @@ -313,7 +318,7 @@ def get_current_user( **Responsibility**: Implement complex business logic **Key Functions**: -- Orchestrate multiple CRUD operations +- Orchestrate multiple repository operations - Implement business rules - Handle external service integration - Coordinate transactions @@ -323,9 +328,9 @@ def get_current_user( class AuthService: """Authentication service with business logic.""" - def login( + async def login( self, - db: Session, + db: AsyncSession, email: str, password: str, request: Request @@ -339,8 +344,8 @@ class AuthService: 3. Generate tokens 4. Return tokens and user info """ - # Validate credentials - user = user_crud.get_by_email(db, email=email) + # Validate credentials via repository + user = await user_repo.get_by_email(db, email=email) if not user or not verify_password(password, user.hashed_password): raise AuthenticationError("Invalid credentials") @@ -350,11 +355,10 @@ class AuthService: # Extract device info device_info = extract_device_info(request) - # Create session - session = session_crud.create_session( + # Create session via repository + session = await session_repo.create( db, - user_id=user.id, - device_info=device_info + obj_in=SessionCreate(user_id=user.id, **device_info) ) # Generate tokens @@ -373,75 +377,60 @@ class AuthService: **Rules**: - Contains business logic, not just data operations -- Can call multiple CRUD operations +- Can call multiple repository operations - Should handle complex workflows - Must maintain data consistency - Should use transactions when needed -#### 4. CRUD Layer (`app/crud/`) +#### 4. Repository Layer (`app/repositories/`) -**Responsibility**: Database operations and queries +**Responsibility**: Database operations and queries — no business logic **Key Functions**: - Create, read, update, delete operations - Build database queries -- Handle database errors +- Raise custom repository exceptions (`DuplicateEntryError`, `IntegrityConstraintError`) - Manage soft deletes - Implement pagination and filtering **Example**: ```python -class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): - """CRUD operations for user sessions.""" +class SessionRepository(RepositoryBase[UserSession, SessionCreate, SessionUpdate]): + """Repository for user sessions — database operations only.""" - def get_by_jti(self, db: Session, jti: UUID) -> Optional[UserSession]: + async def get_by_jti(self, db: AsyncSession, *, jti: str) -> UserSession | None: """Get session by refresh token JTI.""" - try: - return ( - db.query(UserSession) - .filter(UserSession.refresh_token_jti == jti) - .first() - ) - except Exception as e: - logger.error(f"Error getting session by JTI: {str(e)}") - return None + result = await db.execute( + select(UserSession).where(UserSession.refresh_token_jti == jti) + ) + return result.scalar_one_or_none() - def get_active_by_jti( - self, - db: Session, - jti: UUID - ) -> Optional[UserSession]: - """Get active session by refresh token JTI.""" - session = self.get_by_jti(db, jti=jti) - if session and session.is_active and not session.is_expired: - return session - return None - - def deactivate(self, db: Session, session_id: UUID) -> bool: + async def deactivate(self, db: AsyncSession, *, session_id: UUID) -> bool: """Deactivate a session (logout).""" try: - session = self.get(db, id=session_id) + session = await self.get(db, id=session_id) if not session: return False session.is_active = False - db.commit() + await db.commit() logger.info(f"Session {session_id} deactivated") return True except Exception as e: - db.rollback() + await db.rollback() logger.error(f"Error deactivating session: {str(e)}") return False ``` **Rules**: - Should NOT contain business logic -- Must handle database exceptions -- Must use parameterized queries (SQLAlchemy does this) +- Must raise custom repository exceptions (not raw `ValueError`/`IntegrityError`) +- Must use async SQLAlchemy 2.0 `select()` API (never `db.query()`) - Should log all database errors - Must rollback on errors - Should use soft deletes when possible +- **Never imported directly by routes** — always called through services #### 5. Data Layer (`app/models/` + `app/schemas/`) @@ -546,51 +535,23 @@ SessionLocal = sessionmaker( #### Dependency Injection Pattern ```python -def get_db() -> Generator[Session, None, None]: +async def get_db() -> AsyncGenerator[AsyncSession, None]: """ - Database session dependency for FastAPI routes. + Async database session dependency for FastAPI routes. - Automatically commits on success, rolls back on error. + The session is passed to service methods; commit/rollback is + managed inside service or repository methods. """ - db = SessionLocal() - try: + async with AsyncSessionLocal() as db: yield db - finally: - db.close() -# Usage in routes +# Usage in routes — always through a service, never direct repository @router.get("/users") -def list_users(db: Session = Depends(get_db)): - return user_crud.get_multi(db) -``` - -#### Context Manager Pattern - -```python -@contextmanager -def transaction_scope() -> Generator[Session, None, None]: - """ - Context manager for database transactions. - - Use for complex operations requiring multiple steps. - Automatically commits on success, rolls back on error. - """ - db = SessionLocal() - try: - yield db - db.commit() - except Exception: - db.rollback() - raise - finally: - db.close() - -# Usage in services -def complex_operation(): - with transaction_scope() as db: - user = user_crud.create(db, obj_in=user_data) - session = session_crud.create(db, session_data) - return user, session +async def list_users( + user_service: UserService = Depends(get_user_service), + db: AsyncSession = Depends(get_db), +): + return await user_service.get_users(db) ``` ### Model Mixins @@ -782,22 +743,15 @@ def get_profile( ```python @router.delete("/sessions/{session_id}") -def revoke_session( +async def revoke_session( session_id: UUID, current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + session_service: SessionService = Depends(get_session_service), + db: AsyncSession = Depends(get_db), ): """Users can only revoke their own sessions.""" - session = session_crud.get(db, id=session_id) - - if not session: - raise NotFoundError("Session not found") - - # Check ownership - if session.user_id != current_user.id: - raise AuthorizationError("You can only revoke your own sessions") - - session_crud.deactivate(db, session_id=session_id) + # SessionService verifies ownership and raises NotFoundError / AuthorizationError + await session_service.revoke_session(db, session_id=session_id, user_id=current_user.id) return MessageResponse(success=True, message="Session revoked") ``` @@ -1092,8 +1046,8 @@ async def cleanup_expired_sessions(): Runs daily at 2 AM. Removes sessions expired for more than 30 days. """ try: - with transaction_scope() as db: - count = session_crud.cleanup_expired(db, keep_days=30) + async with AsyncSessionLocal() as db: + count = await session_repo.cleanup_expired(db, keep_days=30) logger.info(f"Cleaned up {count} expired sessions") except Exception as e: logger.error(f"Error cleaning up sessions: {str(e)}", exc_info=True) @@ -1110,7 +1064,7 @@ async def cleanup_expired_sessions(): │Integration │ ← API endpoint tests │ Tests │ ├─────────────┤ - │ Unit │ ← CRUD, services, utilities + │ Unit │ ← repositories, services, utilities │ Tests │ └─────────────┘ ``` diff --git a/backend/docs/CODING_STANDARDS.md b/backend/docs/CODING_STANDARDS.md index 6f4b254..409664a 100644 --- a/backend/docs/CODING_STANDARDS.md +++ b/backend/docs/CODING_STANDARDS.md @@ -75,15 +75,14 @@ def create_user(db: Session, user_in: UserCreate) -> User: ### 4. Code Formatting Use automated formatters: -- **Black**: Code formatting -- **isort**: Import sorting -- **flake8**: Linting +- **Ruff**: Code formatting and linting (replaces Black, isort, flake8) +- **pyright**: Static type checking -Run before committing: +Run before committing (or use `make validate`): ```bash -black app tests -isort app tests -flake8 app tests +uv run ruff format app tests +uv run ruff check app tests +uv run pyright app ``` ## Code Organization @@ -94,19 +93,17 @@ Follow the 5-layer architecture strictly: ``` API Layer (routes/) - ↓ calls -Dependencies (dependencies/) - ↓ injects + ↓ calls (via service injected from dependencies/services.py) Service Layer (services/) ↓ calls -CRUD Layer (crud/) +Repository Layer (repositories/) ↓ uses Models & Schemas (models/, schemas/) ``` **Rules:** -- Routes should NOT directly call CRUD operations (use services when business logic is needed) -- CRUD operations should NOT contain business logic +- Routes must NEVER import repositories directly — always use a service +- Services call repositories; repositories contain only database operations - Models should NOT import from higher layers - Each layer should only depend on the layer directly below it @@ -125,7 +122,7 @@ from sqlalchemy.orm import Session # 3. Local application imports from app.api.dependencies.auth import get_current_user -from app.crud import user_crud +from app.api.dependencies.services import get_user_service from app.models.user import User from app.schemas.users import UserResponse, UserCreate ``` @@ -442,19 +439,19 @@ backend/app/alembic/versions/ 4. **Testability**: Easy to mock and test 5. **Consistent Ordering**: Always order queries for pagination -### Use the Async CRUD Base Class +### Use the Async Repository Base Class -Always inherit from `CRUDBase` for database operations: +Always inherit from `RepositoryBase` for database operations: ```python from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select -from app.crud.base import CRUDBase +from app.repositories.base import RepositoryBase from app.models.user import User from app.schemas.users import UserCreate, UserUpdate -class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]): - """CRUD operations for User model.""" +class UserRepository(RepositoryBase[User, UserCreate, UserUpdate]): + """Repository for User model — database operations only.""" async def get_by_email( self, @@ -467,7 +464,7 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]): ) return result.scalar_one_or_none() -user_crud = CRUDUser(User) +user_repo = UserRepository(User) ``` **Key Points:** @@ -476,6 +473,7 @@ user_crud = CRUDUser(User) - Use `await db.execute()` for queries - Use `.scalar_one_or_none()` instead of `.first()` - Use `T | None` instead of `Optional[T]` +- Repository instances are used internally by services — never import them in routes ### Modern SQLAlchemy Patterns @@ -563,7 +561,7 @@ async def create_user( The database session is automatically managed by FastAPI. Commit on success, rollback on error. """ - return await user_crud.create(db, obj_in=user_in) + return await user_service.create_user(db, obj_in=user_in) ``` **Key Points:** @@ -582,12 +580,11 @@ async def complex_operation( """ Perform multiple database operations atomically. - The session automatically commits on success or rolls back on error. + Services call repositories; commit/rollback is handled inside + each repository method. """ - user = await user_crud.create(db, obj_in=user_data) - session = await session_crud.create(db, obj_in=session_data) - - # Commit is handled by the route's dependency + user = await user_repo.create(db, obj_in=user_data) + session = await session_repo.create(db, obj_in=session_data) return user, session ``` @@ -597,10 +594,10 @@ Prefer soft deletes over hard deletes for audit trails: ```python # Good - Soft delete (sets deleted_at) -await user_crud.soft_delete(db, id=user_id) +await user_repo.soft_delete(db, id=user_id) # Acceptable only when required - Hard delete -user_crud.remove(db, id=user_id) +await user_repo.remove(db, id=user_id) ``` ### Query Patterns @@ -740,9 +737,10 @@ Always implement pagination for list endpoints: from app.schemas.common import PaginationParams, PaginatedResponse @router.get("/users", response_model=PaginatedResponse[UserResponse]) -def list_users( +async def list_users( pagination: PaginationParams = Depends(), - db: Session = Depends(get_db) + user_service: UserService = Depends(get_user_service), + db: AsyncSession = Depends(get_db), ): """ List all users with pagination. @@ -750,10 +748,8 @@ def list_users( Default page size: 20 Maximum page size: 100 """ - users, total = user_crud.get_multi_with_total( - db, - skip=pagination.offset, - limit=pagination.limit + users, total = await user_service.get_users( + db, skip=pagination.offset, limit=pagination.limit ) return PaginatedResponse(data=users, pagination=pagination.create_meta(total)) ``` @@ -816,19 +812,17 @@ def admin_route( pass # Check ownership -def delete_resource( +async def delete_resource( resource_id: UUID, current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + resource_service: ResourceService = Depends(get_resource_service), + db: AsyncSession = Depends(get_db), ): - resource = resource_crud.get(db, id=resource_id) - if not resource: - raise NotFoundError("Resource not found") - - if resource.user_id != current_user.id and not current_user.is_superuser: - raise AuthorizationError("You can only delete your own resources") - - resource_crud.remove(db, id=resource_id) + # Service handles ownership check and raises appropriate errors + await resource_service.delete_resource( + db, resource_id=resource_id, user_id=current_user.id, + is_superuser=current_user.is_superuser, + ) ``` ### Input Validation @@ -862,9 +856,9 @@ tests/ ├── api/ # Integration tests │ ├── test_users.py │ └── test_auth.py -├── crud/ # Unit tests for CRUD -├── models/ # Model tests -└── services/ # Service tests +├── repositories/ # Unit tests for repositories +├── services/ # Unit tests for services +└── models/ # Model tests ``` ### Async Testing with pytest-asyncio @@ -927,7 +921,7 @@ async def test_user(db_session: AsyncSession) -> User: @pytest.mark.asyncio async def test_get_user(db_session: AsyncSession, test_user: User): """Test retrieving a user by ID.""" - user = await user_crud.get(db_session, id=test_user.id) + user = await user_repo.get(db_session, id=test_user.id) assert user is not None assert user.email == test_user.email ``` diff --git a/backend/docs/COMMON_PITFALLS.md b/backend/docs/COMMON_PITFALLS.md index 8bbb064..963afef 100644 --- a/backend/docs/COMMON_PITFALLS.md +++ b/backend/docs/COMMON_PITFALLS.md @@ -616,7 +616,43 @@ def create_user( return user ``` -**Rule**: Add type hints to ALL functions. Use `mypy` to enforce type checking. +**Rule**: Add type hints to ALL functions. Use `pyright` to enforce type checking (`make type-check`). + +--- + +--- + +### ❌ PITFALL #19: Importing Repositories Directly in Routes + +**Issue**: Routes should never call repositories directly. The layered architecture requires all business operations to go through the service layer. + +```python +# ❌ WRONG - Route bypasses service layer +from app.repositories.session import session_repo + +@router.get("/sessions/me") +async def list_sessions( + current_user: User = Depends(get_current_active_user), + db: AsyncSession = Depends(get_db), +): + return await session_repo.get_user_sessions(db, user_id=current_user.id) +``` + +```python +# ✅ CORRECT - Route calls service injected via dependency +from app.api.dependencies.services import get_session_service +from app.services.session_service import SessionService + +@router.get("/sessions/me") +async def list_sessions( + current_user: User = Depends(get_current_active_user), + session_service: SessionService = Depends(get_session_service), + db: AsyncSession = Depends(get_db), +): + return await session_service.get_user_sessions(db, user_id=current_user.id) +``` + +**Rule**: Routes import from `app.api.dependencies.services`, never from `app.repositories.*`. Services are the only callers of repositories. --- @@ -649,6 +685,11 @@ Use this checklist to catch issues before code review: - [ ] Resource ownership verification - [ ] CORS configured (no wildcards in production) +### Architecture +- [ ] Routes never import repositories directly (only services) +- [ ] Services call repositories; repositories call database only +- [ ] New service registered in `app/api/dependencies/services.py` + ### Python - [ ] Use `==` not `is` for value comparison - [ ] No mutable default arguments @@ -661,21 +702,18 @@ Use this checklist to catch issues before code review: ### Pre-commit Checks -Add these to your development workflow: +Add these to your development workflow (or use `make validate`): ```bash -# Format code -black app tests -isort app tests +# Format + lint (Ruff replaces Black, isort, flake8) +uv run ruff format app tests +uv run ruff check app tests # Type checking -mypy app --strict - -# Linting -flake8 app tests +uv run pyright app # Run tests -pytest --cov=app --cov-report=term-missing +IS_TEST=True uv run pytest --cov=app --cov-report=term-missing # Check coverage (should be 80%+) coverage report --fail-under=80 @@ -693,6 +731,6 @@ Add new entries when: --- -**Last Updated**: 2025-10-31 -**Issues Cataloged**: 18 common pitfalls +**Last Updated**: 2026-02-28 +**Issues Cataloged**: 19 common pitfalls **Remember**: This document exists because these issues HAVE occurred. Don't skip it. diff --git a/backend/docs/FEATURE_EXAMPLE.md b/backend/docs/FEATURE_EXAMPLE.md index 263bba0..b973650 100644 --- a/backend/docs/FEATURE_EXAMPLE.md +++ b/backend/docs/FEATURE_EXAMPLE.md @@ -8,7 +8,7 @@ This guide walks through implementing a complete feature using the **User Sessio - [Implementation Steps](#implementation-steps) - [Step 1: Design the Database Model](#step-1-design-the-database-model) - [Step 2: Create Pydantic Schemas](#step-2-create-pydantic-schemas) - - [Step 3: Implement CRUD Operations](#step-3-implement-crud-operations) + - [Step 3: Implement Repository](#step-3-implement-repository) - [Step 4: Create API Endpoints](#step-4-create-api-endpoints) - [Step 5: Integrate with Existing Features](#step-5-integrate-with-existing-features) - [Step 6: Add Background Jobs](#step-6-add-background-jobs) @@ -204,8 +204,8 @@ Follow the standard pattern: ``` SessionBase (common fields) - ├── SessionCreate (internal: CRUD operations) - ├── SessionUpdate (internal: CRUD operations) + ├── SessionCreate (internal: repository operations) + ├── SessionUpdate (internal: repository operations) └── SessionResponse (external: API responses) ``` @@ -240,7 +240,7 @@ class SessionCreate(SessionBase): """ Schema for creating a new session (internal use). - Used by CRUD operations, not exposed to API. + Used by repository operations, not exposed to API. Contains all fields needed to create a session. """ user_id: UUID @@ -344,37 +344,37 @@ class DeviceInfo(BaseModel): 5. **OpenAPI Documentation**: `json_schema_extra` provides examples in API docs 6. **Type Safety**: Comprehensive type hints for all fields -### Step 3: Implement CRUD Operations +### Step 3: Implement Repository -**File**: `app/crud/session.py` +**File**: `app/repositories/session.py` -CRUD layer handles all database operations. No business logic here! +The repository layer handles all database operations. No business logic here — that belongs in services! -#### 3.1 Extend the Base CRUD Class +#### 3.1 Extend the Base Repository Class ```python """ -CRUD operations for user sessions. +Repository for user sessions. """ from datetime import datetime, timezone, timedelta -from typing import List, Optional from uuid import UUID -from sqlalchemy.orm import Session -from sqlalchemy import and_ + +from sqlalchemy import and_, select, update +from sqlalchemy.ext.asyncio import AsyncSession import logging -from app.crud.base import CRUDBase +from app.repositories.base import RepositoryBase from app.models.user_session import UserSession from app.schemas.sessions import SessionCreate, SessionUpdate logger = logging.getLogger(__name__) -class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): +class SessionRepository(RepositoryBase[UserSession, SessionCreate, SessionUpdate]): """ - CRUD operations for user sessions. + Repository for user sessions. - Inherits standard operations from CRUDBase: + Inherits standard operations from RepositoryBase: - get(db, id) - Get by ID - get_multi(db, skip, limit) - List with pagination - create(db, obj_in) - Create new session @@ -382,111 +382,62 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): - remove(db, id) - Delete session """ - # Custom query methods - # -------------------- - - def get_by_jti(self, db: Session, *, jti: str) -> Optional[UserSession]: + async def get_by_jti(self, db: AsyncSession, *, jti: str) -> UserSession | None: """ Get session by refresh token JTI. Used during token refresh to find the corresponding session. - - Args: - db: Database session - jti: Refresh token JWT ID - - Returns: - UserSession if found, None otherwise """ - try: - return db.query(UserSession).filter( - UserSession.refresh_token_jti == jti - ).first() - except Exception as e: - logger.error(f"Error getting session by JTI {jti}: {str(e)}") - raise + result = await db.execute( + select(UserSession).where(UserSession.refresh_token_jti == jti) + ) + return result.scalar_one_or_none() - def get_active_by_jti(self, db: Session, *, jti: str) -> Optional[UserSession]: - """ - Get active session by refresh token JTI. - - Only returns the session if it's currently active. - - Args: - db: Database session - jti: Refresh token JWT ID - - Returns: - Active UserSession if found, None otherwise - """ - try: - return db.query(UserSession).filter( + async def get_active_by_jti(self, db: AsyncSession, *, jti: str) -> UserSession | None: + """Get active (non-expired) session by refresh token JTI.""" + result = await db.execute( + select(UserSession).where( and_( UserSession.refresh_token_jti == jti, - UserSession.is_active == True + UserSession.is_active.is_(True), ) - ).first() - except Exception as e: - logger.error(f"Error getting active session by JTI {jti}: {str(e)}") - raise + ) + ) + session = result.scalar_one_or_none() + if session and not session.is_expired: + return session + return None - def get_user_sessions( + async def get_user_sessions( self, - db: Session, + db: AsyncSession, *, - user_id: str, - active_only: bool = True - ) -> List[UserSession]: + user_id: UUID, + active_only: bool = True, + ) -> list[UserSession]: """ - Get all sessions for a user. - - Args: - db: Database session - user_id: User ID - active_only: If True, return only active sessions - - Returns: - List of UserSession objects, ordered by most recently used + Get all sessions for a user, ordered by most recently used. """ - try: - # Convert user_id string to UUID if needed - user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id + query = select(UserSession).where(UserSession.user_id == user_id) + if active_only: + query = query.where(UserSession.is_active.is_(True)) + query = query.order_by(UserSession.last_used_at.desc()) + result = await db.execute(query) + return list(result.scalars().all()) - query = db.query(UserSession).filter(UserSession.user_id == user_uuid) - - if active_only: - query = query.filter(UserSession.is_active == True) - - # Order by most recently used first - return query.order_by(UserSession.last_used_at.desc()).all() - except Exception as e: - logger.error(f"Error getting sessions for user {user_id}: {str(e)}") - raise - - # Creation methods - # ---------------- - - def create_session( + async def create_session( self, - db: Session, + db: AsyncSession, *, - obj_in: SessionCreate + obj_in: SessionCreate, ) -> UserSession: """ Create a new user session. - Args: - db: Database session - obj_in: SessionCreate schema with session data - - Returns: - Created UserSession - Raises: - ValueError: If session creation fails + DuplicateEntryError: If a session with the same JTI already exists """ try: - # Create model instance from schema db_obj = UserSession( user_id=obj_in.user_id, refresh_token_jti=obj_in.refresh_token_jti, @@ -501,248 +452,93 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): location_country=obj_in.location_country, ) - db.add(db_obj) - db.commit() - db.refresh(db_obj) - + db_obj.add(db_obj) + await db.commit() + await db.refresh(db_obj) logger.info( - f"Session created for user {obj_in.user_id} from {obj_in.device_name} " - f"(IP: {obj_in.ip_address})" + f"Session created for user {obj_in.user_id} from {obj_in.device_name}" ) - return db_obj - except Exception as e: - db.rollback() + await db.rollback() logger.error(f"Error creating session: {str(e)}", exc_info=True) - raise ValueError(f"Failed to create session: {str(e)}") - - # Update methods - # -------------- - - def deactivate(self, db: Session, *, session_id: str) -> Optional[UserSession]: - """ - Deactivate a session (logout from device). - - Args: - db: Database session - session_id: Session UUID - - Returns: - Deactivated UserSession if found, None otherwise - """ - try: - session = self.get(db, id=session_id) - if not session: - logger.warning(f"Session {session_id} not found for deactivation") - return None - - session.is_active = False - db.add(session) - db.commit() - db.refresh(session) - - logger.info( - f"Session {session_id} deactivated for user {session.user_id} " - f"({session.device_name})" - ) - - return session - - except Exception as e: - db.rollback() - logger.error(f"Error deactivating session {session_id}: {str(e)}") raise - def deactivate_all_user_sessions( + async def deactivate(self, db: AsyncSession, *, session_id: UUID) -> UserSession | None: + """Deactivate a session (logout from device).""" + session = await self.get(db, id=session_id) + if not session: + return None + session.is_active = False + await db.commit() + await db.refresh(session) + logger.info(f"Session {session_id} deactivated ({session.device_name})") + return session + + async def deactivate_all_user_sessions( self, - db: Session, + db: AsyncSession, *, - user_id: str + user_id: UUID, ) -> int: """ Deactivate all active sessions for a user (logout from all devices). - Uses bulk update for efficiency. - - Args: - db: Database session - user_id: User ID - - Returns: - Number of sessions deactivated + Uses a bulk UPDATE for efficiency — no N+1 queries. """ - try: - # Convert user_id string to UUID if needed - user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id - - # Bulk update query - count = db.query(UserSession).filter( - and_( - UserSession.user_id == user_uuid, - UserSession.is_active == True - ) - ).update({"is_active": False}) - - db.commit() - - logger.info(f"Deactivated {count} sessions for user {user_id}") - - return count - - except Exception as e: - db.rollback() - logger.error(f"Error deactivating all sessions for user {user_id}: {str(e)}") - raise - - def update_last_used( - self, - db: Session, - *, - session: UserSession - ) -> UserSession: - """ - Update the last_used_at timestamp for a session. - - Called when a refresh token is used. - - Args: - db: Database session - session: UserSession object - - Returns: - Updated UserSession - """ - try: - session.last_used_at = datetime.now(timezone.utc) - db.add(session) - db.commit() - db.refresh(session) - return session - except Exception as e: - db.rollback() - logger.error(f"Error updating last_used for session {session.id}: {str(e)}") - raise - - def update_refresh_token( - self, - db: Session, - *, - session: UserSession, - new_jti: str, - new_expires_at: datetime - ) -> UserSession: - """ - Update session with new refresh token JTI and expiration. - - Called during token refresh (token rotation). - - Args: - db: Database session - session: UserSession object - new_jti: New refresh token JTI - new_expires_at: New expiration datetime - - Returns: - Updated UserSession - """ - try: - session.refresh_token_jti = new_jti - session.expires_at = new_expires_at - session.last_used_at = datetime.now(timezone.utc) - db.add(session) - db.commit() - db.refresh(session) - return session - except Exception as e: - db.rollback() - logger.error(f"Error updating refresh token for session {session.id}: {str(e)}") - raise - - # Cleanup methods - # --------------- - - def cleanup_expired(self, db: Session, *, keep_days: int = 30) -> int: - """ - Clean up expired sessions. - - Deletes sessions that are: - - Expired (expires_at < now) AND inactive - - Older than keep_days (for audit trail) - - Args: - db: Database session - keep_days: Keep inactive sessions for this many days - - Returns: - Number of sessions deleted - """ - try: - cutoff_date = datetime.now(timezone.utc) - timedelta(days=keep_days) - - count = db.query(UserSession).filter( - and_( - UserSession.is_active == False, - UserSession.expires_at < datetime.now(timezone.utc), - UserSession.created_at < cutoff_date - ) - ).delete() - - db.commit() - - if count > 0: - logger.info(f"Cleaned up {count} expired sessions") - - return count - - except Exception as e: - db.rollback() - logger.error(f"Error cleaning up expired sessions: {str(e)}") - raise - - # Utility methods - # --------------- - - def get_user_session_count(self, db: Session, *, user_id: str) -> int: - """ - Get count of active sessions for a user. - - Useful for session limits or security monitoring. - - Args: - db: Database session - user_id: User ID - - Returns: - Number of active sessions - """ - try: - return db.query(UserSession).filter( + result = await db.execute( + update(UserSession) + .where( and_( UserSession.user_id == user_id, - UserSession.is_active == True + UserSession.is_active.is_(True), ) - ).count() - except Exception as e: - logger.error(f"Error counting sessions for user {user_id}: {str(e)}") - raise + ) + .values(is_active=False) + ) + await db.commit() + count = result.rowcount + logger.info(f"Deactivated {count} sessions for user {user_id}") + return count + + async def cleanup_expired(self, db: AsyncSession, *, keep_days: int = 30) -> int: + """ + Hard-delete inactive sessions older than keep_days. + + Returns the number of sessions deleted. + """ + cutoff_date = datetime.now(timezone.utc) - timedelta(days=keep_days) + result = await db.execute( + select(UserSession).where( + and_( + UserSession.is_active.is_(False), + UserSession.expires_at < datetime.now(timezone.utc), + UserSession.created_at < cutoff_date, + ) + ) + ) + sessions = list(result.scalars().all()) + for s in sessions: + await db.delete(s) + await db.commit() + if sessions: + logger.info(f"Cleaned up {len(sessions)} expired sessions") + return len(sessions) -# Create singleton instance -# This is the instance that will be imported and used throughout the app -session = CRUDSession(UserSession) +# Singleton instance — used by services, never imported directly in routes +session_repo = SessionRepository(UserSession) ``` **Key Patterns**: -1. **Error Handling**: Every method has try/except with rollback -2. **Logging**: Log all significant actions (create, delete, errors) -3. **Type Safety**: Full type hints for parameters and returns -4. **Docstrings**: Document what each method does, args, returns, raises -5. **Bulk Operations**: Use `query().update()` for efficiency when updating many rows -6. **UUID Handling**: Convert string UUIDs to UUID objects when needed -7. **Ordering**: Return results in a logical order (most recent first) -8. **Singleton Pattern**: Create one instance to be imported elsewhere +1. **Async everywhere**: All methods use `async def` and `await` +2. **Modern SQLAlchemy**: `select()` API, never `db.query()` +3. **Bulk updates**: Use `update()` statement for multi-row changes (no N+1) +4. **Error handling**: `try/except` with `await db.rollback()` in mutating methods +5. **Logging**: Log all significant actions (create, delete, errors) +6. **Type safety**: Full type hints; `UUID` not raw `str` for IDs +7. **Singleton pattern**: One module-level instance used by services ### Step 4: Create API Endpoints @@ -772,14 +568,15 @@ from uuid import UUID from fastapi import APIRouter, Depends, HTTPException, status, Request from slowapi import Limiter from slowapi.util import get_remote_address -from sqlalchemy.orm import Session +from sqlalchemy.ext.asyncio import AsyncSession from app.api.dependencies.auth import get_current_user +from app.api.dependencies.services import get_session_service from app.core.database import get_db from app.models.user import User from app.schemas.sessions import SessionResponse, SessionListResponse from app.schemas.common import MessageResponse -from app.crud.session import session as session_crud +from app.services.session_service import SessionService from app.core.exceptions import NotFoundError, AuthorizationError, ErrorCode router = APIRouter() @@ -803,61 +600,21 @@ limiter = Limiter(key_func=get_remote_address) operation_id="list_my_sessions" ) @limiter.limit("30/minute") -def list_my_sessions( +async def list_my_sessions( request: Request, current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + session_service: SessionService = Depends(get_session_service), + db: AsyncSession = Depends(get_db), ) -> Any: - """ - List all active sessions for the current user. - - Args: - request: FastAPI request object (for rate limiting) - current_user: Current authenticated user (injected) - db: Database session (injected) - - Returns: - SessionListResponse with list of active sessions - """ - try: - # Get all active sessions for user - sessions = session_crud.get_user_sessions( - db, - user_id=str(current_user.id), - active_only=True - ) - - # Convert to response format - session_responses = [] - for idx, s in enumerate(sessions): - session_response = SessionResponse( - id=s.id, - device_name=s.device_name, - device_id=s.device_id, - ip_address=s.ip_address, - location_city=s.location_city, - location_country=s.location_country, - last_used_at=s.last_used_at, - created_at=s.created_at, - expires_at=s.expires_at, - # Mark the most recently used session as current - is_current=(idx == 0) - ) - session_responses.append(session_response) - - logger.info(f"User {current_user.id} listed {len(session_responses)} active sessions") - - return SessionListResponse( - sessions=session_responses, - total=len(session_responses) - ) - - except Exception as e: - logger.error(f"Error listing sessions for user {current_user.id}: {str(e)}", exc_info=True) - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to retrieve sessions" - ) + """List all active sessions for the current user.""" + sessions = await session_service.get_user_sessions( + db, user_id=current_user.id, active_only=True + ) + session_responses = [ + SessionResponse.model_validate(s) | {"is_current": idx == 0} + for idx, s in enumerate(sessions) + ] + return SessionListResponse(sessions=session_responses, total=len(session_responses)) @router.delete( @@ -876,71 +633,26 @@ def list_my_sessions( operation_id="revoke_session" ) @limiter.limit("10/minute") -def revoke_session( +async def revoke_session( request: Request, session_id: UUID, current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + session_service: SessionService = Depends(get_session_service), + db: AsyncSession = Depends(get_db), ) -> Any: """ Revoke a specific session by ID. - Args: - request: FastAPI request object (for rate limiting) - session_id: UUID of the session to revoke - current_user: Current authenticated user (injected) - db: Database session (injected) - - Returns: - MessageResponse with success message - - Raises: - NotFoundError: If session doesn't exist - AuthorizationError: If session belongs to another user + The service verifies ownership and raises NotFoundError / + AuthorizationError which are handled by global exception handlers. """ - try: - # Get the session - session = session_crud.get(db, id=str(session_id)) - - if not session: - raise NotFoundError( - message=f"Session {session_id} not found", - error_code=ErrorCode.NOT_FOUND - ) - - # Verify session belongs to current user (authorization check) - if str(session.user_id) != str(current_user.id): - logger.warning( - f"User {current_user.id} attempted to revoke session {session_id} " - f"belonging to user {session.user_id}" - ) - raise AuthorizationError( - message="You can only revoke your own sessions", - error_code=ErrorCode.INSUFFICIENT_PERMISSIONS - ) - - # Deactivate the session - session_crud.deactivate(db, session_id=str(session_id)) - - logger.info( - f"User {current_user.id} revoked session {session_id} " - f"({session.device_name})" - ) - - return MessageResponse( - success=True, - message=f"Session revoked: {session.device_name or 'Unknown device'}" - ) - - except (NotFoundError, AuthorizationError): - # Re-raise custom exceptions (they'll be handled by global handlers) - raise - except Exception as e: - logger.error(f"Error revoking session {session_id}: {str(e)}", exc_info=True) - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to revoke session" - ) + device_name = await session_service.revoke_session( + db, session_id=session_id, user_id=current_user.id + ) + return MessageResponse( + success=True, + message=f"Session revoked: {device_name or 'Unknown device'}" + ) @router.delete( @@ -958,55 +670,20 @@ def revoke_session( operation_id="cleanup_expired_sessions" ) @limiter.limit("5/minute") -def cleanup_expired_sessions( +async def cleanup_expired_sessions( request: Request, current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + session_service: SessionService = Depends(get_session_service), + db: AsyncSession = Depends(get_db), ) -> Any: - """ - Cleanup expired sessions for the current user. - - Args: - request: FastAPI request object (for rate limiting) - current_user: Current authenticated user (injected) - db: Database session (injected) - - Returns: - MessageResponse with count of sessions cleaned - """ - try: - from datetime import datetime, timezone - - # Get all sessions for user (including inactive) - all_sessions = session_crud.get_user_sessions( - db, - user_id=str(current_user.id), - active_only=False - ) - - # Delete expired and inactive sessions - deleted_count = 0 - for s in all_sessions: - if not s.is_active and s.expires_at < datetime.now(timezone.utc): - db.delete(s) - deleted_count += 1 - - db.commit() - - logger.info(f"User {current_user.id} cleaned up {deleted_count} expired sessions") - - return MessageResponse( - success=True, - message=f"Cleaned up {deleted_count} expired sessions" - ) - - except Exception as e: - logger.error(f"Error cleaning up sessions for user {current_user.id}: {str(e)}", exc_info=True) - db.rollback() - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to cleanup sessions" - ) + """Cleanup expired sessions for the current user.""" + deleted_count = await session_service.cleanup_user_expired_sessions( + db, user_id=current_user.id + ) + return MessageResponse( + success=True, + message=f"Cleaned up {deleted_count} expired sessions" + ) ``` **Key Patterns**: @@ -1079,59 +756,25 @@ Session management needs to be integrated into the authentication flow. ```python from app.utils.device import extract_device_info -from app.crud.session import session as session_crud -from app.schemas.sessions import SessionCreate +from app.api.dependencies.services import get_auth_service +from app.services.auth_service import AuthService @router.post("/login") async def login( request: Request, credentials: OAuth2PasswordRequestForm = Depends(), - db: Session = Depends(get_db) + auth_service: AuthService = Depends(get_auth_service), + db: AsyncSession = Depends(get_db), ): """Authenticate user and create session.""" - - # 1. Validate credentials - user = user_crud.get_by_email(db, email=credentials.username) - if not user or not verify_password(credentials.password, user.hashed_password): - raise AuthenticationError("Invalid credentials") - - if not user.is_active: - raise AuthenticationError("Account is inactive") - - # 2. Extract device information from request - device_info = extract_device_info(request) - - # 3. Generate tokens - jti = str(uuid.uuid4()) # Generate JTI for refresh token - access_token = create_access_token(subject=str(user.id)) - refresh_token = create_refresh_token(subject=str(user.id), jti=jti) - - # 4. Create session record - from datetime import datetime, timezone, timedelta - - session_data = SessionCreate( - user_id=user.id, - refresh_token_jti=jti, - device_name=device_info.device_name, - device_id=device_info.device_id, - ip_address=device_info.ip_address, - user_agent=device_info.user_agent, - last_used_at=datetime.now(timezone.utc), - expires_at=datetime.now(timezone.utc) + timedelta(days=7), - location_city=device_info.location_city, - location_country=device_info.location_country, + # All business logic (validate credentials, create session, generate tokens) + # is delegated to AuthService which calls the appropriate repositories. + return await auth_service.login( + db, + email=credentials.username, + password=credentials.password, + request=request, ) - - session_crud.create_session(db, obj_in=session_data) - - logger.info(f"User {user.email} logged in from {device_info.device_name}") - - return { - "access_token": access_token, - "refresh_token": refresh_token, - "token_type": "bearer", - "user": UserResponse.model_validate(user) - } ``` #### 5.2 Create Device Info Utility @@ -1193,89 +836,35 @@ def extract_device_info(request: Request) -> DeviceInfo: ```python @router.post("/refresh") -def refresh_token( +async def refresh_token( refresh_request: RefreshRequest, - db: Session = Depends(get_db) + auth_service: AuthService = Depends(get_auth_service), + db: AsyncSession = Depends(get_db), ): """Refresh access token using refresh token.""" - - try: - # 1. Decode and validate refresh token - payload = decode_token(refresh_request.refresh_token) - - if payload.get("type") != "refresh": - raise AuthenticationError("Invalid token type") - - user_id = UUID(payload.get("sub")) - jti = payload.get("jti") - - # 2. Find and validate session - session = session_crud.get_active_by_jti(db, jti=jti) - - if not session: - raise AuthenticationError("Session not found or expired") - - if session.user_id != user_id: - raise AuthenticationError("Token mismatch") - - # 3. Generate new tokens (token rotation) - new_jti = str(uuid.uuid4()) - new_access_token = create_access_token(subject=str(user_id)) - new_refresh_token = create_refresh_token(subject=str(user_id), jti=new_jti) - - # 4. Update session with new JTI - session_crud.update_refresh_token( - db, - session=session, - new_jti=new_jti, - new_expires_at=datetime.now(timezone.utc) + timedelta(days=7) - ) - - logger.info(f"Tokens refreshed for user {user_id}") - - return { - "access_token": new_access_token, - "refresh_token": new_refresh_token, - "token_type": "bearer" - } - - except Exception as e: - logger.error(f"Token refresh failed: {str(e)}") - raise AuthenticationError("Failed to refresh token") + # AuthService handles token validation, session lookup, token rotation + return await auth_service.refresh_tokens( + db, refresh_token=refresh_request.refresh_token + ) ``` #### 5.4 Update Logout Endpoint ```python @router.post("/logout") -def logout( +async def logout( logout_request: LogoutRequest, current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + auth_service: AuthService = Depends(get_auth_service), + db: AsyncSession = Depends(get_db), ): """Logout from current device.""" - - try: - # Decode refresh token to get JTI - payload = decode_token(logout_request.refresh_token) - jti = payload.get("jti") - - # Find and deactivate session - session = session_crud.get_by_jti(db, jti=jti) - - if session and session.user_id == current_user.id: - session_crud.deactivate(db, session_id=str(session.id)) - logger.info(f"User {current_user.id} logged out from {session.device_name}") - - return MessageResponse( - success=True, - message="Logged out successfully" - ) - - except Exception as e: - logger.error(f"Logout failed: {str(e)}") - # Even if cleanup fails, return success (user intended to logout) - return MessageResponse(success=True, message="Logged out") + await auth_service.logout( + db, + refresh_token=logout_request.refresh_token, + user_id=current_user.id, + ) + return MessageResponse(success=True, message="Logged out successfully") ``` ### Step 6: Add Background Jobs @@ -1287,8 +876,8 @@ def logout( Background job for cleaning up expired sessions. """ import logging -from app.core.database import SessionLocal -from app.crud.session import session as session_crud +from app.core.database import AsyncSessionLocal +from app.repositories.session import session_repo logger = logging.getLogger(__name__) @@ -1302,14 +891,12 @@ async def cleanup_expired_sessions(): - Inactive (is_active = False) - Older than 30 days (for audit trail) """ - db = SessionLocal() - try: - count = session_crud.cleanup_expired(db, keep_days=30) - logger.info(f"Background cleanup: Removed {count} expired sessions") - except Exception as e: - logger.error(f"Error in session cleanup job: {str(e)}", exc_info=True) - finally: - db.close() + async with AsyncSessionLocal() as db: + try: + count = await session_repo.cleanup_expired(db, keep_days=30) + logger.info(f"Background cleanup: Removed {count} expired sessions") + except Exception as e: + logger.error(f"Error in session cleanup job: {str(e)}", exc_info=True) ``` **Register in** `app/main.py`: @@ -1679,7 +1266,8 @@ You've now implemented a complete feature! Here's what was created: **Files Created/Modified**: 1. `app/models/user_session.py` - Database model 2. `app/schemas/sessions.py` - Pydantic schemas -3. `app/crud/session.py` - CRUD operations +3. `app/repositories/session.py` - Repository (data access) +4. `app/services/session_service.py` - Service (business logic) 4. `app/api/routes/sessions.py` - API endpoints 5. `app/utils/device.py` - Device detection utility 6. `app/services/session_cleanup.py` - Background job @@ -1715,7 +1303,7 @@ You've now implemented a complete feature! Here's what was created: ### Don'ts -1. **Don't Mix Layers**: Keep business logic out of CRUD, database ops out of routes +1. **Don't Mix Layers**: Keep business logic in services, database ops in repositories, routing in routes 2. **Don't Expose Internals**: Never return sensitive data in API responses 3. **Don't Trust Input**: Always validate and sanitize user input 4. **Don't Ignore Errors**: Always handle exceptions properly @@ -1733,7 +1321,9 @@ When implementing a new feature, use this checklist: - [ ] Design database schema - [ ] Create SQLAlchemy model - [ ] Design Pydantic schemas (Create, Update, Response) -- [ ] Implement CRUD operations +- [ ] Implement repository (data access) +- [ ] Implement service (business logic) +- [ ] Register service in `app/api/dependencies/services.py` - [ ] Create API endpoints - [ ] Add authentication/authorization - [ ] Implement rate limiting