refactor(docs): update architecture to reflect repository migration

- Rename CRUD layer to Repository layer throughout architecture documentation.
- Update dependency injection examples to use repository classes.
- Add async SQLAlchemy pattern for Repository methods (`select()` and transactions).
- Replace CRUD references in FEATURE_EXAMPLE.md with Repository-focused implementation details.
- Highlight repository class responsibilities and remove outdated CRUD patterns.
This commit is contained in:
2026-03-01 11:13:51 +01:00
parent 80d2dc0cb2
commit 68275b1dd3
4 changed files with 349 additions and 773 deletions

View File

@@ -117,7 +117,8 @@ backend/
│ ├── api/ # API layer │ ├── api/ # API layer
│ │ ├── dependencies/ # Dependency injection │ │ ├── dependencies/ # Dependency injection
│ │ │ ├── auth.py # Authentication dependencies │ │ │ ├── auth.py # Authentication dependencies
│ │ │ ── permissions.py # Authorization dependencies │ │ │ ── permissions.py # Authorization dependencies
│ │ │ └── services.py # Service singleton injection
│ │ ├── routes/ # API endpoints │ │ ├── routes/ # API endpoints
│ │ │ ├── auth.py # Authentication routes │ │ │ ├── auth.py # Authentication routes
│ │ │ ├── users.py # User management routes │ │ │ ├── users.py # User management routes
@@ -131,13 +132,14 @@ backend/
│ │ ├── config.py # Application configuration │ │ ├── config.py # Application configuration
│ │ ├── database.py # Database connection │ │ ├── database.py # Database connection
│ │ ├── exceptions.py # Custom exception classes │ │ ├── exceptions.py # Custom exception classes
│ │ ├── repository_exceptions.py # Repository-level exception hierarchy
│ │ └── middleware.py # Custom middleware │ │ └── middleware.py # Custom middleware
│ │ │ │
│ ├── crud/ # Database operations │ ├── repositories/ # Data access layer
│ │ ├── base.py # Generic CRUD base class │ │ ├── base.py # Generic repository base class
│ │ ├── user.py # User CRUD operations │ │ ├── user.py # User repository
│ │ ├── session.py # Session CRUD operations │ │ ├── session.py # Session repository
│ │ └── organization.py # Organization CRUD │ │ └── organization.py # Organization repository
│ │ │ │
│ ├── models/ # SQLAlchemy models │ ├── models/ # SQLAlchemy models
│ │ ├── base.py # Base model with mixins │ │ ├── base.py # Base model with mixins
@@ -153,8 +155,11 @@ backend/
│ │ ├── sessions.py # Session schemas │ │ ├── sessions.py # Session schemas
│ │ └── organizations.py # Organization schemas │ │ └── organizations.py # Organization schemas
│ │ │ │
│ ├── services/ # Business logic │ ├── services/ # Business logic layer
│ │ ├── auth_service.py # Authentication service │ │ ├── auth_service.py # Authentication service
│ │ ├── user_service.py # User management service
│ │ ├── session_service.py # Session management service
│ │ ├── organization_service.py # Organization service
│ │ ├── email_service.py # Email service │ │ ├── email_service.py # Email service
│ │ └── session_cleanup.py # Background cleanup │ │ └── session_cleanup.py # Background cleanup
│ │ │ │
@@ -168,9 +173,9 @@ backend/
├── tests/ # Test suite ├── tests/ # Test suite
│ ├── api/ # Integration tests │ ├── api/ # Integration tests
│ ├── crud/ # CRUD tests │ ├── repositories/ # Repository unit tests
│ ├── services/ # Service unit tests
│ ├── models/ # Model tests │ ├── models/ # Model tests
│ ├── services/ # Service tests
│ └── conftest.py # Test configuration │ └── conftest.py # Test configuration
├── docs/ # Documentation ├── docs/ # Documentation
@@ -214,11 +219,11 @@ The application follows a strict 5-layer architecture:
└──────────────────────────┬──────────────────────────────────┘ └──────────────────────────┬──────────────────────────────────┘
│ calls │ calls
┌──────────────────────────▼──────────────────────────────────┐ ┌──────────────────────────▼──────────────────────────────────┐
CRUD Layer (crud/) Repository Layer (repositories/)
│ - Database operations │ │ - Database operations │
│ - Query building │ │ - Query building │
│ - Transaction management │ - Custom repository exceptions
│ - Error handling │ - No business logic
└──────────────────────────┬──────────────────────────────────┘ └──────────────────────────┬──────────────────────────────────┘
│ uses │ uses
┌──────────────────────────▼──────────────────────────────────┐ ┌──────────────────────────▼──────────────────────────────────┐
@@ -262,7 +267,7 @@ async def get_current_user_info(
**Rules**: **Rules**:
- Should NOT contain business logic - Should NOT contain business logic
- Should NOT directly perform database operations (use CRUD or services) - Should NOT directly call repositories (use services injected via `dependencies/services.py`)
- Must validate all input via Pydantic schemas - Must validate all input via Pydantic schemas
- Must specify response models - Must specify response models
- Should apply appropriate rate limits - Should apply appropriate rate limits
@@ -279,9 +284,9 @@ async def get_current_user_info(
**Example**: **Example**:
```python ```python
def get_current_user( async def get_current_user(
token: str = Depends(oauth2_scheme), token: str = Depends(oauth2_scheme),
db: Session = Depends(get_db) db: AsyncSession = Depends(get_db)
) -> User: ) -> User:
""" """
Extract and validate user from JWT token. Extract and validate user from JWT token.
@@ -295,7 +300,7 @@ def get_current_user(
except Exception: except Exception:
raise AuthenticationError("Invalid authentication credentials") raise AuthenticationError("Invalid authentication credentials")
user = user_crud.get(db, id=user_id) user = await user_repo.get(db, id=user_id)
if not user: if not user:
raise AuthenticationError("User not found") raise AuthenticationError("User not found")
@@ -313,7 +318,7 @@ def get_current_user(
**Responsibility**: Implement complex business logic **Responsibility**: Implement complex business logic
**Key Functions**: **Key Functions**:
- Orchestrate multiple CRUD operations - Orchestrate multiple repository operations
- Implement business rules - Implement business rules
- Handle external service integration - Handle external service integration
- Coordinate transactions - Coordinate transactions
@@ -323,9 +328,9 @@ def get_current_user(
class AuthService: class AuthService:
"""Authentication service with business logic.""" """Authentication service with business logic."""
def login( async def login(
self, self,
db: Session, db: AsyncSession,
email: str, email: str,
password: str, password: str,
request: Request request: Request
@@ -339,8 +344,8 @@ class AuthService:
3. Generate tokens 3. Generate tokens
4. Return tokens and user info 4. Return tokens and user info
""" """
# Validate credentials # Validate credentials via repository
user = user_crud.get_by_email(db, email=email) user = await user_repo.get_by_email(db, email=email)
if not user or not verify_password(password, user.hashed_password): if not user or not verify_password(password, user.hashed_password):
raise AuthenticationError("Invalid credentials") raise AuthenticationError("Invalid credentials")
@@ -350,11 +355,10 @@ class AuthService:
# Extract device info # Extract device info
device_info = extract_device_info(request) device_info = extract_device_info(request)
# Create session # Create session via repository
session = session_crud.create_session( session = await session_repo.create(
db, db,
user_id=user.id, obj_in=SessionCreate(user_id=user.id, **device_info)
device_info=device_info
) )
# Generate tokens # Generate tokens
@@ -373,75 +377,60 @@ class AuthService:
**Rules**: **Rules**:
- Contains business logic, not just data operations - Contains business logic, not just data operations
- Can call multiple CRUD operations - Can call multiple repository operations
- Should handle complex workflows - Should handle complex workflows
- Must maintain data consistency - Must maintain data consistency
- Should use transactions when needed - Should use transactions when needed
#### 4. CRUD Layer (`app/crud/`) #### 4. Repository Layer (`app/repositories/`)
**Responsibility**: Database operations and queries **Responsibility**: Database operations and queries — no business logic
**Key Functions**: **Key Functions**:
- Create, read, update, delete operations - Create, read, update, delete operations
- Build database queries - Build database queries
- Handle database errors - Raise custom repository exceptions (`DuplicateEntryError`, `IntegrityConstraintError`)
- Manage soft deletes - Manage soft deletes
- Implement pagination and filtering - Implement pagination and filtering
**Example**: **Example**:
```python ```python
class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): class SessionRepository(RepositoryBase[UserSession, SessionCreate, SessionUpdate]):
"""CRUD operations for user sessions.""" """Repository for user sessions — database operations only."""
def get_by_jti(self, db: Session, jti: UUID) -> Optional[UserSession]: async def get_by_jti(self, db: AsyncSession, *, jti: str) -> UserSession | None:
"""Get session by refresh token JTI.""" """Get session by refresh token JTI."""
try: result = await db.execute(
return ( select(UserSession).where(UserSession.refresh_token_jti == jti)
db.query(UserSession)
.filter(UserSession.refresh_token_jti == jti)
.first()
) )
except Exception as e: return result.scalar_one_or_none()
logger.error(f"Error getting session by JTI: {str(e)}")
return None
def get_active_by_jti( async def deactivate(self, db: AsyncSession, *, session_id: UUID) -> bool:
self,
db: Session,
jti: UUID
) -> Optional[UserSession]:
"""Get active session by refresh token JTI."""
session = self.get_by_jti(db, jti=jti)
if session and session.is_active and not session.is_expired:
return session
return None
def deactivate(self, db: Session, session_id: UUID) -> bool:
"""Deactivate a session (logout).""" """Deactivate a session (logout)."""
try: try:
session = self.get(db, id=session_id) session = await self.get(db, id=session_id)
if not session: if not session:
return False return False
session.is_active = False session.is_active = False
db.commit() await db.commit()
logger.info(f"Session {session_id} deactivated") logger.info(f"Session {session_id} deactivated")
return True return True
except Exception as e: except Exception as e:
db.rollback() await db.rollback()
logger.error(f"Error deactivating session: {str(e)}") logger.error(f"Error deactivating session: {str(e)}")
return False return False
``` ```
**Rules**: **Rules**:
- Should NOT contain business logic - Should NOT contain business logic
- Must handle database exceptions - Must raise custom repository exceptions (not raw `ValueError`/`IntegrityError`)
- Must use parameterized queries (SQLAlchemy does this) - Must use async SQLAlchemy 2.0 `select()` API (never `db.query()`)
- Should log all database errors - Should log all database errors
- Must rollback on errors - Must rollback on errors
- Should use soft deletes when possible - Should use soft deletes when possible
- **Never imported directly by routes** — always called through services
#### 5. Data Layer (`app/models/` + `app/schemas/`) #### 5. Data Layer (`app/models/` + `app/schemas/`)
@@ -546,51 +535,23 @@ SessionLocal = sessionmaker(
#### Dependency Injection Pattern #### Dependency Injection Pattern
```python ```python
def get_db() -> Generator[Session, None, None]: async def get_db() -> AsyncGenerator[AsyncSession, None]:
""" """
Database session dependency for FastAPI routes. Async database session dependency for FastAPI routes.
Automatically commits on success, rolls back on error. The session is passed to service methods; commit/rollback is
managed inside service or repository methods.
""" """
db = SessionLocal() async with AsyncSessionLocal() as db:
try:
yield db yield db
finally:
db.close()
# Usage in routes # Usage in routes — always through a service, never direct repository
@router.get("/users") @router.get("/users")
def list_users(db: Session = Depends(get_db)): async def list_users(
return user_crud.get_multi(db) user_service: UserService = Depends(get_user_service),
``` db: AsyncSession = Depends(get_db),
):
#### Context Manager Pattern return await user_service.get_users(db)
```python
@contextmanager
def transaction_scope() -> Generator[Session, None, None]:
"""
Context manager for database transactions.
Use for complex operations requiring multiple steps.
Automatically commits on success, rolls back on error.
"""
db = SessionLocal()
try:
yield db
db.commit()
except Exception:
db.rollback()
raise
finally:
db.close()
# Usage in services
def complex_operation():
with transaction_scope() as db:
user = user_crud.create(db, obj_in=user_data)
session = session_crud.create(db, session_data)
return user, session
``` ```
### Model Mixins ### Model Mixins
@@ -782,22 +743,15 @@ def get_profile(
```python ```python
@router.delete("/sessions/{session_id}") @router.delete("/sessions/{session_id}")
def revoke_session( async def revoke_session(
session_id: UUID, session_id: UUID,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db) session_service: SessionService = Depends(get_session_service),
db: AsyncSession = Depends(get_db),
): ):
"""Users can only revoke their own sessions.""" """Users can only revoke their own sessions."""
session = session_crud.get(db, id=session_id) # SessionService verifies ownership and raises NotFoundError / AuthorizationError
await session_service.revoke_session(db, session_id=session_id, user_id=current_user.id)
if not session:
raise NotFoundError("Session not found")
# Check ownership
if session.user_id != current_user.id:
raise AuthorizationError("You can only revoke your own sessions")
session_crud.deactivate(db, session_id=session_id)
return MessageResponse(success=True, message="Session revoked") return MessageResponse(success=True, message="Session revoked")
``` ```
@@ -1092,8 +1046,8 @@ async def cleanup_expired_sessions():
Runs daily at 2 AM. Removes sessions expired for more than 30 days. Runs daily at 2 AM. Removes sessions expired for more than 30 days.
""" """
try: try:
with transaction_scope() as db: async with AsyncSessionLocal() as db:
count = session_crud.cleanup_expired(db, keep_days=30) count = await session_repo.cleanup_expired(db, keep_days=30)
logger.info(f"Cleaned up {count} expired sessions") logger.info(f"Cleaned up {count} expired sessions")
except Exception as e: except Exception as e:
logger.error(f"Error cleaning up sessions: {str(e)}", exc_info=True) logger.error(f"Error cleaning up sessions: {str(e)}", exc_info=True)
@@ -1110,7 +1064,7 @@ async def cleanup_expired_sessions():
│Integration │ ← API endpoint tests │Integration │ ← API endpoint tests
│ Tests │ │ Tests │
├─────────────┤ ├─────────────┤
│ Unit │ ← CRUD, services, utilities │ Unit │ ← repositories, services, utilities
│ Tests │ │ Tests │
└─────────────┘ └─────────────┘
``` ```

View File

@@ -75,15 +75,14 @@ def create_user(db: Session, user_in: UserCreate) -> User:
### 4. Code Formatting ### 4. Code Formatting
Use automated formatters: Use automated formatters:
- **Black**: Code formatting - **Ruff**: Code formatting and linting (replaces Black, isort, flake8)
- **isort**: Import sorting - **pyright**: Static type checking
- **flake8**: Linting
Run before committing: Run before committing (or use `make validate`):
```bash ```bash
black app tests uv run ruff format app tests
isort app tests uv run ruff check app tests
flake8 app tests uv run pyright app
``` ```
## Code Organization ## Code Organization
@@ -94,19 +93,17 @@ Follow the 5-layer architecture strictly:
``` ```
API Layer (routes/) API Layer (routes/)
↓ calls ↓ calls (via service injected from dependencies/services.py)
Dependencies (dependencies/)
↓ injects
Service Layer (services/) Service Layer (services/)
↓ calls ↓ calls
CRUD Layer (crud/) Repository Layer (repositories/)
↓ uses ↓ uses
Models & Schemas (models/, schemas/) Models & Schemas (models/, schemas/)
``` ```
**Rules:** **Rules:**
- Routes should NOT directly call CRUD operations (use services when business logic is needed) - Routes must NEVER import repositories directly — always use a service
- CRUD operations should NOT contain business logic - Services call repositories; repositories contain only database operations
- Models should NOT import from higher layers - Models should NOT import from higher layers
- Each layer should only depend on the layer directly below it - Each layer should only depend on the layer directly below it
@@ -125,7 +122,7 @@ from sqlalchemy.orm import Session
# 3. Local application imports # 3. Local application imports
from app.api.dependencies.auth import get_current_user from app.api.dependencies.auth import get_current_user
from app.crud import user_crud from app.api.dependencies.services import get_user_service
from app.models.user import User from app.models.user import User
from app.schemas.users import UserResponse, UserCreate from app.schemas.users import UserResponse, UserCreate
``` ```
@@ -442,19 +439,19 @@ backend/app/alembic/versions/
4. **Testability**: Easy to mock and test 4. **Testability**: Easy to mock and test
5. **Consistent Ordering**: Always order queries for pagination 5. **Consistent Ordering**: Always order queries for pagination
### Use the Async CRUD Base Class ### Use the Async Repository Base Class
Always inherit from `CRUDBase` for database operations: Always inherit from `RepositoryBase` for database operations:
```python ```python
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select from sqlalchemy import select
from app.crud.base import CRUDBase from app.repositories.base import RepositoryBase
from app.models.user import User from app.models.user import User
from app.schemas.users import UserCreate, UserUpdate from app.schemas.users import UserCreate, UserUpdate
class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]): class UserRepository(RepositoryBase[User, UserCreate, UserUpdate]):
"""CRUD operations for User model.""" """Repository for User model — database operations only."""
async def get_by_email( async def get_by_email(
self, self,
@@ -467,7 +464,7 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
) )
return result.scalar_one_or_none() return result.scalar_one_or_none()
user_crud = CRUDUser(User) user_repo = UserRepository(User)
``` ```
**Key Points:** **Key Points:**
@@ -476,6 +473,7 @@ user_crud = CRUDUser(User)
- Use `await db.execute()` for queries - Use `await db.execute()` for queries
- Use `.scalar_one_or_none()` instead of `.first()` - Use `.scalar_one_or_none()` instead of `.first()`
- Use `T | None` instead of `Optional[T]` - Use `T | None` instead of `Optional[T]`
- Repository instances are used internally by services — never import them in routes
### Modern SQLAlchemy Patterns ### Modern SQLAlchemy Patterns
@@ -563,7 +561,7 @@ async def create_user(
The database session is automatically managed by FastAPI. The database session is automatically managed by FastAPI.
Commit on success, rollback on error. Commit on success, rollback on error.
""" """
return await user_crud.create(db, obj_in=user_in) return await user_service.create_user(db, obj_in=user_in)
``` ```
**Key Points:** **Key Points:**
@@ -582,12 +580,11 @@ async def complex_operation(
""" """
Perform multiple database operations atomically. Perform multiple database operations atomically.
The session automatically commits on success or rolls back on error. Services call repositories; commit/rollback is handled inside
each repository method.
""" """
user = await user_crud.create(db, obj_in=user_data) user = await user_repo.create(db, obj_in=user_data)
session = await session_crud.create(db, obj_in=session_data) session = await session_repo.create(db, obj_in=session_data)
# Commit is handled by the route's dependency
return user, session return user, session
``` ```
@@ -597,10 +594,10 @@ Prefer soft deletes over hard deletes for audit trails:
```python ```python
# Good - Soft delete (sets deleted_at) # Good - Soft delete (sets deleted_at)
await user_crud.soft_delete(db, id=user_id) await user_repo.soft_delete(db, id=user_id)
# Acceptable only when required - Hard delete # Acceptable only when required - Hard delete
user_crud.remove(db, id=user_id) await user_repo.remove(db, id=user_id)
``` ```
### Query Patterns ### Query Patterns
@@ -740,9 +737,10 @@ Always implement pagination for list endpoints:
from app.schemas.common import PaginationParams, PaginatedResponse from app.schemas.common import PaginationParams, PaginatedResponse
@router.get("/users", response_model=PaginatedResponse[UserResponse]) @router.get("/users", response_model=PaginatedResponse[UserResponse])
def list_users( async def list_users(
pagination: PaginationParams = Depends(), pagination: PaginationParams = Depends(),
db: Session = Depends(get_db) user_service: UserService = Depends(get_user_service),
db: AsyncSession = Depends(get_db),
): ):
""" """
List all users with pagination. List all users with pagination.
@@ -750,10 +748,8 @@ def list_users(
Default page size: 20 Default page size: 20
Maximum page size: 100 Maximum page size: 100
""" """
users, total = user_crud.get_multi_with_total( users, total = await user_service.get_users(
db, db, skip=pagination.offset, limit=pagination.limit
skip=pagination.offset,
limit=pagination.limit
) )
return PaginatedResponse(data=users, pagination=pagination.create_meta(total)) return PaginatedResponse(data=users, pagination=pagination.create_meta(total))
``` ```
@@ -816,19 +812,17 @@ def admin_route(
pass pass
# Check ownership # Check ownership
def delete_resource( async def delete_resource(
resource_id: UUID, resource_id: UUID,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db) resource_service: ResourceService = Depends(get_resource_service),
db: AsyncSession = Depends(get_db),
): ):
resource = resource_crud.get(db, id=resource_id) # Service handles ownership check and raises appropriate errors
if not resource: await resource_service.delete_resource(
raise NotFoundError("Resource not found") db, resource_id=resource_id, user_id=current_user.id,
is_superuser=current_user.is_superuser,
if resource.user_id != current_user.id and not current_user.is_superuser: )
raise AuthorizationError("You can only delete your own resources")
resource_crud.remove(db, id=resource_id)
``` ```
### Input Validation ### Input Validation
@@ -862,9 +856,9 @@ tests/
├── api/ # Integration tests ├── api/ # Integration tests
│ ├── test_users.py │ ├── test_users.py
│ └── test_auth.py │ └── test_auth.py
├── crud/ # Unit tests for CRUD ├── repositories/ # Unit tests for repositories
├── models/ # Model tests ├── services/ # Unit tests for services
└── services/ # Service tests └── models/ # Model tests
``` ```
### Async Testing with pytest-asyncio ### Async Testing with pytest-asyncio
@@ -927,7 +921,7 @@ async def test_user(db_session: AsyncSession) -> User:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_user(db_session: AsyncSession, test_user: User): async def test_get_user(db_session: AsyncSession, test_user: User):
"""Test retrieving a user by ID.""" """Test retrieving a user by ID."""
user = await user_crud.get(db_session, id=test_user.id) user = await user_repo.get(db_session, id=test_user.id)
assert user is not None assert user is not None
assert user.email == test_user.email assert user.email == test_user.email
``` ```

View File

@@ -616,7 +616,43 @@ def create_user(
return user return user
``` ```
**Rule**: Add type hints to ALL functions. Use `mypy` to enforce type checking. **Rule**: Add type hints to ALL functions. Use `pyright` to enforce type checking (`make type-check`).
---
---
### ❌ PITFALL #19: Importing Repositories Directly in Routes
**Issue**: Routes should never call repositories directly. The layered architecture requires all business operations to go through the service layer.
```python
# ❌ WRONG - Route bypasses service layer
from app.repositories.session import session_repo
@router.get("/sessions/me")
async def list_sessions(
current_user: User = Depends(get_current_active_user),
db: AsyncSession = Depends(get_db),
):
return await session_repo.get_user_sessions(db, user_id=current_user.id)
```
```python
# ✅ CORRECT - Route calls service injected via dependency
from app.api.dependencies.services import get_session_service
from app.services.session_service import SessionService
@router.get("/sessions/me")
async def list_sessions(
current_user: User = Depends(get_current_active_user),
session_service: SessionService = Depends(get_session_service),
db: AsyncSession = Depends(get_db),
):
return await session_service.get_user_sessions(db, user_id=current_user.id)
```
**Rule**: Routes import from `app.api.dependencies.services`, never from `app.repositories.*`. Services are the only callers of repositories.
--- ---
@@ -649,6 +685,11 @@ Use this checklist to catch issues before code review:
- [ ] Resource ownership verification - [ ] Resource ownership verification
- [ ] CORS configured (no wildcards in production) - [ ] CORS configured (no wildcards in production)
### Architecture
- [ ] Routes never import repositories directly (only services)
- [ ] Services call repositories; repositories call database only
- [ ] New service registered in `app/api/dependencies/services.py`
### Python ### Python
- [ ] Use `==` not `is` for value comparison - [ ] Use `==` not `is` for value comparison
- [ ] No mutable default arguments - [ ] No mutable default arguments
@@ -661,21 +702,18 @@ Use this checklist to catch issues before code review:
### Pre-commit Checks ### Pre-commit Checks
Add these to your development workflow: Add these to your development workflow (or use `make validate`):
```bash ```bash
# Format code # Format + lint (Ruff replaces Black, isort, flake8)
black app tests uv run ruff format app tests
isort app tests uv run ruff check app tests
# Type checking # Type checking
mypy app --strict uv run pyright app
# Linting
flake8 app tests
# Run tests # Run tests
pytest --cov=app --cov-report=term-missing IS_TEST=True uv run pytest --cov=app --cov-report=term-missing
# Check coverage (should be 80%+) # Check coverage (should be 80%+)
coverage report --fail-under=80 coverage report --fail-under=80
@@ -693,6 +731,6 @@ Add new entries when:
--- ---
**Last Updated**: 2025-10-31 **Last Updated**: 2026-02-28
**Issues Cataloged**: 18 common pitfalls **Issues Cataloged**: 19 common pitfalls
**Remember**: This document exists because these issues HAVE occurred. Don't skip it. **Remember**: This document exists because these issues HAVE occurred. Don't skip it.

View File

@@ -8,7 +8,7 @@ This guide walks through implementing a complete feature using the **User Sessio
- [Implementation Steps](#implementation-steps) - [Implementation Steps](#implementation-steps)
- [Step 1: Design the Database Model](#step-1-design-the-database-model) - [Step 1: Design the Database Model](#step-1-design-the-database-model)
- [Step 2: Create Pydantic Schemas](#step-2-create-pydantic-schemas) - [Step 2: Create Pydantic Schemas](#step-2-create-pydantic-schemas)
- [Step 3: Implement CRUD Operations](#step-3-implement-crud-operations) - [Step 3: Implement Repository](#step-3-implement-repository)
- [Step 4: Create API Endpoints](#step-4-create-api-endpoints) - [Step 4: Create API Endpoints](#step-4-create-api-endpoints)
- [Step 5: Integrate with Existing Features](#step-5-integrate-with-existing-features) - [Step 5: Integrate with Existing Features](#step-5-integrate-with-existing-features)
- [Step 6: Add Background Jobs](#step-6-add-background-jobs) - [Step 6: Add Background Jobs](#step-6-add-background-jobs)
@@ -204,8 +204,8 @@ Follow the standard pattern:
``` ```
SessionBase (common fields) SessionBase (common fields)
├── SessionCreate (internal: CRUD operations) ├── SessionCreate (internal: repository operations)
├── SessionUpdate (internal: CRUD operations) ├── SessionUpdate (internal: repository operations)
└── SessionResponse (external: API responses) └── SessionResponse (external: API responses)
``` ```
@@ -240,7 +240,7 @@ class SessionCreate(SessionBase):
""" """
Schema for creating a new session (internal use). Schema for creating a new session (internal use).
Used by CRUD operations, not exposed to API. Used by repository operations, not exposed to API.
Contains all fields needed to create a session. Contains all fields needed to create a session.
""" """
user_id: UUID user_id: UUID
@@ -344,37 +344,37 @@ class DeviceInfo(BaseModel):
5. **OpenAPI Documentation**: `json_schema_extra` provides examples in API docs 5. **OpenAPI Documentation**: `json_schema_extra` provides examples in API docs
6. **Type Safety**: Comprehensive type hints for all fields 6. **Type Safety**: Comprehensive type hints for all fields
### Step 3: Implement CRUD Operations ### Step 3: Implement Repository
**File**: `app/crud/session.py` **File**: `app/repositories/session.py`
CRUD layer handles all database operations. No business logic here! The repository layer handles all database operations. No business logic here — that belongs in services!
#### 3.1 Extend the Base CRUD Class #### 3.1 Extend the Base Repository Class
```python ```python
""" """
CRUD operations for user sessions. Repository for user sessions.
""" """
from datetime import datetime, timezone, timedelta from datetime import datetime, timezone, timedelta
from typing import List, Optional
from uuid import UUID from uuid import UUID
from sqlalchemy.orm import Session
from sqlalchemy import and_ from sqlalchemy import and_, select, update
from sqlalchemy.ext.asyncio import AsyncSession
import logging import logging
from app.crud.base import CRUDBase from app.repositories.base import RepositoryBase
from app.models.user_session import UserSession from app.models.user_session import UserSession
from app.schemas.sessions import SessionCreate, SessionUpdate from app.schemas.sessions import SessionCreate, SessionUpdate
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]): class SessionRepository(RepositoryBase[UserSession, SessionCreate, SessionUpdate]):
""" """
CRUD operations for user sessions. Repository for user sessions.
Inherits standard operations from CRUDBase: Inherits standard operations from RepositoryBase:
- get(db, id) - Get by ID - get(db, id) - Get by ID
- get_multi(db, skip, limit) - List with pagination - get_multi(db, skip, limit) - List with pagination
- create(db, obj_in) - Create new session - create(db, obj_in) - Create new session
@@ -382,111 +382,62 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
- remove(db, id) - Delete session - remove(db, id) - Delete session
""" """
# Custom query methods async def get_by_jti(self, db: AsyncSession, *, jti: str) -> UserSession | None:
# --------------------
def get_by_jti(self, db: Session, *, jti: str) -> Optional[UserSession]:
""" """
Get session by refresh token JTI. Get session by refresh token JTI.
Used during token refresh to find the corresponding session. Used during token refresh to find the corresponding session.
Args:
db: Database session
jti: Refresh token JWT ID
Returns:
UserSession if found, None otherwise
""" """
try: result = await db.execute(
return db.query(UserSession).filter( select(UserSession).where(UserSession.refresh_token_jti == jti)
UserSession.refresh_token_jti == jti )
).first() return result.scalar_one_or_none()
except Exception as e:
logger.error(f"Error getting session by JTI {jti}: {str(e)}")
raise
def get_active_by_jti(self, db: Session, *, jti: str) -> Optional[UserSession]: async def get_active_by_jti(self, db: AsyncSession, *, jti: str) -> UserSession | None:
""" """Get active (non-expired) session by refresh token JTI."""
Get active session by refresh token JTI. result = await db.execute(
select(UserSession).where(
Only returns the session if it's currently active.
Args:
db: Database session
jti: Refresh token JWT ID
Returns:
Active UserSession if found, None otherwise
"""
try:
return db.query(UserSession).filter(
and_( and_(
UserSession.refresh_token_jti == jti, UserSession.refresh_token_jti == jti,
UserSession.is_active == True UserSession.is_active.is_(True),
) )
).first() )
except Exception as e: )
logger.error(f"Error getting active session by JTI {jti}: {str(e)}") session = result.scalar_one_or_none()
raise if session and not session.is_expired:
return session
return None
def get_user_sessions( async def get_user_sessions(
self, self,
db: Session, db: AsyncSession,
*, *,
user_id: str, user_id: UUID,
active_only: bool = True active_only: bool = True,
) -> List[UserSession]: ) -> list[UserSession]:
""" """
Get all sessions for a user. Get all sessions for a user, ordered by most recently used.
Args:
db: Database session
user_id: User ID
active_only: If True, return only active sessions
Returns:
List of UserSession objects, ordered by most recently used
""" """
try: query = select(UserSession).where(UserSession.user_id == user_id)
# Convert user_id string to UUID if needed
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
query = db.query(UserSession).filter(UserSession.user_id == user_uuid)
if active_only: if active_only:
query = query.filter(UserSession.is_active == True) query = query.where(UserSession.is_active.is_(True))
query = query.order_by(UserSession.last_used_at.desc())
result = await db.execute(query)
return list(result.scalars().all())
# Order by most recently used first async def create_session(
return query.order_by(UserSession.last_used_at.desc()).all()
except Exception as e:
logger.error(f"Error getting sessions for user {user_id}: {str(e)}")
raise
# Creation methods
# ----------------
def create_session(
self, self,
db: Session, db: AsyncSession,
*, *,
obj_in: SessionCreate obj_in: SessionCreate,
) -> UserSession: ) -> UserSession:
""" """
Create a new user session. Create a new user session.
Args:
db: Database session
obj_in: SessionCreate schema with session data
Returns:
Created UserSession
Raises: Raises:
ValueError: If session creation fails DuplicateEntryError: If a session with the same JTI already exists
""" """
try: try:
# Create model instance from schema
db_obj = UserSession( db_obj = UserSession(
user_id=obj_in.user_id, user_id=obj_in.user_id,
refresh_token_jti=obj_in.refresh_token_jti, refresh_token_jti=obj_in.refresh_token_jti,
@@ -501,248 +452,93 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
location_country=obj_in.location_country, location_country=obj_in.location_country,
) )
db.add(db_obj) db_obj.add(db_obj)
db.commit() await db.commit()
db.refresh(db_obj) await db.refresh(db_obj)
logger.info( logger.info(
f"Session created for user {obj_in.user_id} from {obj_in.device_name} " f"Session created for user {obj_in.user_id} from {obj_in.device_name}"
f"(IP: {obj_in.ip_address})"
) )
return db_obj return db_obj
except Exception as e: except Exception as e:
db.rollback() await db.rollback()
logger.error(f"Error creating session: {str(e)}", exc_info=True) logger.error(f"Error creating session: {str(e)}", exc_info=True)
raise ValueError(f"Failed to create session: {str(e)}")
# Update methods
# --------------
def deactivate(self, db: Session, *, session_id: str) -> Optional[UserSession]:
"""
Deactivate a session (logout from device).
Args:
db: Database session
session_id: Session UUID
Returns:
Deactivated UserSession if found, None otherwise
"""
try:
session = self.get(db, id=session_id)
if not session:
logger.warning(f"Session {session_id} not found for deactivation")
return None
session.is_active = False
db.add(session)
db.commit()
db.refresh(session)
logger.info(
f"Session {session_id} deactivated for user {session.user_id} "
f"({session.device_name})"
)
return session
except Exception as e:
db.rollback()
logger.error(f"Error deactivating session {session_id}: {str(e)}")
raise raise
def deactivate_all_user_sessions( async def deactivate(self, db: AsyncSession, *, session_id: UUID) -> UserSession | None:
"""Deactivate a session (logout from device)."""
session = await self.get(db, id=session_id)
if not session:
return None
session.is_active = False
await db.commit()
await db.refresh(session)
logger.info(f"Session {session_id} deactivated ({session.device_name})")
return session
async def deactivate_all_user_sessions(
self, self,
db: Session, db: AsyncSession,
*, *,
user_id: str user_id: UUID,
) -> int: ) -> int:
""" """
Deactivate all active sessions for a user (logout from all devices). Deactivate all active sessions for a user (logout from all devices).
Uses bulk update for efficiency. Uses a bulk UPDATE for efficiency — no N+1 queries.
Args:
db: Database session
user_id: User ID
Returns:
Number of sessions deactivated
""" """
try: result = await db.execute(
# Convert user_id string to UUID if needed update(UserSession)
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id .where(
# Bulk update query
count = db.query(UserSession).filter(
and_(
UserSession.user_id == user_uuid,
UserSession.is_active == True
)
).update({"is_active": False})
db.commit()
logger.info(f"Deactivated {count} sessions for user {user_id}")
return count
except Exception as e:
db.rollback()
logger.error(f"Error deactivating all sessions for user {user_id}: {str(e)}")
raise
def update_last_used(
self,
db: Session,
*,
session: UserSession
) -> UserSession:
"""
Update the last_used_at timestamp for a session.
Called when a refresh token is used.
Args:
db: Database session
session: UserSession object
Returns:
Updated UserSession
"""
try:
session.last_used_at = datetime.now(timezone.utc)
db.add(session)
db.commit()
db.refresh(session)
return session
except Exception as e:
db.rollback()
logger.error(f"Error updating last_used for session {session.id}: {str(e)}")
raise
def update_refresh_token(
self,
db: Session,
*,
session: UserSession,
new_jti: str,
new_expires_at: datetime
) -> UserSession:
"""
Update session with new refresh token JTI and expiration.
Called during token refresh (token rotation).
Args:
db: Database session
session: UserSession object
new_jti: New refresh token JTI
new_expires_at: New expiration datetime
Returns:
Updated UserSession
"""
try:
session.refresh_token_jti = new_jti
session.expires_at = new_expires_at
session.last_used_at = datetime.now(timezone.utc)
db.add(session)
db.commit()
db.refresh(session)
return session
except Exception as e:
db.rollback()
logger.error(f"Error updating refresh token for session {session.id}: {str(e)}")
raise
# Cleanup methods
# ---------------
def cleanup_expired(self, db: Session, *, keep_days: int = 30) -> int:
"""
Clean up expired sessions.
Deletes sessions that are:
- Expired (expires_at < now) AND inactive
- Older than keep_days (for audit trail)
Args:
db: Database session
keep_days: Keep inactive sessions for this many days
Returns:
Number of sessions deleted
"""
try:
cutoff_date = datetime.now(timezone.utc) - timedelta(days=keep_days)
count = db.query(UserSession).filter(
and_(
UserSession.is_active == False,
UserSession.expires_at < datetime.now(timezone.utc),
UserSession.created_at < cutoff_date
)
).delete()
db.commit()
if count > 0:
logger.info(f"Cleaned up {count} expired sessions")
return count
except Exception as e:
db.rollback()
logger.error(f"Error cleaning up expired sessions: {str(e)}")
raise
# Utility methods
# ---------------
def get_user_session_count(self, db: Session, *, user_id: str) -> int:
"""
Get count of active sessions for a user.
Useful for session limits or security monitoring.
Args:
db: Database session
user_id: User ID
Returns:
Number of active sessions
"""
try:
return db.query(UserSession).filter(
and_( and_(
UserSession.user_id == user_id, UserSession.user_id == user_id,
UserSession.is_active == True UserSession.is_active.is_(True),
) )
).count() )
except Exception as e: .values(is_active=False)
logger.error(f"Error counting sessions for user {user_id}: {str(e)}") )
raise await db.commit()
count = result.rowcount
logger.info(f"Deactivated {count} sessions for user {user_id}")
return count
async def cleanup_expired(self, db: AsyncSession, *, keep_days: int = 30) -> int:
"""
Hard-delete inactive sessions older than keep_days.
Returns the number of sessions deleted.
"""
cutoff_date = datetime.now(timezone.utc) - timedelta(days=keep_days)
result = await db.execute(
select(UserSession).where(
and_(
UserSession.is_active.is_(False),
UserSession.expires_at < datetime.now(timezone.utc),
UserSession.created_at < cutoff_date,
)
)
)
sessions = list(result.scalars().all())
for s in sessions:
await db.delete(s)
await db.commit()
if sessions:
logger.info(f"Cleaned up {len(sessions)} expired sessions")
return len(sessions)
# Create singleton instance # Singleton instance — used by services, never imported directly in routes
# This is the instance that will be imported and used throughout the app session_repo = SessionRepository(UserSession)
session = CRUDSession(UserSession)
``` ```
**Key Patterns**: **Key Patterns**:
1. **Error Handling**: Every method has try/except with rollback 1. **Async everywhere**: All methods use `async def` and `await`
2. **Logging**: Log all significant actions (create, delete, errors) 2. **Modern SQLAlchemy**: `select()` API, never `db.query()`
3. **Type Safety**: Full type hints for parameters and returns 3. **Bulk updates**: Use `update()` statement for multi-row changes (no N+1)
4. **Docstrings**: Document what each method does, args, returns, raises 4. **Error handling**: `try/except` with `await db.rollback()` in mutating methods
5. **Bulk Operations**: Use `query().update()` for efficiency when updating many rows 5. **Logging**: Log all significant actions (create, delete, errors)
6. **UUID Handling**: Convert string UUIDs to UUID objects when needed 6. **Type safety**: Full type hints; `UUID` not raw `str` for IDs
7. **Ordering**: Return results in a logical order (most recent first) 7. **Singleton pattern**: One module-level instance used by services
8. **Singleton Pattern**: Create one instance to be imported elsewhere
### Step 4: Create API Endpoints ### Step 4: Create API Endpoints
@@ -772,14 +568,15 @@ from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, status, Request from fastapi import APIRouter, Depends, HTTPException, status, Request
from slowapi import Limiter from slowapi import Limiter
from slowapi.util import get_remote_address from slowapi.util import get_remote_address
from sqlalchemy.orm import Session from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.auth import get_current_user from app.api.dependencies.auth import get_current_user
from app.api.dependencies.services import get_session_service
from app.core.database import get_db from app.core.database import get_db
from app.models.user import User from app.models.user import User
from app.schemas.sessions import SessionResponse, SessionListResponse from app.schemas.sessions import SessionResponse, SessionListResponse
from app.schemas.common import MessageResponse from app.schemas.common import MessageResponse
from app.crud.session import session as session_crud from app.services.session_service import SessionService
from app.core.exceptions import NotFoundError, AuthorizationError, ErrorCode from app.core.exceptions import NotFoundError, AuthorizationError, ErrorCode
router = APIRouter() router = APIRouter()
@@ -803,61 +600,21 @@ limiter = Limiter(key_func=get_remote_address)
operation_id="list_my_sessions" operation_id="list_my_sessions"
) )
@limiter.limit("30/minute") @limiter.limit("30/minute")
def list_my_sessions( async def list_my_sessions(
request: Request, request: Request,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db) session_service: SessionService = Depends(get_session_service),
db: AsyncSession = Depends(get_db),
) -> Any: ) -> Any:
""" """List all active sessions for the current user."""
List all active sessions for the current user. sessions = await session_service.get_user_sessions(
db, user_id=current_user.id, active_only=True
Args:
request: FastAPI request object (for rate limiting)
current_user: Current authenticated user (injected)
db: Database session (injected)
Returns:
SessionListResponse with list of active sessions
"""
try:
# Get all active sessions for user
sessions = session_crud.get_user_sessions(
db,
user_id=str(current_user.id),
active_only=True
)
# Convert to response format
session_responses = []
for idx, s in enumerate(sessions):
session_response = SessionResponse(
id=s.id,
device_name=s.device_name,
device_id=s.device_id,
ip_address=s.ip_address,
location_city=s.location_city,
location_country=s.location_country,
last_used_at=s.last_used_at,
created_at=s.created_at,
expires_at=s.expires_at,
# Mark the most recently used session as current
is_current=(idx == 0)
)
session_responses.append(session_response)
logger.info(f"User {current_user.id} listed {len(session_responses)} active sessions")
return SessionListResponse(
sessions=session_responses,
total=len(session_responses)
)
except Exception as e:
logger.error(f"Error listing sessions for user {current_user.id}: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve sessions"
) )
session_responses = [
SessionResponse.model_validate(s) | {"is_current": idx == 0}
for idx, s in enumerate(sessions)
]
return SessionListResponse(sessions=session_responses, total=len(session_responses))
@router.delete( @router.delete(
@@ -876,70 +633,25 @@ def list_my_sessions(
operation_id="revoke_session" operation_id="revoke_session"
) )
@limiter.limit("10/minute") @limiter.limit("10/minute")
def revoke_session( async def revoke_session(
request: Request, request: Request,
session_id: UUID, session_id: UUID,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db) session_service: SessionService = Depends(get_session_service),
db: AsyncSession = Depends(get_db),
) -> Any: ) -> Any:
""" """
Revoke a specific session by ID. Revoke a specific session by ID.
Args: The service verifies ownership and raises NotFoundError /
request: FastAPI request object (for rate limiting) AuthorizationError which are handled by global exception handlers.
session_id: UUID of the session to revoke
current_user: Current authenticated user (injected)
db: Database session (injected)
Returns:
MessageResponse with success message
Raises:
NotFoundError: If session doesn't exist
AuthorizationError: If session belongs to another user
""" """
try: device_name = await session_service.revoke_session(
# Get the session db, session_id=session_id, user_id=current_user.id
session = session_crud.get(db, id=str(session_id))
if not session:
raise NotFoundError(
message=f"Session {session_id} not found",
error_code=ErrorCode.NOT_FOUND
) )
# Verify session belongs to current user (authorization check)
if str(session.user_id) != str(current_user.id):
logger.warning(
f"User {current_user.id} attempted to revoke session {session_id} "
f"belonging to user {session.user_id}"
)
raise AuthorizationError(
message="You can only revoke your own sessions",
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS
)
# Deactivate the session
session_crud.deactivate(db, session_id=str(session_id))
logger.info(
f"User {current_user.id} revoked session {session_id} "
f"({session.device_name})"
)
return MessageResponse( return MessageResponse(
success=True, success=True,
message=f"Session revoked: {session.device_name or 'Unknown device'}" message=f"Session revoked: {device_name or 'Unknown device'}"
)
except (NotFoundError, AuthorizationError):
# Re-raise custom exceptions (they'll be handled by global handlers)
raise
except Exception as e:
logger.error(f"Error revoking session {session_id}: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to revoke session"
) )
@@ -958,55 +670,20 @@ def revoke_session(
operation_id="cleanup_expired_sessions" operation_id="cleanup_expired_sessions"
) )
@limiter.limit("5/minute") @limiter.limit("5/minute")
def cleanup_expired_sessions( async def cleanup_expired_sessions(
request: Request, request: Request,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db) session_service: SessionService = Depends(get_session_service),
db: AsyncSession = Depends(get_db),
) -> Any: ) -> Any:
""" """Cleanup expired sessions for the current user."""
Cleanup expired sessions for the current user. deleted_count = await session_service.cleanup_user_expired_sessions(
db, user_id=current_user.id
Args:
request: FastAPI request object (for rate limiting)
current_user: Current authenticated user (injected)
db: Database session (injected)
Returns:
MessageResponse with count of sessions cleaned
"""
try:
from datetime import datetime, timezone
# Get all sessions for user (including inactive)
all_sessions = session_crud.get_user_sessions(
db,
user_id=str(current_user.id),
active_only=False
) )
# Delete expired and inactive sessions
deleted_count = 0
for s in all_sessions:
if not s.is_active and s.expires_at < datetime.now(timezone.utc):
db.delete(s)
deleted_count += 1
db.commit()
logger.info(f"User {current_user.id} cleaned up {deleted_count} expired sessions")
return MessageResponse( return MessageResponse(
success=True, success=True,
message=f"Cleaned up {deleted_count} expired sessions" message=f"Cleaned up {deleted_count} expired sessions"
) )
except Exception as e:
logger.error(f"Error cleaning up sessions for user {current_user.id}: {str(e)}", exc_info=True)
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to cleanup sessions"
)
``` ```
**Key Patterns**: **Key Patterns**:
@@ -1079,59 +756,25 @@ Session management needs to be integrated into the authentication flow.
```python ```python
from app.utils.device import extract_device_info from app.utils.device import extract_device_info
from app.crud.session import session as session_crud from app.api.dependencies.services import get_auth_service
from app.schemas.sessions import SessionCreate from app.services.auth_service import AuthService
@router.post("/login") @router.post("/login")
async def login( async def login(
request: Request, request: Request,
credentials: OAuth2PasswordRequestForm = Depends(), credentials: OAuth2PasswordRequestForm = Depends(),
db: Session = Depends(get_db) auth_service: AuthService = Depends(get_auth_service),
db: AsyncSession = Depends(get_db),
): ):
"""Authenticate user and create session.""" """Authenticate user and create session."""
# All business logic (validate credentials, create session, generate tokens)
# 1. Validate credentials # is delegated to AuthService which calls the appropriate repositories.
user = user_crud.get_by_email(db, email=credentials.username) return await auth_service.login(
if not user or not verify_password(credentials.password, user.hashed_password): db,
raise AuthenticationError("Invalid credentials") email=credentials.username,
password=credentials.password,
if not user.is_active: request=request,
raise AuthenticationError("Account is inactive")
# 2. Extract device information from request
device_info = extract_device_info(request)
# 3. Generate tokens
jti = str(uuid.uuid4()) # Generate JTI for refresh token
access_token = create_access_token(subject=str(user.id))
refresh_token = create_refresh_token(subject=str(user.id), jti=jti)
# 4. Create session record
from datetime import datetime, timezone, timedelta
session_data = SessionCreate(
user_id=user.id,
refresh_token_jti=jti,
device_name=device_info.device_name,
device_id=device_info.device_id,
ip_address=device_info.ip_address,
user_agent=device_info.user_agent,
last_used_at=datetime.now(timezone.utc),
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
location_city=device_info.location_city,
location_country=device_info.location_country,
) )
session_crud.create_session(db, obj_in=session_data)
logger.info(f"User {user.email} logged in from {device_info.device_name}")
return {
"access_token": access_token,
"refresh_token": refresh_token,
"token_type": "bearer",
"user": UserResponse.model_validate(user)
}
``` ```
#### 5.2 Create Device Info Utility #### 5.2 Create Device Info Utility
@@ -1193,89 +836,35 @@ def extract_device_info(request: Request) -> DeviceInfo:
```python ```python
@router.post("/refresh") @router.post("/refresh")
def refresh_token( async def refresh_token(
refresh_request: RefreshRequest, refresh_request: RefreshRequest,
db: Session = Depends(get_db) auth_service: AuthService = Depends(get_auth_service),
db: AsyncSession = Depends(get_db),
): ):
"""Refresh access token using refresh token.""" """Refresh access token using refresh token."""
# AuthService handles token validation, session lookup, token rotation
try: return await auth_service.refresh_tokens(
# 1. Decode and validate refresh token db, refresh_token=refresh_request.refresh_token
payload = decode_token(refresh_request.refresh_token)
if payload.get("type") != "refresh":
raise AuthenticationError("Invalid token type")
user_id = UUID(payload.get("sub"))
jti = payload.get("jti")
# 2. Find and validate session
session = session_crud.get_active_by_jti(db, jti=jti)
if not session:
raise AuthenticationError("Session not found or expired")
if session.user_id != user_id:
raise AuthenticationError("Token mismatch")
# 3. Generate new tokens (token rotation)
new_jti = str(uuid.uuid4())
new_access_token = create_access_token(subject=str(user_id))
new_refresh_token = create_refresh_token(subject=str(user_id), jti=new_jti)
# 4. Update session with new JTI
session_crud.update_refresh_token(
db,
session=session,
new_jti=new_jti,
new_expires_at=datetime.now(timezone.utc) + timedelta(days=7)
) )
logger.info(f"Tokens refreshed for user {user_id}")
return {
"access_token": new_access_token,
"refresh_token": new_refresh_token,
"token_type": "bearer"
}
except Exception as e:
logger.error(f"Token refresh failed: {str(e)}")
raise AuthenticationError("Failed to refresh token")
``` ```
#### 5.4 Update Logout Endpoint #### 5.4 Update Logout Endpoint
```python ```python
@router.post("/logout") @router.post("/logout")
def logout( async def logout(
logout_request: LogoutRequest, logout_request: LogoutRequest,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db) auth_service: AuthService = Depends(get_auth_service),
db: AsyncSession = Depends(get_db),
): ):
"""Logout from current device.""" """Logout from current device."""
await auth_service.logout(
try: db,
# Decode refresh token to get JTI refresh_token=logout_request.refresh_token,
payload = decode_token(logout_request.refresh_token) user_id=current_user.id,
jti = payload.get("jti")
# Find and deactivate session
session = session_crud.get_by_jti(db, jti=jti)
if session and session.user_id == current_user.id:
session_crud.deactivate(db, session_id=str(session.id))
logger.info(f"User {current_user.id} logged out from {session.device_name}")
return MessageResponse(
success=True,
message="Logged out successfully"
) )
return MessageResponse(success=True, message="Logged out successfully")
except Exception as e:
logger.error(f"Logout failed: {str(e)}")
# Even if cleanup fails, return success (user intended to logout)
return MessageResponse(success=True, message="Logged out")
``` ```
### Step 6: Add Background Jobs ### Step 6: Add Background Jobs
@@ -1287,8 +876,8 @@ def logout(
Background job for cleaning up expired sessions. Background job for cleaning up expired sessions.
""" """
import logging import logging
from app.core.database import SessionLocal from app.core.database import AsyncSessionLocal
from app.crud.session import session as session_crud from app.repositories.session import session_repo
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -1302,14 +891,12 @@ async def cleanup_expired_sessions():
- Inactive (is_active = False) - Inactive (is_active = False)
- Older than 30 days (for audit trail) - Older than 30 days (for audit trail)
""" """
db = SessionLocal() async with AsyncSessionLocal() as db:
try: try:
count = session_crud.cleanup_expired(db, keep_days=30) count = await session_repo.cleanup_expired(db, keep_days=30)
logger.info(f"Background cleanup: Removed {count} expired sessions") logger.info(f"Background cleanup: Removed {count} expired sessions")
except Exception as e: except Exception as e:
logger.error(f"Error in session cleanup job: {str(e)}", exc_info=True) logger.error(f"Error in session cleanup job: {str(e)}", exc_info=True)
finally:
db.close()
``` ```
**Register in** `app/main.py`: **Register in** `app/main.py`:
@@ -1679,7 +1266,8 @@ You've now implemented a complete feature! Here's what was created:
**Files Created/Modified**: **Files Created/Modified**:
1. `app/models/user_session.py` - Database model 1. `app/models/user_session.py` - Database model
2. `app/schemas/sessions.py` - Pydantic schemas 2. `app/schemas/sessions.py` - Pydantic schemas
3. `app/crud/session.py` - CRUD operations 3. `app/repositories/session.py` - Repository (data access)
4. `app/services/session_service.py` - Service (business logic)
4. `app/api/routes/sessions.py` - API endpoints 4. `app/api/routes/sessions.py` - API endpoints
5. `app/utils/device.py` - Device detection utility 5. `app/utils/device.py` - Device detection utility
6. `app/services/session_cleanup.py` - Background job 6. `app/services/session_cleanup.py` - Background job
@@ -1715,7 +1303,7 @@ You've now implemented a complete feature! Here's what was created:
### Don'ts ### Don'ts
1. **Don't Mix Layers**: Keep business logic out of CRUD, database ops out of routes 1. **Don't Mix Layers**: Keep business logic in services, database ops in repositories, routing in routes
2. **Don't Expose Internals**: Never return sensitive data in API responses 2. **Don't Expose Internals**: Never return sensitive data in API responses
3. **Don't Trust Input**: Always validate and sanitize user input 3. **Don't Trust Input**: Always validate and sanitize user input
4. **Don't Ignore Errors**: Always handle exceptions properly 4. **Don't Ignore Errors**: Always handle exceptions properly
@@ -1733,7 +1321,9 @@ When implementing a new feature, use this checklist:
- [ ] Design database schema - [ ] Design database schema
- [ ] Create SQLAlchemy model - [ ] Create SQLAlchemy model
- [ ] Design Pydantic schemas (Create, Update, Response) - [ ] Design Pydantic schemas (Create, Update, Response)
- [ ] Implement CRUD operations - [ ] Implement repository (data access)
- [ ] Implement service (business logic)
- [ ] Register service in `app/api/dependencies/services.py`
- [ ] Create API endpoints - [ ] Create API endpoints
- [ ] Add authentication/authorization - [ ] Add authentication/authorization
- [ ] Implement rate limiting - [ ] Implement rate limiting