refactor(docs): update architecture to reflect repository migration

- Rename CRUD layer to Repository layer throughout architecture documentation.
- Update dependency injection examples to use repository classes.
- Add async SQLAlchemy pattern for Repository methods (`select()` and transactions).
- Replace CRUD references in FEATURE_EXAMPLE.md with Repository-focused implementation details.
- Highlight repository class responsibilities and remove outdated CRUD patterns.
This commit is contained in:
2026-03-01 11:13:51 +01:00
parent 80d2dc0cb2
commit 68275b1dd3
4 changed files with 349 additions and 773 deletions

View File

@@ -117,7 +117,8 @@ backend/
│ ├── api/ # API layer
│ │ ├── dependencies/ # Dependency injection
│ │ │ ├── auth.py # Authentication dependencies
│ │ │ ── permissions.py # Authorization dependencies
│ │ │ ── permissions.py # Authorization dependencies
│ │ │ └── services.py # Service singleton injection
│ │ ├── routes/ # API endpoints
│ │ │ ├── auth.py # Authentication routes
│ │ │ ├── users.py # User management routes
@@ -131,13 +132,14 @@ backend/
│ │ ├── config.py # Application configuration
│ │ ├── database.py # Database connection
│ │ ├── exceptions.py # Custom exception classes
│ │ ├── repository_exceptions.py # Repository-level exception hierarchy
│ │ └── middleware.py # Custom middleware
│ │
│ ├── crud/ # Database operations
│ │ ├── base.py # Generic CRUD base class
│ │ ├── user.py # User CRUD operations
│ │ ├── session.py # Session CRUD operations
│ │ └── organization.py # Organization CRUD
│ ├── repositories/ # Data access layer
│ │ ├── base.py # Generic repository base class
│ │ ├── user.py # User repository
│ │ ├── session.py # Session repository
│ │ └── organization.py # Organization repository
│ │
│ ├── models/ # SQLAlchemy models
│ │ ├── base.py # Base model with mixins
@@ -153,8 +155,11 @@ backend/
│ │ ├── sessions.py # Session schemas
│ │ └── organizations.py # Organization schemas
│ │
│ ├── services/ # Business logic
│ ├── services/ # Business logic layer
│ │ ├── auth_service.py # Authentication service
│ │ ├── user_service.py # User management service
│ │ ├── session_service.py # Session management service
│ │ ├── organization_service.py # Organization service
│ │ ├── email_service.py # Email service
│ │ └── session_cleanup.py # Background cleanup
│ │
@@ -168,9 +173,9 @@ backend/
├── tests/ # Test suite
│ ├── api/ # Integration tests
│ ├── crud/ # CRUD tests
│ ├── repositories/ # Repository unit tests
│ ├── services/ # Service unit tests
│ ├── models/ # Model tests
│ ├── services/ # Service tests
│ └── conftest.py # Test configuration
├── docs/ # Documentation
@@ -214,11 +219,11 @@ The application follows a strict 5-layer architecture:
└──────────────────────────┬──────────────────────────────────┘
│ calls
┌──────────────────────────▼──────────────────────────────────┐
CRUD Layer (crud/)
Repository Layer (repositories/)
│ - Database operations │
│ - Query building │
│ - Transaction management
│ - Error handling
│ - Custom repository exceptions
│ - No business logic
└──────────────────────────┬──────────────────────────────────┘
│ uses
┌──────────────────────────▼──────────────────────────────────┐
@@ -262,7 +267,7 @@ async def get_current_user_info(
**Rules**:
- Should NOT contain business logic
- Should NOT directly perform database operations (use CRUD or services)
- Should NOT directly call repositories (use services injected via `dependencies/services.py`)
- Must validate all input via Pydantic schemas
- Must specify response models
- Should apply appropriate rate limits
@@ -279,9 +284,9 @@ async def get_current_user_info(
**Example**:
```python
def get_current_user(
async def get_current_user(
token: str = Depends(oauth2_scheme),
db: Session = Depends(get_db)
db: AsyncSession = Depends(get_db)
) -> User:
"""
Extract and validate user from JWT token.
@@ -295,7 +300,7 @@ def get_current_user(
except Exception:
raise AuthenticationError("Invalid authentication credentials")
user = user_crud.get(db, id=user_id)
user = await user_repo.get(db, id=user_id)
if not user:
raise AuthenticationError("User not found")
@@ -313,7 +318,7 @@ def get_current_user(
**Responsibility**: Implement complex business logic
**Key Functions**:
- Orchestrate multiple CRUD operations
- Orchestrate multiple repository operations
- Implement business rules
- Handle external service integration
- Coordinate transactions
@@ -323,9 +328,9 @@ def get_current_user(
class AuthService:
"""Authentication service with business logic."""
def login(
async def login(
self,
db: Session,
db: AsyncSession,
email: str,
password: str,
request: Request
@@ -339,8 +344,8 @@ class AuthService:
3. Generate tokens
4. Return tokens and user info
"""
# Validate credentials
user = user_crud.get_by_email(db, email=email)
# Validate credentials via repository
user = await user_repo.get_by_email(db, email=email)
if not user or not verify_password(password, user.hashed_password):
raise AuthenticationError("Invalid credentials")
@@ -350,11 +355,10 @@ class AuthService:
# Extract device info
device_info = extract_device_info(request)
# Create session
session = session_crud.create_session(
# Create session via repository
session = await session_repo.create(
db,
user_id=user.id,
device_info=device_info
obj_in=SessionCreate(user_id=user.id, **device_info)
)
# Generate tokens
@@ -373,75 +377,60 @@ class AuthService:
**Rules**:
- Contains business logic, not just data operations
- Can call multiple CRUD operations
- Can call multiple repository operations
- Should handle complex workflows
- Must maintain data consistency
- Should use transactions when needed
#### 4. CRUD Layer (`app/crud/`)
#### 4. Repository Layer (`app/repositories/`)
**Responsibility**: Database operations and queries
**Responsibility**: Database operations and queries — no business logic
**Key Functions**:
- Create, read, update, delete operations
- Build database queries
- Handle database errors
- Raise custom repository exceptions (`DuplicateEntryError`, `IntegrityConstraintError`)
- Manage soft deletes
- Implement pagination and filtering
**Example**:
```python
class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
"""CRUD operations for user sessions."""
class SessionRepository(RepositoryBase[UserSession, SessionCreate, SessionUpdate]):
"""Repository for user sessions — database operations only."""
def get_by_jti(self, db: Session, jti: UUID) -> Optional[UserSession]:
async def get_by_jti(self, db: AsyncSession, *, jti: str) -> UserSession | None:
"""Get session by refresh token JTI."""
try:
return (
db.query(UserSession)
.filter(UserSession.refresh_token_jti == jti)
.first()
result = await db.execute(
select(UserSession).where(UserSession.refresh_token_jti == jti)
)
except Exception as e:
logger.error(f"Error getting session by JTI: {str(e)}")
return None
return result.scalar_one_or_none()
def get_active_by_jti(
self,
db: Session,
jti: UUID
) -> Optional[UserSession]:
"""Get active session by refresh token JTI."""
session = self.get_by_jti(db, jti=jti)
if session and session.is_active and not session.is_expired:
return session
return None
def deactivate(self, db: Session, session_id: UUID) -> bool:
async def deactivate(self, db: AsyncSession, *, session_id: UUID) -> bool:
"""Deactivate a session (logout)."""
try:
session = self.get(db, id=session_id)
session = await self.get(db, id=session_id)
if not session:
return False
session.is_active = False
db.commit()
await db.commit()
logger.info(f"Session {session_id} deactivated")
return True
except Exception as e:
db.rollback()
await db.rollback()
logger.error(f"Error deactivating session: {str(e)}")
return False
```
**Rules**:
- Should NOT contain business logic
- Must handle database exceptions
- Must use parameterized queries (SQLAlchemy does this)
- Must raise custom repository exceptions (not raw `ValueError`/`IntegrityError`)
- Must use async SQLAlchemy 2.0 `select()` API (never `db.query()`)
- Should log all database errors
- Must rollback on errors
- Should use soft deletes when possible
- **Never imported directly by routes** — always called through services
#### 5. Data Layer (`app/models/` + `app/schemas/`)
@@ -546,51 +535,23 @@ SessionLocal = sessionmaker(
#### Dependency Injection Pattern
```python
def get_db() -> Generator[Session, None, None]:
async def get_db() -> AsyncGenerator[AsyncSession, None]:
"""
Database session dependency for FastAPI routes.
Async database session dependency for FastAPI routes.
Automatically commits on success, rolls back on error.
The session is passed to service methods; commit/rollback is
managed inside service or repository methods.
"""
db = SessionLocal()
try:
async with AsyncSessionLocal() as db:
yield db
finally:
db.close()
# Usage in routes
# Usage in routes — always through a service, never direct repository
@router.get("/users")
def list_users(db: Session = Depends(get_db)):
return user_crud.get_multi(db)
```
#### Context Manager Pattern
```python
@contextmanager
def transaction_scope() -> Generator[Session, None, None]:
"""
Context manager for database transactions.
Use for complex operations requiring multiple steps.
Automatically commits on success, rolls back on error.
"""
db = SessionLocal()
try:
yield db
db.commit()
except Exception:
db.rollback()
raise
finally:
db.close()
# Usage in services
def complex_operation():
with transaction_scope() as db:
user = user_crud.create(db, obj_in=user_data)
session = session_crud.create(db, session_data)
return user, session
async def list_users(
user_service: UserService = Depends(get_user_service),
db: AsyncSession = Depends(get_db),
):
return await user_service.get_users(db)
```
### Model Mixins
@@ -782,22 +743,15 @@ def get_profile(
```python
@router.delete("/sessions/{session_id}")
def revoke_session(
async def revoke_session(
session_id: UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
session_service: SessionService = Depends(get_session_service),
db: AsyncSession = Depends(get_db),
):
"""Users can only revoke their own sessions."""
session = session_crud.get(db, id=session_id)
if not session:
raise NotFoundError("Session not found")
# Check ownership
if session.user_id != current_user.id:
raise AuthorizationError("You can only revoke your own sessions")
session_crud.deactivate(db, session_id=session_id)
# SessionService verifies ownership and raises NotFoundError / AuthorizationError
await session_service.revoke_session(db, session_id=session_id, user_id=current_user.id)
return MessageResponse(success=True, message="Session revoked")
```
@@ -1092,8 +1046,8 @@ async def cleanup_expired_sessions():
Runs daily at 2 AM. Removes sessions expired for more than 30 days.
"""
try:
with transaction_scope() as db:
count = session_crud.cleanup_expired(db, keep_days=30)
async with AsyncSessionLocal() as db:
count = await session_repo.cleanup_expired(db, keep_days=30)
logger.info(f"Cleaned up {count} expired sessions")
except Exception as e:
logger.error(f"Error cleaning up sessions: {str(e)}", exc_info=True)
@@ -1110,7 +1064,7 @@ async def cleanup_expired_sessions():
│Integration │ ← API endpoint tests
│ Tests │
├─────────────┤
│ Unit │ ← CRUD, services, utilities
│ Unit │ ← repositories, services, utilities
│ Tests │
└─────────────┘
```

View File

@@ -75,15 +75,14 @@ def create_user(db: Session, user_in: UserCreate) -> User:
### 4. Code Formatting
Use automated formatters:
- **Black**: Code formatting
- **isort**: Import sorting
- **flake8**: Linting
- **Ruff**: Code formatting and linting (replaces Black, isort, flake8)
- **pyright**: Static type checking
Run before committing:
Run before committing (or use `make validate`):
```bash
black app tests
isort app tests
flake8 app tests
uv run ruff format app tests
uv run ruff check app tests
uv run pyright app
```
## Code Organization
@@ -94,19 +93,17 @@ Follow the 5-layer architecture strictly:
```
API Layer (routes/)
↓ calls
Dependencies (dependencies/)
↓ injects
↓ calls (via service injected from dependencies/services.py)
Service Layer (services/)
↓ calls
CRUD Layer (crud/)
Repository Layer (repositories/)
↓ uses
Models & Schemas (models/, schemas/)
```
**Rules:**
- Routes should NOT directly call CRUD operations (use services when business logic is needed)
- CRUD operations should NOT contain business logic
- Routes must NEVER import repositories directly — always use a service
- Services call repositories; repositories contain only database operations
- Models should NOT import from higher layers
- Each layer should only depend on the layer directly below it
@@ -125,7 +122,7 @@ from sqlalchemy.orm import Session
# 3. Local application imports
from app.api.dependencies.auth import get_current_user
from app.crud import user_crud
from app.api.dependencies.services import get_user_service
from app.models.user import User
from app.schemas.users import UserResponse, UserCreate
```
@@ -442,19 +439,19 @@ backend/app/alembic/versions/
4. **Testability**: Easy to mock and test
5. **Consistent Ordering**: Always order queries for pagination
### Use the Async CRUD Base Class
### Use the Async Repository Base Class
Always inherit from `CRUDBase` for database operations:
Always inherit from `RepositoryBase` for database operations:
```python
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.crud.base import CRUDBase
from app.repositories.base import RepositoryBase
from app.models.user import User
from app.schemas.users import UserCreate, UserUpdate
class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
"""CRUD operations for User model."""
class UserRepository(RepositoryBase[User, UserCreate, UserUpdate]):
"""Repository for User model — database operations only."""
async def get_by_email(
self,
@@ -467,7 +464,7 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
)
return result.scalar_one_or_none()
user_crud = CRUDUser(User)
user_repo = UserRepository(User)
```
**Key Points:**
@@ -476,6 +473,7 @@ user_crud = CRUDUser(User)
- Use `await db.execute()` for queries
- Use `.scalar_one_or_none()` instead of `.first()`
- Use `T | None` instead of `Optional[T]`
- Repository instances are used internally by services — never import them in routes
### Modern SQLAlchemy Patterns
@@ -563,7 +561,7 @@ async def create_user(
The database session is automatically managed by FastAPI.
Commit on success, rollback on error.
"""
return await user_crud.create(db, obj_in=user_in)
return await user_service.create_user(db, obj_in=user_in)
```
**Key Points:**
@@ -582,12 +580,11 @@ async def complex_operation(
"""
Perform multiple database operations atomically.
The session automatically commits on success or rolls back on error.
Services call repositories; commit/rollback is handled inside
each repository method.
"""
user = await user_crud.create(db, obj_in=user_data)
session = await session_crud.create(db, obj_in=session_data)
# Commit is handled by the route's dependency
user = await user_repo.create(db, obj_in=user_data)
session = await session_repo.create(db, obj_in=session_data)
return user, session
```
@@ -597,10 +594,10 @@ Prefer soft deletes over hard deletes for audit trails:
```python
# Good - Soft delete (sets deleted_at)
await user_crud.soft_delete(db, id=user_id)
await user_repo.soft_delete(db, id=user_id)
# Acceptable only when required - Hard delete
user_crud.remove(db, id=user_id)
await user_repo.remove(db, id=user_id)
```
### Query Patterns
@@ -740,9 +737,10 @@ Always implement pagination for list endpoints:
from app.schemas.common import PaginationParams, PaginatedResponse
@router.get("/users", response_model=PaginatedResponse[UserResponse])
def list_users(
async def list_users(
pagination: PaginationParams = Depends(),
db: Session = Depends(get_db)
user_service: UserService = Depends(get_user_service),
db: AsyncSession = Depends(get_db),
):
"""
List all users with pagination.
@@ -750,10 +748,8 @@ def list_users(
Default page size: 20
Maximum page size: 100
"""
users, total = user_crud.get_multi_with_total(
db,
skip=pagination.offset,
limit=pagination.limit
users, total = await user_service.get_users(
db, skip=pagination.offset, limit=pagination.limit
)
return PaginatedResponse(data=users, pagination=pagination.create_meta(total))
```
@@ -816,19 +812,17 @@ def admin_route(
pass
# Check ownership
def delete_resource(
async def delete_resource(
resource_id: UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
resource_service: ResourceService = Depends(get_resource_service),
db: AsyncSession = Depends(get_db),
):
resource = resource_crud.get(db, id=resource_id)
if not resource:
raise NotFoundError("Resource not found")
if resource.user_id != current_user.id and not current_user.is_superuser:
raise AuthorizationError("You can only delete your own resources")
resource_crud.remove(db, id=resource_id)
# Service handles ownership check and raises appropriate errors
await resource_service.delete_resource(
db, resource_id=resource_id, user_id=current_user.id,
is_superuser=current_user.is_superuser,
)
```
### Input Validation
@@ -862,9 +856,9 @@ tests/
├── api/ # Integration tests
│ ├── test_users.py
│ └── test_auth.py
├── crud/ # Unit tests for CRUD
├── models/ # Model tests
└── services/ # Service tests
├── repositories/ # Unit tests for repositories
├── services/ # Unit tests for services
└── models/ # Model tests
```
### Async Testing with pytest-asyncio
@@ -927,7 +921,7 @@ async def test_user(db_session: AsyncSession) -> User:
@pytest.mark.asyncio
async def test_get_user(db_session: AsyncSession, test_user: User):
"""Test retrieving a user by ID."""
user = await user_crud.get(db_session, id=test_user.id)
user = await user_repo.get(db_session, id=test_user.id)
assert user is not None
assert user.email == test_user.email
```

View File

@@ -616,7 +616,43 @@ def create_user(
return user
```
**Rule**: Add type hints to ALL functions. Use `mypy` to enforce type checking.
**Rule**: Add type hints to ALL functions. Use `pyright` to enforce type checking (`make type-check`).
---
---
### ❌ PITFALL #19: Importing Repositories Directly in Routes
**Issue**: Routes should never call repositories directly. The layered architecture requires all business operations to go through the service layer.
```python
# ❌ WRONG - Route bypasses service layer
from app.repositories.session import session_repo
@router.get("/sessions/me")
async def list_sessions(
current_user: User = Depends(get_current_active_user),
db: AsyncSession = Depends(get_db),
):
return await session_repo.get_user_sessions(db, user_id=current_user.id)
```
```python
# ✅ CORRECT - Route calls service injected via dependency
from app.api.dependencies.services import get_session_service
from app.services.session_service import SessionService
@router.get("/sessions/me")
async def list_sessions(
current_user: User = Depends(get_current_active_user),
session_service: SessionService = Depends(get_session_service),
db: AsyncSession = Depends(get_db),
):
return await session_service.get_user_sessions(db, user_id=current_user.id)
```
**Rule**: Routes import from `app.api.dependencies.services`, never from `app.repositories.*`. Services are the only callers of repositories.
---
@@ -649,6 +685,11 @@ Use this checklist to catch issues before code review:
- [ ] Resource ownership verification
- [ ] CORS configured (no wildcards in production)
### Architecture
- [ ] Routes never import repositories directly (only services)
- [ ] Services call repositories; repositories call database only
- [ ] New service registered in `app/api/dependencies/services.py`
### Python
- [ ] Use `==` not `is` for value comparison
- [ ] No mutable default arguments
@@ -661,21 +702,18 @@ Use this checklist to catch issues before code review:
### Pre-commit Checks
Add these to your development workflow:
Add these to your development workflow (or use `make validate`):
```bash
# Format code
black app tests
isort app tests
# Format + lint (Ruff replaces Black, isort, flake8)
uv run ruff format app tests
uv run ruff check app tests
# Type checking
mypy app --strict
# Linting
flake8 app tests
uv run pyright app
# Run tests
pytest --cov=app --cov-report=term-missing
IS_TEST=True uv run pytest --cov=app --cov-report=term-missing
# Check coverage (should be 80%+)
coverage report --fail-under=80
@@ -693,6 +731,6 @@ Add new entries when:
---
**Last Updated**: 2025-10-31
**Issues Cataloged**: 18 common pitfalls
**Last Updated**: 2026-02-28
**Issues Cataloged**: 19 common pitfalls
**Remember**: This document exists because these issues HAVE occurred. Don't skip it.

View File

@@ -8,7 +8,7 @@ This guide walks through implementing a complete feature using the **User Sessio
- [Implementation Steps](#implementation-steps)
- [Step 1: Design the Database Model](#step-1-design-the-database-model)
- [Step 2: Create Pydantic Schemas](#step-2-create-pydantic-schemas)
- [Step 3: Implement CRUD Operations](#step-3-implement-crud-operations)
- [Step 3: Implement Repository](#step-3-implement-repository)
- [Step 4: Create API Endpoints](#step-4-create-api-endpoints)
- [Step 5: Integrate with Existing Features](#step-5-integrate-with-existing-features)
- [Step 6: Add Background Jobs](#step-6-add-background-jobs)
@@ -204,8 +204,8 @@ Follow the standard pattern:
```
SessionBase (common fields)
├── SessionCreate (internal: CRUD operations)
├── SessionUpdate (internal: CRUD operations)
├── SessionCreate (internal: repository operations)
├── SessionUpdate (internal: repository operations)
└── SessionResponse (external: API responses)
```
@@ -240,7 +240,7 @@ class SessionCreate(SessionBase):
"""
Schema for creating a new session (internal use).
Used by CRUD operations, not exposed to API.
Used by repository operations, not exposed to API.
Contains all fields needed to create a session.
"""
user_id: UUID
@@ -344,37 +344,37 @@ class DeviceInfo(BaseModel):
5. **OpenAPI Documentation**: `json_schema_extra` provides examples in API docs
6. **Type Safety**: Comprehensive type hints for all fields
### Step 3: Implement CRUD Operations
### Step 3: Implement Repository
**File**: `app/crud/session.py`
**File**: `app/repositories/session.py`
CRUD layer handles all database operations. No business logic here!
The repository layer handles all database operations. No business logic here — that belongs in services!
#### 3.1 Extend the Base CRUD Class
#### 3.1 Extend the Base Repository Class
```python
"""
CRUD operations for user sessions.
Repository for user sessions.
"""
from datetime import datetime, timezone, timedelta
from typing import List, Optional
from uuid import UUID
from sqlalchemy.orm import Session
from sqlalchemy import and_
from sqlalchemy import and_, select, update
from sqlalchemy.ext.asyncio import AsyncSession
import logging
from app.crud.base import CRUDBase
from app.repositories.base import RepositoryBase
from app.models.user_session import UserSession
from app.schemas.sessions import SessionCreate, SessionUpdate
logger = logging.getLogger(__name__)
class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
class SessionRepository(RepositoryBase[UserSession, SessionCreate, SessionUpdate]):
"""
CRUD operations for user sessions.
Repository for user sessions.
Inherits standard operations from CRUDBase:
Inherits standard operations from RepositoryBase:
- get(db, id) - Get by ID
- get_multi(db, skip, limit) - List with pagination
- create(db, obj_in) - Create new session
@@ -382,111 +382,62 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
- remove(db, id) - Delete session
"""
# Custom query methods
# --------------------
def get_by_jti(self, db: Session, *, jti: str) -> Optional[UserSession]:
async def get_by_jti(self, db: AsyncSession, *, jti: str) -> UserSession | None:
"""
Get session by refresh token JTI.
Used during token refresh to find the corresponding session.
Args:
db: Database session
jti: Refresh token JWT ID
Returns:
UserSession if found, None otherwise
"""
try:
return db.query(UserSession).filter(
UserSession.refresh_token_jti == jti
).first()
except Exception as e:
logger.error(f"Error getting session by JTI {jti}: {str(e)}")
raise
result = await db.execute(
select(UserSession).where(UserSession.refresh_token_jti == jti)
)
return result.scalar_one_or_none()
def get_active_by_jti(self, db: Session, *, jti: str) -> Optional[UserSession]:
"""
Get active session by refresh token JTI.
Only returns the session if it's currently active.
Args:
db: Database session
jti: Refresh token JWT ID
Returns:
Active UserSession if found, None otherwise
"""
try:
return db.query(UserSession).filter(
async def get_active_by_jti(self, db: AsyncSession, *, jti: str) -> UserSession | None:
"""Get active (non-expired) session by refresh token JTI."""
result = await db.execute(
select(UserSession).where(
and_(
UserSession.refresh_token_jti == jti,
UserSession.is_active == True
UserSession.is_active.is_(True),
)
).first()
except Exception as e:
logger.error(f"Error getting active session by JTI {jti}: {str(e)}")
raise
)
)
session = result.scalar_one_or_none()
if session and not session.is_expired:
return session
return None
def get_user_sessions(
async def get_user_sessions(
self,
db: Session,
db: AsyncSession,
*,
user_id: str,
active_only: bool = True
) -> List[UserSession]:
user_id: UUID,
active_only: bool = True,
) -> list[UserSession]:
"""
Get all sessions for a user.
Args:
db: Database session
user_id: User ID
active_only: If True, return only active sessions
Returns:
List of UserSession objects, ordered by most recently used
Get all sessions for a user, ordered by most recently used.
"""
try:
# Convert user_id string to UUID if needed
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
query = db.query(UserSession).filter(UserSession.user_id == user_uuid)
query = select(UserSession).where(UserSession.user_id == user_id)
if active_only:
query = query.filter(UserSession.is_active == True)
query = query.where(UserSession.is_active.is_(True))
query = query.order_by(UserSession.last_used_at.desc())
result = await db.execute(query)
return list(result.scalars().all())
# Order by most recently used first
return query.order_by(UserSession.last_used_at.desc()).all()
except Exception as e:
logger.error(f"Error getting sessions for user {user_id}: {str(e)}")
raise
# Creation methods
# ----------------
def create_session(
async def create_session(
self,
db: Session,
db: AsyncSession,
*,
obj_in: SessionCreate
obj_in: SessionCreate,
) -> UserSession:
"""
Create a new user session.
Args:
db: Database session
obj_in: SessionCreate schema with session data
Returns:
Created UserSession
Raises:
ValueError: If session creation fails
DuplicateEntryError: If a session with the same JTI already exists
"""
try:
# Create model instance from schema
db_obj = UserSession(
user_id=obj_in.user_id,
refresh_token_jti=obj_in.refresh_token_jti,
@@ -501,248 +452,93 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
location_country=obj_in.location_country,
)
db.add(db_obj)
db.commit()
db.refresh(db_obj)
db_obj.add(db_obj)
await db.commit()
await db.refresh(db_obj)
logger.info(
f"Session created for user {obj_in.user_id} from {obj_in.device_name}"
f"(IP: {obj_in.ip_address})"
)
return db_obj
except Exception as e:
db.rollback()
await db.rollback()
logger.error(f"Error creating session: {str(e)}", exc_info=True)
raise ValueError(f"Failed to create session: {str(e)}")
# Update methods
# --------------
def deactivate(self, db: Session, *, session_id: str) -> Optional[UserSession]:
"""
Deactivate a session (logout from device).
Args:
db: Database session
session_id: Session UUID
Returns:
Deactivated UserSession if found, None otherwise
"""
try:
session = self.get(db, id=session_id)
if not session:
logger.warning(f"Session {session_id} not found for deactivation")
return None
session.is_active = False
db.add(session)
db.commit()
db.refresh(session)
logger.info(
f"Session {session_id} deactivated for user {session.user_id} "
f"({session.device_name})"
)
return session
except Exception as e:
db.rollback()
logger.error(f"Error deactivating session {session_id}: {str(e)}")
raise
def deactivate_all_user_sessions(
async def deactivate(self, db: AsyncSession, *, session_id: UUID) -> UserSession | None:
"""Deactivate a session (logout from device)."""
session = await self.get(db, id=session_id)
if not session:
return None
session.is_active = False
await db.commit()
await db.refresh(session)
logger.info(f"Session {session_id} deactivated ({session.device_name})")
return session
async def deactivate_all_user_sessions(
self,
db: Session,
db: AsyncSession,
*,
user_id: str
user_id: UUID,
) -> int:
"""
Deactivate all active sessions for a user (logout from all devices).
Uses bulk update for efficiency.
Args:
db: Database session
user_id: User ID
Returns:
Number of sessions deactivated
Uses a bulk UPDATE for efficiency — no N+1 queries.
"""
try:
# Convert user_id string to UUID if needed
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
# Bulk update query
count = db.query(UserSession).filter(
and_(
UserSession.user_id == user_uuid,
UserSession.is_active == True
)
).update({"is_active": False})
db.commit()
logger.info(f"Deactivated {count} sessions for user {user_id}")
return count
except Exception as e:
db.rollback()
logger.error(f"Error deactivating all sessions for user {user_id}: {str(e)}")
raise
def update_last_used(
self,
db: Session,
*,
session: UserSession
) -> UserSession:
"""
Update the last_used_at timestamp for a session.
Called when a refresh token is used.
Args:
db: Database session
session: UserSession object
Returns:
Updated UserSession
"""
try:
session.last_used_at = datetime.now(timezone.utc)
db.add(session)
db.commit()
db.refresh(session)
return session
except Exception as e:
db.rollback()
logger.error(f"Error updating last_used for session {session.id}: {str(e)}")
raise
def update_refresh_token(
self,
db: Session,
*,
session: UserSession,
new_jti: str,
new_expires_at: datetime
) -> UserSession:
"""
Update session with new refresh token JTI and expiration.
Called during token refresh (token rotation).
Args:
db: Database session
session: UserSession object
new_jti: New refresh token JTI
new_expires_at: New expiration datetime
Returns:
Updated UserSession
"""
try:
session.refresh_token_jti = new_jti
session.expires_at = new_expires_at
session.last_used_at = datetime.now(timezone.utc)
db.add(session)
db.commit()
db.refresh(session)
return session
except Exception as e:
db.rollback()
logger.error(f"Error updating refresh token for session {session.id}: {str(e)}")
raise
# Cleanup methods
# ---------------
def cleanup_expired(self, db: Session, *, keep_days: int = 30) -> int:
"""
Clean up expired sessions.
Deletes sessions that are:
- Expired (expires_at < now) AND inactive
- Older than keep_days (for audit trail)
Args:
db: Database session
keep_days: Keep inactive sessions for this many days
Returns:
Number of sessions deleted
"""
try:
cutoff_date = datetime.now(timezone.utc) - timedelta(days=keep_days)
count = db.query(UserSession).filter(
and_(
UserSession.is_active == False,
UserSession.expires_at < datetime.now(timezone.utc),
UserSession.created_at < cutoff_date
)
).delete()
db.commit()
if count > 0:
logger.info(f"Cleaned up {count} expired sessions")
return count
except Exception as e:
db.rollback()
logger.error(f"Error cleaning up expired sessions: {str(e)}")
raise
# Utility methods
# ---------------
def get_user_session_count(self, db: Session, *, user_id: str) -> int:
"""
Get count of active sessions for a user.
Useful for session limits or security monitoring.
Args:
db: Database session
user_id: User ID
Returns:
Number of active sessions
"""
try:
return db.query(UserSession).filter(
result = await db.execute(
update(UserSession)
.where(
and_(
UserSession.user_id == user_id,
UserSession.is_active == True
UserSession.is_active.is_(True),
)
).count()
except Exception as e:
logger.error(f"Error counting sessions for user {user_id}: {str(e)}")
raise
)
.values(is_active=False)
)
await db.commit()
count = result.rowcount
logger.info(f"Deactivated {count} sessions for user {user_id}")
return count
async def cleanup_expired(self, db: AsyncSession, *, keep_days: int = 30) -> int:
"""
Hard-delete inactive sessions older than keep_days.
Returns the number of sessions deleted.
"""
cutoff_date = datetime.now(timezone.utc) - timedelta(days=keep_days)
result = await db.execute(
select(UserSession).where(
and_(
UserSession.is_active.is_(False),
UserSession.expires_at < datetime.now(timezone.utc),
UserSession.created_at < cutoff_date,
)
)
)
sessions = list(result.scalars().all())
for s in sessions:
await db.delete(s)
await db.commit()
if sessions:
logger.info(f"Cleaned up {len(sessions)} expired sessions")
return len(sessions)
# Create singleton instance
# This is the instance that will be imported and used throughout the app
session = CRUDSession(UserSession)
# Singleton instance — used by services, never imported directly in routes
session_repo = SessionRepository(UserSession)
```
**Key Patterns**:
1. **Error Handling**: Every method has try/except with rollback
2. **Logging**: Log all significant actions (create, delete, errors)
3. **Type Safety**: Full type hints for parameters and returns
4. **Docstrings**: Document what each method does, args, returns, raises
5. **Bulk Operations**: Use `query().update()` for efficiency when updating many rows
6. **UUID Handling**: Convert string UUIDs to UUID objects when needed
7. **Ordering**: Return results in a logical order (most recent first)
8. **Singleton Pattern**: Create one instance to be imported elsewhere
1. **Async everywhere**: All methods use `async def` and `await`
2. **Modern SQLAlchemy**: `select()` API, never `db.query()`
3. **Bulk updates**: Use `update()` statement for multi-row changes (no N+1)
4. **Error handling**: `try/except` with `await db.rollback()` in mutating methods
5. **Logging**: Log all significant actions (create, delete, errors)
6. **Type safety**: Full type hints; `UUID` not raw `str` for IDs
7. **Singleton pattern**: One module-level instance used by services
### Step 4: Create API Endpoints
@@ -772,14 +568,15 @@ from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, status, Request
from slowapi import Limiter
from slowapi.util import get_remote_address
from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.dependencies.auth import get_current_user
from app.api.dependencies.services import get_session_service
from app.core.database import get_db
from app.models.user import User
from app.schemas.sessions import SessionResponse, SessionListResponse
from app.schemas.common import MessageResponse
from app.crud.session import session as session_crud
from app.services.session_service import SessionService
from app.core.exceptions import NotFoundError, AuthorizationError, ErrorCode
router = APIRouter()
@@ -803,61 +600,21 @@ limiter = Limiter(key_func=get_remote_address)
operation_id="list_my_sessions"
)
@limiter.limit("30/minute")
def list_my_sessions(
async def list_my_sessions(
request: Request,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
session_service: SessionService = Depends(get_session_service),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
List all active sessions for the current user.
Args:
request: FastAPI request object (for rate limiting)
current_user: Current authenticated user (injected)
db: Database session (injected)
Returns:
SessionListResponse with list of active sessions
"""
try:
# Get all active sessions for user
sessions = session_crud.get_user_sessions(
db,
user_id=str(current_user.id),
active_only=True
)
# Convert to response format
session_responses = []
for idx, s in enumerate(sessions):
session_response = SessionResponse(
id=s.id,
device_name=s.device_name,
device_id=s.device_id,
ip_address=s.ip_address,
location_city=s.location_city,
location_country=s.location_country,
last_used_at=s.last_used_at,
created_at=s.created_at,
expires_at=s.expires_at,
# Mark the most recently used session as current
is_current=(idx == 0)
)
session_responses.append(session_response)
logger.info(f"User {current_user.id} listed {len(session_responses)} active sessions")
return SessionListResponse(
sessions=session_responses,
total=len(session_responses)
)
except Exception as e:
logger.error(f"Error listing sessions for user {current_user.id}: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve sessions"
"""List all active sessions for the current user."""
sessions = await session_service.get_user_sessions(
db, user_id=current_user.id, active_only=True
)
session_responses = [
SessionResponse.model_validate(s) | {"is_current": idx == 0}
for idx, s in enumerate(sessions)
]
return SessionListResponse(sessions=session_responses, total=len(session_responses))
@router.delete(
@@ -876,70 +633,25 @@ def list_my_sessions(
operation_id="revoke_session"
)
@limiter.limit("10/minute")
def revoke_session(
async def revoke_session(
request: Request,
session_id: UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
session_service: SessionService = Depends(get_session_service),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Revoke a specific session by ID.
Args:
request: FastAPI request object (for rate limiting)
session_id: UUID of the session to revoke
current_user: Current authenticated user (injected)
db: Database session (injected)
Returns:
MessageResponse with success message
Raises:
NotFoundError: If session doesn't exist
AuthorizationError: If session belongs to another user
The service verifies ownership and raises NotFoundError /
AuthorizationError which are handled by global exception handlers.
"""
try:
# Get the session
session = session_crud.get(db, id=str(session_id))
if not session:
raise NotFoundError(
message=f"Session {session_id} not found",
error_code=ErrorCode.NOT_FOUND
device_name = await session_service.revoke_session(
db, session_id=session_id, user_id=current_user.id
)
# Verify session belongs to current user (authorization check)
if str(session.user_id) != str(current_user.id):
logger.warning(
f"User {current_user.id} attempted to revoke session {session_id} "
f"belonging to user {session.user_id}"
)
raise AuthorizationError(
message="You can only revoke your own sessions",
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS
)
# Deactivate the session
session_crud.deactivate(db, session_id=str(session_id))
logger.info(
f"User {current_user.id} revoked session {session_id} "
f"({session.device_name})"
)
return MessageResponse(
success=True,
message=f"Session revoked: {session.device_name or 'Unknown device'}"
)
except (NotFoundError, AuthorizationError):
# Re-raise custom exceptions (they'll be handled by global handlers)
raise
except Exception as e:
logger.error(f"Error revoking session {session_id}: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to revoke session"
message=f"Session revoked: {device_name or 'Unknown device'}"
)
@@ -958,55 +670,20 @@ def revoke_session(
operation_id="cleanup_expired_sessions"
)
@limiter.limit("5/minute")
def cleanup_expired_sessions(
async def cleanup_expired_sessions(
request: Request,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
session_service: SessionService = Depends(get_session_service),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Cleanup expired sessions for the current user.
Args:
request: FastAPI request object (for rate limiting)
current_user: Current authenticated user (injected)
db: Database session (injected)
Returns:
MessageResponse with count of sessions cleaned
"""
try:
from datetime import datetime, timezone
# Get all sessions for user (including inactive)
all_sessions = session_crud.get_user_sessions(
db,
user_id=str(current_user.id),
active_only=False
"""Cleanup expired sessions for the current user."""
deleted_count = await session_service.cleanup_user_expired_sessions(
db, user_id=current_user.id
)
# Delete expired and inactive sessions
deleted_count = 0
for s in all_sessions:
if not s.is_active and s.expires_at < datetime.now(timezone.utc):
db.delete(s)
deleted_count += 1
db.commit()
logger.info(f"User {current_user.id} cleaned up {deleted_count} expired sessions")
return MessageResponse(
success=True,
message=f"Cleaned up {deleted_count} expired sessions"
)
except Exception as e:
logger.error(f"Error cleaning up sessions for user {current_user.id}: {str(e)}", exc_info=True)
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to cleanup sessions"
)
```
**Key Patterns**:
@@ -1079,59 +756,25 @@ Session management needs to be integrated into the authentication flow.
```python
from app.utils.device import extract_device_info
from app.crud.session import session as session_crud
from app.schemas.sessions import SessionCreate
from app.api.dependencies.services import get_auth_service
from app.services.auth_service import AuthService
@router.post("/login")
async def login(
request: Request,
credentials: OAuth2PasswordRequestForm = Depends(),
db: Session = Depends(get_db)
auth_service: AuthService = Depends(get_auth_service),
db: AsyncSession = Depends(get_db),
):
"""Authenticate user and create session."""
# 1. Validate credentials
user = user_crud.get_by_email(db, email=credentials.username)
if not user or not verify_password(credentials.password, user.hashed_password):
raise AuthenticationError("Invalid credentials")
if not user.is_active:
raise AuthenticationError("Account is inactive")
# 2. Extract device information from request
device_info = extract_device_info(request)
# 3. Generate tokens
jti = str(uuid.uuid4()) # Generate JTI for refresh token
access_token = create_access_token(subject=str(user.id))
refresh_token = create_refresh_token(subject=str(user.id), jti=jti)
# 4. Create session record
from datetime import datetime, timezone, timedelta
session_data = SessionCreate(
user_id=user.id,
refresh_token_jti=jti,
device_name=device_info.device_name,
device_id=device_info.device_id,
ip_address=device_info.ip_address,
user_agent=device_info.user_agent,
last_used_at=datetime.now(timezone.utc),
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
location_city=device_info.location_city,
location_country=device_info.location_country,
# All business logic (validate credentials, create session, generate tokens)
# is delegated to AuthService which calls the appropriate repositories.
return await auth_service.login(
db,
email=credentials.username,
password=credentials.password,
request=request,
)
session_crud.create_session(db, obj_in=session_data)
logger.info(f"User {user.email} logged in from {device_info.device_name}")
return {
"access_token": access_token,
"refresh_token": refresh_token,
"token_type": "bearer",
"user": UserResponse.model_validate(user)
}
```
#### 5.2 Create Device Info Utility
@@ -1193,89 +836,35 @@ def extract_device_info(request: Request) -> DeviceInfo:
```python
@router.post("/refresh")
def refresh_token(
async def refresh_token(
refresh_request: RefreshRequest,
db: Session = Depends(get_db)
auth_service: AuthService = Depends(get_auth_service),
db: AsyncSession = Depends(get_db),
):
"""Refresh access token using refresh token."""
try:
# 1. Decode and validate refresh token
payload = decode_token(refresh_request.refresh_token)
if payload.get("type") != "refresh":
raise AuthenticationError("Invalid token type")
user_id = UUID(payload.get("sub"))
jti = payload.get("jti")
# 2. Find and validate session
session = session_crud.get_active_by_jti(db, jti=jti)
if not session:
raise AuthenticationError("Session not found or expired")
if session.user_id != user_id:
raise AuthenticationError("Token mismatch")
# 3. Generate new tokens (token rotation)
new_jti = str(uuid.uuid4())
new_access_token = create_access_token(subject=str(user_id))
new_refresh_token = create_refresh_token(subject=str(user_id), jti=new_jti)
# 4. Update session with new JTI
session_crud.update_refresh_token(
db,
session=session,
new_jti=new_jti,
new_expires_at=datetime.now(timezone.utc) + timedelta(days=7)
# AuthService handles token validation, session lookup, token rotation
return await auth_service.refresh_tokens(
db, refresh_token=refresh_request.refresh_token
)
logger.info(f"Tokens refreshed for user {user_id}")
return {
"access_token": new_access_token,
"refresh_token": new_refresh_token,
"token_type": "bearer"
}
except Exception as e:
logger.error(f"Token refresh failed: {str(e)}")
raise AuthenticationError("Failed to refresh token")
```
#### 5.4 Update Logout Endpoint
```python
@router.post("/logout")
def logout(
async def logout(
logout_request: LogoutRequest,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
auth_service: AuthService = Depends(get_auth_service),
db: AsyncSession = Depends(get_db),
):
"""Logout from current device."""
try:
# Decode refresh token to get JTI
payload = decode_token(logout_request.refresh_token)
jti = payload.get("jti")
# Find and deactivate session
session = session_crud.get_by_jti(db, jti=jti)
if session and session.user_id == current_user.id:
session_crud.deactivate(db, session_id=str(session.id))
logger.info(f"User {current_user.id} logged out from {session.device_name}")
return MessageResponse(
success=True,
message="Logged out successfully"
await auth_service.logout(
db,
refresh_token=logout_request.refresh_token,
user_id=current_user.id,
)
except Exception as e:
logger.error(f"Logout failed: {str(e)}")
# Even if cleanup fails, return success (user intended to logout)
return MessageResponse(success=True, message="Logged out")
return MessageResponse(success=True, message="Logged out successfully")
```
### Step 6: Add Background Jobs
@@ -1287,8 +876,8 @@ def logout(
Background job for cleaning up expired sessions.
"""
import logging
from app.core.database import SessionLocal
from app.crud.session import session as session_crud
from app.core.database import AsyncSessionLocal
from app.repositories.session import session_repo
logger = logging.getLogger(__name__)
@@ -1302,14 +891,12 @@ async def cleanup_expired_sessions():
- Inactive (is_active = False)
- Older than 30 days (for audit trail)
"""
db = SessionLocal()
async with AsyncSessionLocal() as db:
try:
count = session_crud.cleanup_expired(db, keep_days=30)
count = await session_repo.cleanup_expired(db, keep_days=30)
logger.info(f"Background cleanup: Removed {count} expired sessions")
except Exception as e:
logger.error(f"Error in session cleanup job: {str(e)}", exc_info=True)
finally:
db.close()
```
**Register in** `app/main.py`:
@@ -1679,7 +1266,8 @@ You've now implemented a complete feature! Here's what was created:
**Files Created/Modified**:
1. `app/models/user_session.py` - Database model
2. `app/schemas/sessions.py` - Pydantic schemas
3. `app/crud/session.py` - CRUD operations
3. `app/repositories/session.py` - Repository (data access)
4. `app/services/session_service.py` - Service (business logic)
4. `app/api/routes/sessions.py` - API endpoints
5. `app/utils/device.py` - Device detection utility
6. `app/services/session_cleanup.py` - Background job
@@ -1715,7 +1303,7 @@ You've now implemented a complete feature! Here's what was created:
### Don'ts
1. **Don't Mix Layers**: Keep business logic out of CRUD, database ops out of routes
1. **Don't Mix Layers**: Keep business logic in services, database ops in repositories, routing in routes
2. **Don't Expose Internals**: Never return sensitive data in API responses
3. **Don't Trust Input**: Always validate and sanitize user input
4. **Don't Ignore Errors**: Always handle exceptions properly
@@ -1733,7 +1321,9 @@ When implementing a new feature, use this checklist:
- [ ] Design database schema
- [ ] Create SQLAlchemy model
- [ ] Design Pydantic schemas (Create, Update, Response)
- [ ] Implement CRUD operations
- [ ] Implement repository (data access)
- [ ] Implement service (business logic)
- [ ] Register service in `app/api/dependencies/services.py`
- [ ] Create API endpoints
- [ ] Add authentication/authorization
- [ ] Implement rate limiting