refactor(docs): update architecture to reflect repository migration
- Rename CRUD layer to Repository layer throughout architecture documentation. - Update dependency injection examples to use repository classes. - Add async SQLAlchemy pattern for Repository methods (`select()` and transactions). - Replace CRUD references in FEATURE_EXAMPLE.md with Repository-focused implementation details. - Highlight repository class responsibilities and remove outdated CRUD patterns.
This commit is contained in:
@@ -117,7 +117,8 @@ backend/
|
|||||||
│ ├── api/ # API layer
|
│ ├── api/ # API layer
|
||||||
│ │ ├── dependencies/ # Dependency injection
|
│ │ ├── dependencies/ # Dependency injection
|
||||||
│ │ │ ├── auth.py # Authentication dependencies
|
│ │ │ ├── auth.py # Authentication dependencies
|
||||||
│ │ │ └── permissions.py # Authorization dependencies
|
│ │ │ ├── permissions.py # Authorization dependencies
|
||||||
|
│ │ │ └── services.py # Service singleton injection
|
||||||
│ │ ├── routes/ # API endpoints
|
│ │ ├── routes/ # API endpoints
|
||||||
│ │ │ ├── auth.py # Authentication routes
|
│ │ │ ├── auth.py # Authentication routes
|
||||||
│ │ │ ├── users.py # User management routes
|
│ │ │ ├── users.py # User management routes
|
||||||
@@ -131,13 +132,14 @@ backend/
|
|||||||
│ │ ├── config.py # Application configuration
|
│ │ ├── config.py # Application configuration
|
||||||
│ │ ├── database.py # Database connection
|
│ │ ├── database.py # Database connection
|
||||||
│ │ ├── exceptions.py # Custom exception classes
|
│ │ ├── exceptions.py # Custom exception classes
|
||||||
|
│ │ ├── repository_exceptions.py # Repository-level exception hierarchy
|
||||||
│ │ └── middleware.py # Custom middleware
|
│ │ └── middleware.py # Custom middleware
|
||||||
│ │
|
│ │
|
||||||
│ ├── crud/ # Database operations
|
│ ├── repositories/ # Data access layer
|
||||||
│ │ ├── base.py # Generic CRUD base class
|
│ │ ├── base.py # Generic repository base class
|
||||||
│ │ ├── user.py # User CRUD operations
|
│ │ ├── user.py # User repository
|
||||||
│ │ ├── session.py # Session CRUD operations
|
│ │ ├── session.py # Session repository
|
||||||
│ │ └── organization.py # Organization CRUD
|
│ │ └── organization.py # Organization repository
|
||||||
│ │
|
│ │
|
||||||
│ ├── models/ # SQLAlchemy models
|
│ ├── models/ # SQLAlchemy models
|
||||||
│ │ ├── base.py # Base model with mixins
|
│ │ ├── base.py # Base model with mixins
|
||||||
@@ -153,8 +155,11 @@ backend/
|
|||||||
│ │ ├── sessions.py # Session schemas
|
│ │ ├── sessions.py # Session schemas
|
||||||
│ │ └── organizations.py # Organization schemas
|
│ │ └── organizations.py # Organization schemas
|
||||||
│ │
|
│ │
|
||||||
│ ├── services/ # Business logic
|
│ ├── services/ # Business logic layer
|
||||||
│ │ ├── auth_service.py # Authentication service
|
│ │ ├── auth_service.py # Authentication service
|
||||||
|
│ │ ├── user_service.py # User management service
|
||||||
|
│ │ ├── session_service.py # Session management service
|
||||||
|
│ │ ├── organization_service.py # Organization service
|
||||||
│ │ ├── email_service.py # Email service
|
│ │ ├── email_service.py # Email service
|
||||||
│ │ └── session_cleanup.py # Background cleanup
|
│ │ └── session_cleanup.py # Background cleanup
|
||||||
│ │
|
│ │
|
||||||
@@ -168,9 +173,9 @@ backend/
|
|||||||
│
|
│
|
||||||
├── tests/ # Test suite
|
├── tests/ # Test suite
|
||||||
│ ├── api/ # Integration tests
|
│ ├── api/ # Integration tests
|
||||||
│ ├── crud/ # CRUD tests
|
│ ├── repositories/ # Repository unit tests
|
||||||
|
│ ├── services/ # Service unit tests
|
||||||
│ ├── models/ # Model tests
|
│ ├── models/ # Model tests
|
||||||
│ ├── services/ # Service tests
|
|
||||||
│ └── conftest.py # Test configuration
|
│ └── conftest.py # Test configuration
|
||||||
│
|
│
|
||||||
├── docs/ # Documentation
|
├── docs/ # Documentation
|
||||||
@@ -214,11 +219,11 @@ The application follows a strict 5-layer architecture:
|
|||||||
└──────────────────────────┬──────────────────────────────────┘
|
└──────────────────────────┬──────────────────────────────────┘
|
||||||
│ calls
|
│ calls
|
||||||
┌──────────────────────────▼──────────────────────────────────┐
|
┌──────────────────────────▼──────────────────────────────────┐
|
||||||
│ CRUD Layer (crud/) │
|
│ Repository Layer (repositories/) │
|
||||||
│ - Database operations │
|
│ - Database operations │
|
||||||
│ - Query building │
|
│ - Query building │
|
||||||
│ - Transaction management │
|
│ - Custom repository exceptions │
|
||||||
│ - Error handling │
|
│ - No business logic │
|
||||||
└──────────────────────────┬──────────────────────────────────┘
|
└──────────────────────────┬──────────────────────────────────┘
|
||||||
│ uses
|
│ uses
|
||||||
┌──────────────────────────▼──────────────────────────────────┐
|
┌──────────────────────────▼──────────────────────────────────┐
|
||||||
@@ -262,7 +267,7 @@ async def get_current_user_info(
|
|||||||
|
|
||||||
**Rules**:
|
**Rules**:
|
||||||
- Should NOT contain business logic
|
- Should NOT contain business logic
|
||||||
- Should NOT directly perform database operations (use CRUD or services)
|
- Should NOT directly call repositories (use services injected via `dependencies/services.py`)
|
||||||
- Must validate all input via Pydantic schemas
|
- Must validate all input via Pydantic schemas
|
||||||
- Must specify response models
|
- Must specify response models
|
||||||
- Should apply appropriate rate limits
|
- Should apply appropriate rate limits
|
||||||
@@ -279,9 +284,9 @@ async def get_current_user_info(
|
|||||||
|
|
||||||
**Example**:
|
**Example**:
|
||||||
```python
|
```python
|
||||||
def get_current_user(
|
async def get_current_user(
|
||||||
token: str = Depends(oauth2_scheme),
|
token: str = Depends(oauth2_scheme),
|
||||||
db: Session = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
) -> User:
|
) -> User:
|
||||||
"""
|
"""
|
||||||
Extract and validate user from JWT token.
|
Extract and validate user from JWT token.
|
||||||
@@ -295,7 +300,7 @@ def get_current_user(
|
|||||||
except Exception:
|
except Exception:
|
||||||
raise AuthenticationError("Invalid authentication credentials")
|
raise AuthenticationError("Invalid authentication credentials")
|
||||||
|
|
||||||
user = user_crud.get(db, id=user_id)
|
user = await user_repo.get(db, id=user_id)
|
||||||
if not user:
|
if not user:
|
||||||
raise AuthenticationError("User not found")
|
raise AuthenticationError("User not found")
|
||||||
|
|
||||||
@@ -313,7 +318,7 @@ def get_current_user(
|
|||||||
**Responsibility**: Implement complex business logic
|
**Responsibility**: Implement complex business logic
|
||||||
|
|
||||||
**Key Functions**:
|
**Key Functions**:
|
||||||
- Orchestrate multiple CRUD operations
|
- Orchestrate multiple repository operations
|
||||||
- Implement business rules
|
- Implement business rules
|
||||||
- Handle external service integration
|
- Handle external service integration
|
||||||
- Coordinate transactions
|
- Coordinate transactions
|
||||||
@@ -323,9 +328,9 @@ def get_current_user(
|
|||||||
class AuthService:
|
class AuthService:
|
||||||
"""Authentication service with business logic."""
|
"""Authentication service with business logic."""
|
||||||
|
|
||||||
def login(
|
async def login(
|
||||||
self,
|
self,
|
||||||
db: Session,
|
db: AsyncSession,
|
||||||
email: str,
|
email: str,
|
||||||
password: str,
|
password: str,
|
||||||
request: Request
|
request: Request
|
||||||
@@ -339,8 +344,8 @@ class AuthService:
|
|||||||
3. Generate tokens
|
3. Generate tokens
|
||||||
4. Return tokens and user info
|
4. Return tokens and user info
|
||||||
"""
|
"""
|
||||||
# Validate credentials
|
# Validate credentials via repository
|
||||||
user = user_crud.get_by_email(db, email=email)
|
user = await user_repo.get_by_email(db, email=email)
|
||||||
if not user or not verify_password(password, user.hashed_password):
|
if not user or not verify_password(password, user.hashed_password):
|
||||||
raise AuthenticationError("Invalid credentials")
|
raise AuthenticationError("Invalid credentials")
|
||||||
|
|
||||||
@@ -350,11 +355,10 @@ class AuthService:
|
|||||||
# Extract device info
|
# Extract device info
|
||||||
device_info = extract_device_info(request)
|
device_info = extract_device_info(request)
|
||||||
|
|
||||||
# Create session
|
# Create session via repository
|
||||||
session = session_crud.create_session(
|
session = await session_repo.create(
|
||||||
db,
|
db,
|
||||||
user_id=user.id,
|
obj_in=SessionCreate(user_id=user.id, **device_info)
|
||||||
device_info=device_info
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Generate tokens
|
# Generate tokens
|
||||||
@@ -373,75 +377,60 @@ class AuthService:
|
|||||||
|
|
||||||
**Rules**:
|
**Rules**:
|
||||||
- Contains business logic, not just data operations
|
- Contains business logic, not just data operations
|
||||||
- Can call multiple CRUD operations
|
- Can call multiple repository operations
|
||||||
- Should handle complex workflows
|
- Should handle complex workflows
|
||||||
- Must maintain data consistency
|
- Must maintain data consistency
|
||||||
- Should use transactions when needed
|
- Should use transactions when needed
|
||||||
|
|
||||||
#### 4. CRUD Layer (`app/crud/`)
|
#### 4. Repository Layer (`app/repositories/`)
|
||||||
|
|
||||||
**Responsibility**: Database operations and queries
|
**Responsibility**: Database operations and queries — no business logic
|
||||||
|
|
||||||
**Key Functions**:
|
**Key Functions**:
|
||||||
- Create, read, update, delete operations
|
- Create, read, update, delete operations
|
||||||
- Build database queries
|
- Build database queries
|
||||||
- Handle database errors
|
- Raise custom repository exceptions (`DuplicateEntryError`, `IntegrityConstraintError`)
|
||||||
- Manage soft deletes
|
- Manage soft deletes
|
||||||
- Implement pagination and filtering
|
- Implement pagination and filtering
|
||||||
|
|
||||||
**Example**:
|
**Example**:
|
||||||
```python
|
```python
|
||||||
class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
class SessionRepository(RepositoryBase[UserSession, SessionCreate, SessionUpdate]):
|
||||||
"""CRUD operations for user sessions."""
|
"""Repository for user sessions — database operations only."""
|
||||||
|
|
||||||
def get_by_jti(self, db: Session, jti: UUID) -> Optional[UserSession]:
|
async def get_by_jti(self, db: AsyncSession, *, jti: str) -> UserSession | None:
|
||||||
"""Get session by refresh token JTI."""
|
"""Get session by refresh token JTI."""
|
||||||
try:
|
result = await db.execute(
|
||||||
return (
|
select(UserSession).where(UserSession.refresh_token_jti == jti)
|
||||||
db.query(UserSession)
|
|
||||||
.filter(UserSession.refresh_token_jti == jti)
|
|
||||||
.first()
|
|
||||||
)
|
)
|
||||||
except Exception as e:
|
return result.scalar_one_or_none()
|
||||||
logger.error(f"Error getting session by JTI: {str(e)}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_active_by_jti(
|
async def deactivate(self, db: AsyncSession, *, session_id: UUID) -> bool:
|
||||||
self,
|
|
||||||
db: Session,
|
|
||||||
jti: UUID
|
|
||||||
) -> Optional[UserSession]:
|
|
||||||
"""Get active session by refresh token JTI."""
|
|
||||||
session = self.get_by_jti(db, jti=jti)
|
|
||||||
if session and session.is_active and not session.is_expired:
|
|
||||||
return session
|
|
||||||
return None
|
|
||||||
|
|
||||||
def deactivate(self, db: Session, session_id: UUID) -> bool:
|
|
||||||
"""Deactivate a session (logout)."""
|
"""Deactivate a session (logout)."""
|
||||||
try:
|
try:
|
||||||
session = self.get(db, id=session_id)
|
session = await self.get(db, id=session_id)
|
||||||
if not session:
|
if not session:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
session.is_active = False
|
session.is_active = False
|
||||||
db.commit()
|
await db.commit()
|
||||||
logger.info(f"Session {session_id} deactivated")
|
logger.info(f"Session {session_id} deactivated")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.rollback()
|
await db.rollback()
|
||||||
logger.error(f"Error deactivating session: {str(e)}")
|
logger.error(f"Error deactivating session: {str(e)}")
|
||||||
return False
|
return False
|
||||||
```
|
```
|
||||||
|
|
||||||
**Rules**:
|
**Rules**:
|
||||||
- Should NOT contain business logic
|
- Should NOT contain business logic
|
||||||
- Must handle database exceptions
|
- Must raise custom repository exceptions (not raw `ValueError`/`IntegrityError`)
|
||||||
- Must use parameterized queries (SQLAlchemy does this)
|
- Must use async SQLAlchemy 2.0 `select()` API (never `db.query()`)
|
||||||
- Should log all database errors
|
- Should log all database errors
|
||||||
- Must rollback on errors
|
- Must rollback on errors
|
||||||
- Should use soft deletes when possible
|
- Should use soft deletes when possible
|
||||||
|
- **Never imported directly by routes** — always called through services
|
||||||
|
|
||||||
#### 5. Data Layer (`app/models/` + `app/schemas/`)
|
#### 5. Data Layer (`app/models/` + `app/schemas/`)
|
||||||
|
|
||||||
@@ -546,51 +535,23 @@ SessionLocal = sessionmaker(
|
|||||||
#### Dependency Injection Pattern
|
#### Dependency Injection Pattern
|
||||||
|
|
||||||
```python
|
```python
|
||||||
def get_db() -> Generator[Session, None, None]:
|
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||||
"""
|
"""
|
||||||
Database session dependency for FastAPI routes.
|
Async database session dependency for FastAPI routes.
|
||||||
|
|
||||||
Automatically commits on success, rolls back on error.
|
The session is passed to service methods; commit/rollback is
|
||||||
|
managed inside service or repository methods.
|
||||||
"""
|
"""
|
||||||
db = SessionLocal()
|
async with AsyncSessionLocal() as db:
|
||||||
try:
|
|
||||||
yield db
|
yield db
|
||||||
finally:
|
|
||||||
db.close()
|
|
||||||
|
|
||||||
# Usage in routes
|
# Usage in routes — always through a service, never direct repository
|
||||||
@router.get("/users")
|
@router.get("/users")
|
||||||
def list_users(db: Session = Depends(get_db)):
|
async def list_users(
|
||||||
return user_crud.get_multi(db)
|
user_service: UserService = Depends(get_user_service),
|
||||||
```
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
#### Context Manager Pattern
|
return await user_service.get_users(db)
|
||||||
|
|
||||||
```python
|
|
||||||
@contextmanager
|
|
||||||
def transaction_scope() -> Generator[Session, None, None]:
|
|
||||||
"""
|
|
||||||
Context manager for database transactions.
|
|
||||||
|
|
||||||
Use for complex operations requiring multiple steps.
|
|
||||||
Automatically commits on success, rolls back on error.
|
|
||||||
"""
|
|
||||||
db = SessionLocal()
|
|
||||||
try:
|
|
||||||
yield db
|
|
||||||
db.commit()
|
|
||||||
except Exception:
|
|
||||||
db.rollback()
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
db.close()
|
|
||||||
|
|
||||||
# Usage in services
|
|
||||||
def complex_operation():
|
|
||||||
with transaction_scope() as db:
|
|
||||||
user = user_crud.create(db, obj_in=user_data)
|
|
||||||
session = session_crud.create(db, session_data)
|
|
||||||
return user, session
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Model Mixins
|
### Model Mixins
|
||||||
@@ -782,22 +743,15 @@ def get_profile(
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
@router.delete("/sessions/{session_id}")
|
@router.delete("/sessions/{session_id}")
|
||||||
def revoke_session(
|
async def revoke_session(
|
||||||
session_id: UUID,
|
session_id: UUID,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db)
|
session_service: SessionService = Depends(get_session_service),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
"""Users can only revoke their own sessions."""
|
"""Users can only revoke their own sessions."""
|
||||||
session = session_crud.get(db, id=session_id)
|
# SessionService verifies ownership and raises NotFoundError / AuthorizationError
|
||||||
|
await session_service.revoke_session(db, session_id=session_id, user_id=current_user.id)
|
||||||
if not session:
|
|
||||||
raise NotFoundError("Session not found")
|
|
||||||
|
|
||||||
# Check ownership
|
|
||||||
if session.user_id != current_user.id:
|
|
||||||
raise AuthorizationError("You can only revoke your own sessions")
|
|
||||||
|
|
||||||
session_crud.deactivate(db, session_id=session_id)
|
|
||||||
return MessageResponse(success=True, message="Session revoked")
|
return MessageResponse(success=True, message="Session revoked")
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -1092,8 +1046,8 @@ async def cleanup_expired_sessions():
|
|||||||
Runs daily at 2 AM. Removes sessions expired for more than 30 days.
|
Runs daily at 2 AM. Removes sessions expired for more than 30 days.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
with transaction_scope() as db:
|
async with AsyncSessionLocal() as db:
|
||||||
count = session_crud.cleanup_expired(db, keep_days=30)
|
count = await session_repo.cleanup_expired(db, keep_days=30)
|
||||||
logger.info(f"Cleaned up {count} expired sessions")
|
logger.info(f"Cleaned up {count} expired sessions")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error cleaning up sessions: {str(e)}", exc_info=True)
|
logger.error(f"Error cleaning up sessions: {str(e)}", exc_info=True)
|
||||||
@@ -1110,7 +1064,7 @@ async def cleanup_expired_sessions():
|
|||||||
│Integration │ ← API endpoint tests
|
│Integration │ ← API endpoint tests
|
||||||
│ Tests │
|
│ Tests │
|
||||||
├─────────────┤
|
├─────────────┤
|
||||||
│ Unit │ ← CRUD, services, utilities
|
│ Unit │ ← repositories, services, utilities
|
||||||
│ Tests │
|
│ Tests │
|
||||||
└─────────────┘
|
└─────────────┘
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -75,15 +75,14 @@ def create_user(db: Session, user_in: UserCreate) -> User:
|
|||||||
### 4. Code Formatting
|
### 4. Code Formatting
|
||||||
|
|
||||||
Use automated formatters:
|
Use automated formatters:
|
||||||
- **Black**: Code formatting
|
- **Ruff**: Code formatting and linting (replaces Black, isort, flake8)
|
||||||
- **isort**: Import sorting
|
- **pyright**: Static type checking
|
||||||
- **flake8**: Linting
|
|
||||||
|
|
||||||
Run before committing:
|
Run before committing (or use `make validate`):
|
||||||
```bash
|
```bash
|
||||||
black app tests
|
uv run ruff format app tests
|
||||||
isort app tests
|
uv run ruff check app tests
|
||||||
flake8 app tests
|
uv run pyright app
|
||||||
```
|
```
|
||||||
|
|
||||||
## Code Organization
|
## Code Organization
|
||||||
@@ -94,19 +93,17 @@ Follow the 5-layer architecture strictly:
|
|||||||
|
|
||||||
```
|
```
|
||||||
API Layer (routes/)
|
API Layer (routes/)
|
||||||
↓ calls
|
↓ calls (via service injected from dependencies/services.py)
|
||||||
Dependencies (dependencies/)
|
|
||||||
↓ injects
|
|
||||||
Service Layer (services/)
|
Service Layer (services/)
|
||||||
↓ calls
|
↓ calls
|
||||||
CRUD Layer (crud/)
|
Repository Layer (repositories/)
|
||||||
↓ uses
|
↓ uses
|
||||||
Models & Schemas (models/, schemas/)
|
Models & Schemas (models/, schemas/)
|
||||||
```
|
```
|
||||||
|
|
||||||
**Rules:**
|
**Rules:**
|
||||||
- Routes should NOT directly call CRUD operations (use services when business logic is needed)
|
- Routes must NEVER import repositories directly — always use a service
|
||||||
- CRUD operations should NOT contain business logic
|
- Services call repositories; repositories contain only database operations
|
||||||
- Models should NOT import from higher layers
|
- Models should NOT import from higher layers
|
||||||
- Each layer should only depend on the layer directly below it
|
- Each layer should only depend on the layer directly below it
|
||||||
|
|
||||||
@@ -125,7 +122,7 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
# 3. Local application imports
|
# 3. Local application imports
|
||||||
from app.api.dependencies.auth import get_current_user
|
from app.api.dependencies.auth import get_current_user
|
||||||
from app.crud import user_crud
|
from app.api.dependencies.services import get_user_service
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.schemas.users import UserResponse, UserCreate
|
from app.schemas.users import UserResponse, UserCreate
|
||||||
```
|
```
|
||||||
@@ -442,19 +439,19 @@ backend/app/alembic/versions/
|
|||||||
4. **Testability**: Easy to mock and test
|
4. **Testability**: Easy to mock and test
|
||||||
5. **Consistent Ordering**: Always order queries for pagination
|
5. **Consistent Ordering**: Always order queries for pagination
|
||||||
|
|
||||||
### Use the Async CRUD Base Class
|
### Use the Async Repository Base Class
|
||||||
|
|
||||||
Always inherit from `CRUDBase` for database operations:
|
Always inherit from `RepositoryBase` for database operations:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from app.crud.base import CRUDBase
|
from app.repositories.base import RepositoryBase
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.schemas.users import UserCreate, UserUpdate
|
from app.schemas.users import UserCreate, UserUpdate
|
||||||
|
|
||||||
class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
class UserRepository(RepositoryBase[User, UserCreate, UserUpdate]):
|
||||||
"""CRUD operations for User model."""
|
"""Repository for User model — database operations only."""
|
||||||
|
|
||||||
async def get_by_email(
|
async def get_by_email(
|
||||||
self,
|
self,
|
||||||
@@ -467,7 +464,7 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
|||||||
)
|
)
|
||||||
return result.scalar_one_or_none()
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
user_crud = CRUDUser(User)
|
user_repo = UserRepository(User)
|
||||||
```
|
```
|
||||||
|
|
||||||
**Key Points:**
|
**Key Points:**
|
||||||
@@ -476,6 +473,7 @@ user_crud = CRUDUser(User)
|
|||||||
- Use `await db.execute()` for queries
|
- Use `await db.execute()` for queries
|
||||||
- Use `.scalar_one_or_none()` instead of `.first()`
|
- Use `.scalar_one_or_none()` instead of `.first()`
|
||||||
- Use `T | None` instead of `Optional[T]`
|
- Use `T | None` instead of `Optional[T]`
|
||||||
|
- Repository instances are used internally by services — never import them in routes
|
||||||
|
|
||||||
### Modern SQLAlchemy Patterns
|
### Modern SQLAlchemy Patterns
|
||||||
|
|
||||||
@@ -563,7 +561,7 @@ async def create_user(
|
|||||||
The database session is automatically managed by FastAPI.
|
The database session is automatically managed by FastAPI.
|
||||||
Commit on success, rollback on error.
|
Commit on success, rollback on error.
|
||||||
"""
|
"""
|
||||||
return await user_crud.create(db, obj_in=user_in)
|
return await user_service.create_user(db, obj_in=user_in)
|
||||||
```
|
```
|
||||||
|
|
||||||
**Key Points:**
|
**Key Points:**
|
||||||
@@ -582,12 +580,11 @@ async def complex_operation(
|
|||||||
"""
|
"""
|
||||||
Perform multiple database operations atomically.
|
Perform multiple database operations atomically.
|
||||||
|
|
||||||
The session automatically commits on success or rolls back on error.
|
Services call repositories; commit/rollback is handled inside
|
||||||
|
each repository method.
|
||||||
"""
|
"""
|
||||||
user = await user_crud.create(db, obj_in=user_data)
|
user = await user_repo.create(db, obj_in=user_data)
|
||||||
session = await session_crud.create(db, obj_in=session_data)
|
session = await session_repo.create(db, obj_in=session_data)
|
||||||
|
|
||||||
# Commit is handled by the route's dependency
|
|
||||||
return user, session
|
return user, session
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -597,10 +594,10 @@ Prefer soft deletes over hard deletes for audit trails:
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
# Good - Soft delete (sets deleted_at)
|
# Good - Soft delete (sets deleted_at)
|
||||||
await user_crud.soft_delete(db, id=user_id)
|
await user_repo.soft_delete(db, id=user_id)
|
||||||
|
|
||||||
# Acceptable only when required - Hard delete
|
# Acceptable only when required - Hard delete
|
||||||
user_crud.remove(db, id=user_id)
|
await user_repo.remove(db, id=user_id)
|
||||||
```
|
```
|
||||||
|
|
||||||
### Query Patterns
|
### Query Patterns
|
||||||
@@ -740,9 +737,10 @@ Always implement pagination for list endpoints:
|
|||||||
from app.schemas.common import PaginationParams, PaginatedResponse
|
from app.schemas.common import PaginationParams, PaginatedResponse
|
||||||
|
|
||||||
@router.get("/users", response_model=PaginatedResponse[UserResponse])
|
@router.get("/users", response_model=PaginatedResponse[UserResponse])
|
||||||
def list_users(
|
async def list_users(
|
||||||
pagination: PaginationParams = Depends(),
|
pagination: PaginationParams = Depends(),
|
||||||
db: Session = Depends(get_db)
|
user_service: UserService = Depends(get_user_service),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
List all users with pagination.
|
List all users with pagination.
|
||||||
@@ -750,10 +748,8 @@ def list_users(
|
|||||||
Default page size: 20
|
Default page size: 20
|
||||||
Maximum page size: 100
|
Maximum page size: 100
|
||||||
"""
|
"""
|
||||||
users, total = user_crud.get_multi_with_total(
|
users, total = await user_service.get_users(
|
||||||
db,
|
db, skip=pagination.offset, limit=pagination.limit
|
||||||
skip=pagination.offset,
|
|
||||||
limit=pagination.limit
|
|
||||||
)
|
)
|
||||||
return PaginatedResponse(data=users, pagination=pagination.create_meta(total))
|
return PaginatedResponse(data=users, pagination=pagination.create_meta(total))
|
||||||
```
|
```
|
||||||
@@ -816,19 +812,17 @@ def admin_route(
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
# Check ownership
|
# Check ownership
|
||||||
def delete_resource(
|
async def delete_resource(
|
||||||
resource_id: UUID,
|
resource_id: UUID,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db)
|
resource_service: ResourceService = Depends(get_resource_service),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
resource = resource_crud.get(db, id=resource_id)
|
# Service handles ownership check and raises appropriate errors
|
||||||
if not resource:
|
await resource_service.delete_resource(
|
||||||
raise NotFoundError("Resource not found")
|
db, resource_id=resource_id, user_id=current_user.id,
|
||||||
|
is_superuser=current_user.is_superuser,
|
||||||
if resource.user_id != current_user.id and not current_user.is_superuser:
|
)
|
||||||
raise AuthorizationError("You can only delete your own resources")
|
|
||||||
|
|
||||||
resource_crud.remove(db, id=resource_id)
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Input Validation
|
### Input Validation
|
||||||
@@ -862,9 +856,9 @@ tests/
|
|||||||
├── api/ # Integration tests
|
├── api/ # Integration tests
|
||||||
│ ├── test_users.py
|
│ ├── test_users.py
|
||||||
│ └── test_auth.py
|
│ └── test_auth.py
|
||||||
├── crud/ # Unit tests for CRUD
|
├── repositories/ # Unit tests for repositories
|
||||||
├── models/ # Model tests
|
├── services/ # Unit tests for services
|
||||||
└── services/ # Service tests
|
└── models/ # Model tests
|
||||||
```
|
```
|
||||||
|
|
||||||
### Async Testing with pytest-asyncio
|
### Async Testing with pytest-asyncio
|
||||||
@@ -927,7 +921,7 @@ async def test_user(db_session: AsyncSession) -> User:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_user(db_session: AsyncSession, test_user: User):
|
async def test_get_user(db_session: AsyncSession, test_user: User):
|
||||||
"""Test retrieving a user by ID."""
|
"""Test retrieving a user by ID."""
|
||||||
user = await user_crud.get(db_session, id=test_user.id)
|
user = await user_repo.get(db_session, id=test_user.id)
|
||||||
assert user is not None
|
assert user is not None
|
||||||
assert user.email == test_user.email
|
assert user.email == test_user.email
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -616,7 +616,43 @@ def create_user(
|
|||||||
return user
|
return user
|
||||||
```
|
```
|
||||||
|
|
||||||
**Rule**: Add type hints to ALL functions. Use `mypy` to enforce type checking.
|
**Rule**: Add type hints to ALL functions. Use `pyright` to enforce type checking (`make type-check`).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### ❌ PITFALL #19: Importing Repositories Directly in Routes
|
||||||
|
|
||||||
|
**Issue**: Routes should never call repositories directly. The layered architecture requires all business operations to go through the service layer.
|
||||||
|
|
||||||
|
```python
|
||||||
|
# ❌ WRONG - Route bypasses service layer
|
||||||
|
from app.repositories.session import session_repo
|
||||||
|
|
||||||
|
@router.get("/sessions/me")
|
||||||
|
async def list_sessions(
|
||||||
|
current_user: User = Depends(get_current_active_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
return await session_repo.get_user_sessions(db, user_id=current_user.id)
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
# ✅ CORRECT - Route calls service injected via dependency
|
||||||
|
from app.api.dependencies.services import get_session_service
|
||||||
|
from app.services.session_service import SessionService
|
||||||
|
|
||||||
|
@router.get("/sessions/me")
|
||||||
|
async def list_sessions(
|
||||||
|
current_user: User = Depends(get_current_active_user),
|
||||||
|
session_service: SessionService = Depends(get_session_service),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
return await session_service.get_user_sessions(db, user_id=current_user.id)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Rule**: Routes import from `app.api.dependencies.services`, never from `app.repositories.*`. Services are the only callers of repositories.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -649,6 +685,11 @@ Use this checklist to catch issues before code review:
|
|||||||
- [ ] Resource ownership verification
|
- [ ] Resource ownership verification
|
||||||
- [ ] CORS configured (no wildcards in production)
|
- [ ] CORS configured (no wildcards in production)
|
||||||
|
|
||||||
|
### Architecture
|
||||||
|
- [ ] Routes never import repositories directly (only services)
|
||||||
|
- [ ] Services call repositories; repositories call database only
|
||||||
|
- [ ] New service registered in `app/api/dependencies/services.py`
|
||||||
|
|
||||||
### Python
|
### Python
|
||||||
- [ ] Use `==` not `is` for value comparison
|
- [ ] Use `==` not `is` for value comparison
|
||||||
- [ ] No mutable default arguments
|
- [ ] No mutable default arguments
|
||||||
@@ -661,21 +702,18 @@ Use this checklist to catch issues before code review:
|
|||||||
|
|
||||||
### Pre-commit Checks
|
### Pre-commit Checks
|
||||||
|
|
||||||
Add these to your development workflow:
|
Add these to your development workflow (or use `make validate`):
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Format code
|
# Format + lint (Ruff replaces Black, isort, flake8)
|
||||||
black app tests
|
uv run ruff format app tests
|
||||||
isort app tests
|
uv run ruff check app tests
|
||||||
|
|
||||||
# Type checking
|
# Type checking
|
||||||
mypy app --strict
|
uv run pyright app
|
||||||
|
|
||||||
# Linting
|
|
||||||
flake8 app tests
|
|
||||||
|
|
||||||
# Run tests
|
# Run tests
|
||||||
pytest --cov=app --cov-report=term-missing
|
IS_TEST=True uv run pytest --cov=app --cov-report=term-missing
|
||||||
|
|
||||||
# Check coverage (should be 80%+)
|
# Check coverage (should be 80%+)
|
||||||
coverage report --fail-under=80
|
coverage report --fail-under=80
|
||||||
@@ -693,6 +731,6 @@ Add new entries when:
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
**Last Updated**: 2025-10-31
|
**Last Updated**: 2026-02-28
|
||||||
**Issues Cataloged**: 18 common pitfalls
|
**Issues Cataloged**: 19 common pitfalls
|
||||||
**Remember**: This document exists because these issues HAVE occurred. Don't skip it.
|
**Remember**: This document exists because these issues HAVE occurred. Don't skip it.
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ This guide walks through implementing a complete feature using the **User Sessio
|
|||||||
- [Implementation Steps](#implementation-steps)
|
- [Implementation Steps](#implementation-steps)
|
||||||
- [Step 1: Design the Database Model](#step-1-design-the-database-model)
|
- [Step 1: Design the Database Model](#step-1-design-the-database-model)
|
||||||
- [Step 2: Create Pydantic Schemas](#step-2-create-pydantic-schemas)
|
- [Step 2: Create Pydantic Schemas](#step-2-create-pydantic-schemas)
|
||||||
- [Step 3: Implement CRUD Operations](#step-3-implement-crud-operations)
|
- [Step 3: Implement Repository](#step-3-implement-repository)
|
||||||
- [Step 4: Create API Endpoints](#step-4-create-api-endpoints)
|
- [Step 4: Create API Endpoints](#step-4-create-api-endpoints)
|
||||||
- [Step 5: Integrate with Existing Features](#step-5-integrate-with-existing-features)
|
- [Step 5: Integrate with Existing Features](#step-5-integrate-with-existing-features)
|
||||||
- [Step 6: Add Background Jobs](#step-6-add-background-jobs)
|
- [Step 6: Add Background Jobs](#step-6-add-background-jobs)
|
||||||
@@ -204,8 +204,8 @@ Follow the standard pattern:
|
|||||||
|
|
||||||
```
|
```
|
||||||
SessionBase (common fields)
|
SessionBase (common fields)
|
||||||
├── SessionCreate (internal: CRUD operations)
|
├── SessionCreate (internal: repository operations)
|
||||||
├── SessionUpdate (internal: CRUD operations)
|
├── SessionUpdate (internal: repository operations)
|
||||||
└── SessionResponse (external: API responses)
|
└── SessionResponse (external: API responses)
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -240,7 +240,7 @@ class SessionCreate(SessionBase):
|
|||||||
"""
|
"""
|
||||||
Schema for creating a new session (internal use).
|
Schema for creating a new session (internal use).
|
||||||
|
|
||||||
Used by CRUD operations, not exposed to API.
|
Used by repository operations, not exposed to API.
|
||||||
Contains all fields needed to create a session.
|
Contains all fields needed to create a session.
|
||||||
"""
|
"""
|
||||||
user_id: UUID
|
user_id: UUID
|
||||||
@@ -344,37 +344,37 @@ class DeviceInfo(BaseModel):
|
|||||||
5. **OpenAPI Documentation**: `json_schema_extra` provides examples in API docs
|
5. **OpenAPI Documentation**: `json_schema_extra` provides examples in API docs
|
||||||
6. **Type Safety**: Comprehensive type hints for all fields
|
6. **Type Safety**: Comprehensive type hints for all fields
|
||||||
|
|
||||||
### Step 3: Implement CRUD Operations
|
### Step 3: Implement Repository
|
||||||
|
|
||||||
**File**: `app/crud/session.py`
|
**File**: `app/repositories/session.py`
|
||||||
|
|
||||||
CRUD layer handles all database operations. No business logic here!
|
The repository layer handles all database operations. No business logic here — that belongs in services!
|
||||||
|
|
||||||
#### 3.1 Extend the Base CRUD Class
|
#### 3.1 Extend the Base Repository Class
|
||||||
|
|
||||||
```python
|
```python
|
||||||
"""
|
"""
|
||||||
CRUD operations for user sessions.
|
Repository for user sessions.
|
||||||
"""
|
"""
|
||||||
from datetime import datetime, timezone, timedelta
|
from datetime import datetime, timezone, timedelta
|
||||||
from typing import List, Optional
|
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
from sqlalchemy import and_
|
from sqlalchemy import and_, select, update
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from app.crud.base import CRUDBase
|
from app.repositories.base import RepositoryBase
|
||||||
from app.models.user_session import UserSession
|
from app.models.user_session import UserSession
|
||||||
from app.schemas.sessions import SessionCreate, SessionUpdate
|
from app.schemas.sessions import SessionCreate, SessionUpdate
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
class SessionRepository(RepositoryBase[UserSession, SessionCreate, SessionUpdate]):
|
||||||
"""
|
"""
|
||||||
CRUD operations for user sessions.
|
Repository for user sessions.
|
||||||
|
|
||||||
Inherits standard operations from CRUDBase:
|
Inherits standard operations from RepositoryBase:
|
||||||
- get(db, id) - Get by ID
|
- get(db, id) - Get by ID
|
||||||
- get_multi(db, skip, limit) - List with pagination
|
- get_multi(db, skip, limit) - List with pagination
|
||||||
- create(db, obj_in) - Create new session
|
- create(db, obj_in) - Create new session
|
||||||
@@ -382,111 +382,62 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
|||||||
- remove(db, id) - Delete session
|
- remove(db, id) - Delete session
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Custom query methods
|
async def get_by_jti(self, db: AsyncSession, *, jti: str) -> UserSession | None:
|
||||||
# --------------------
|
|
||||||
|
|
||||||
def get_by_jti(self, db: Session, *, jti: str) -> Optional[UserSession]:
|
|
||||||
"""
|
"""
|
||||||
Get session by refresh token JTI.
|
Get session by refresh token JTI.
|
||||||
|
|
||||||
Used during token refresh to find the corresponding session.
|
Used during token refresh to find the corresponding session.
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
jti: Refresh token JWT ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
UserSession if found, None otherwise
|
|
||||||
"""
|
"""
|
||||||
try:
|
result = await db.execute(
|
||||||
return db.query(UserSession).filter(
|
select(UserSession).where(UserSession.refresh_token_jti == jti)
|
||||||
UserSession.refresh_token_jti == jti
|
)
|
||||||
).first()
|
return result.scalar_one_or_none()
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error getting session by JTI {jti}: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def get_active_by_jti(self, db: Session, *, jti: str) -> Optional[UserSession]:
|
async def get_active_by_jti(self, db: AsyncSession, *, jti: str) -> UserSession | None:
|
||||||
"""
|
"""Get active (non-expired) session by refresh token JTI."""
|
||||||
Get active session by refresh token JTI.
|
result = await db.execute(
|
||||||
|
select(UserSession).where(
|
||||||
Only returns the session if it's currently active.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
jti: Refresh token JWT ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Active UserSession if found, None otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
return db.query(UserSession).filter(
|
|
||||||
and_(
|
and_(
|
||||||
UserSession.refresh_token_jti == jti,
|
UserSession.refresh_token_jti == jti,
|
||||||
UserSession.is_active == True
|
UserSession.is_active.is_(True),
|
||||||
)
|
)
|
||||||
).first()
|
)
|
||||||
except Exception as e:
|
)
|
||||||
logger.error(f"Error getting active session by JTI {jti}: {str(e)}")
|
session = result.scalar_one_or_none()
|
||||||
raise
|
if session and not session.is_expired:
|
||||||
|
return session
|
||||||
|
return None
|
||||||
|
|
||||||
def get_user_sessions(
|
async def get_user_sessions(
|
||||||
self,
|
self,
|
||||||
db: Session,
|
db: AsyncSession,
|
||||||
*,
|
*,
|
||||||
user_id: str,
|
user_id: UUID,
|
||||||
active_only: bool = True
|
active_only: bool = True,
|
||||||
) -> List[UserSession]:
|
) -> list[UserSession]:
|
||||||
"""
|
"""
|
||||||
Get all sessions for a user.
|
Get all sessions for a user, ordered by most recently used.
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
user_id: User ID
|
|
||||||
active_only: If True, return only active sessions
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of UserSession objects, ordered by most recently used
|
|
||||||
"""
|
"""
|
||||||
try:
|
query = select(UserSession).where(UserSession.user_id == user_id)
|
||||||
# Convert user_id string to UUID if needed
|
|
||||||
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
|
|
||||||
|
|
||||||
query = db.query(UserSession).filter(UserSession.user_id == user_uuid)
|
|
||||||
|
|
||||||
if active_only:
|
if active_only:
|
||||||
query = query.filter(UserSession.is_active == True)
|
query = query.where(UserSession.is_active.is_(True))
|
||||||
|
query = query.order_by(UserSession.last_used_at.desc())
|
||||||
|
result = await db.execute(query)
|
||||||
|
return list(result.scalars().all())
|
||||||
|
|
||||||
# Order by most recently used first
|
async def create_session(
|
||||||
return query.order_by(UserSession.last_used_at.desc()).all()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error getting sessions for user {user_id}: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
# Creation methods
|
|
||||||
# ----------------
|
|
||||||
|
|
||||||
def create_session(
|
|
||||||
self,
|
self,
|
||||||
db: Session,
|
db: AsyncSession,
|
||||||
*,
|
*,
|
||||||
obj_in: SessionCreate
|
obj_in: SessionCreate,
|
||||||
) -> UserSession:
|
) -> UserSession:
|
||||||
"""
|
"""
|
||||||
Create a new user session.
|
Create a new user session.
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
obj_in: SessionCreate schema with session data
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Created UserSession
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If session creation fails
|
DuplicateEntryError: If a session with the same JTI already exists
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Create model instance from schema
|
|
||||||
db_obj = UserSession(
|
db_obj = UserSession(
|
||||||
user_id=obj_in.user_id,
|
user_id=obj_in.user_id,
|
||||||
refresh_token_jti=obj_in.refresh_token_jti,
|
refresh_token_jti=obj_in.refresh_token_jti,
|
||||||
@@ -501,248 +452,93 @@ class CRUDSession(CRUDBase[UserSession, SessionCreate, SessionUpdate]):
|
|||||||
location_country=obj_in.location_country,
|
location_country=obj_in.location_country,
|
||||||
)
|
)
|
||||||
|
|
||||||
db.add(db_obj)
|
db_obj.add(db_obj)
|
||||||
db.commit()
|
await db.commit()
|
||||||
db.refresh(db_obj)
|
await db.refresh(db_obj)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Session created for user {obj_in.user_id} from {obj_in.device_name}"
|
f"Session created for user {obj_in.user_id} from {obj_in.device_name}"
|
||||||
f"(IP: {obj_in.ip_address})"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return db_obj
|
return db_obj
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.rollback()
|
await db.rollback()
|
||||||
logger.error(f"Error creating session: {str(e)}", exc_info=True)
|
logger.error(f"Error creating session: {str(e)}", exc_info=True)
|
||||||
raise ValueError(f"Failed to create session: {str(e)}")
|
|
||||||
|
|
||||||
# Update methods
|
|
||||||
# --------------
|
|
||||||
|
|
||||||
def deactivate(self, db: Session, *, session_id: str) -> Optional[UserSession]:
|
|
||||||
"""
|
|
||||||
Deactivate a session (logout from device).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
session_id: Session UUID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Deactivated UserSession if found, None otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
session = self.get(db, id=session_id)
|
|
||||||
if not session:
|
|
||||||
logger.warning(f"Session {session_id} not found for deactivation")
|
|
||||||
return None
|
|
||||||
|
|
||||||
session.is_active = False
|
|
||||||
db.add(session)
|
|
||||||
db.commit()
|
|
||||||
db.refresh(session)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Session {session_id} deactivated for user {session.user_id} "
|
|
||||||
f"({session.device_name})"
|
|
||||||
)
|
|
||||||
|
|
||||||
return session
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
db.rollback()
|
|
||||||
logger.error(f"Error deactivating session {session_id}: {str(e)}")
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def deactivate_all_user_sessions(
|
async def deactivate(self, db: AsyncSession, *, session_id: UUID) -> UserSession | None:
|
||||||
|
"""Deactivate a session (logout from device)."""
|
||||||
|
session = await self.get(db, id=session_id)
|
||||||
|
if not session:
|
||||||
|
return None
|
||||||
|
session.is_active = False
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(session)
|
||||||
|
logger.info(f"Session {session_id} deactivated ({session.device_name})")
|
||||||
|
return session
|
||||||
|
|
||||||
|
async def deactivate_all_user_sessions(
|
||||||
self,
|
self,
|
||||||
db: Session,
|
db: AsyncSession,
|
||||||
*,
|
*,
|
||||||
user_id: str
|
user_id: UUID,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Deactivate all active sessions for a user (logout from all devices).
|
Deactivate all active sessions for a user (logout from all devices).
|
||||||
|
|
||||||
Uses bulk update for efficiency.
|
Uses a bulk UPDATE for efficiency — no N+1 queries.
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
user_id: User ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Number of sessions deactivated
|
|
||||||
"""
|
"""
|
||||||
try:
|
result = await db.execute(
|
||||||
# Convert user_id string to UUID if needed
|
update(UserSession)
|
||||||
user_uuid = UUID(user_id) if isinstance(user_id, str) else user_id
|
.where(
|
||||||
|
|
||||||
# Bulk update query
|
|
||||||
count = db.query(UserSession).filter(
|
|
||||||
and_(
|
|
||||||
UserSession.user_id == user_uuid,
|
|
||||||
UserSession.is_active == True
|
|
||||||
)
|
|
||||||
).update({"is_active": False})
|
|
||||||
|
|
||||||
db.commit()
|
|
||||||
|
|
||||||
logger.info(f"Deactivated {count} sessions for user {user_id}")
|
|
||||||
|
|
||||||
return count
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
db.rollback()
|
|
||||||
logger.error(f"Error deactivating all sessions for user {user_id}: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def update_last_used(
|
|
||||||
self,
|
|
||||||
db: Session,
|
|
||||||
*,
|
|
||||||
session: UserSession
|
|
||||||
) -> UserSession:
|
|
||||||
"""
|
|
||||||
Update the last_used_at timestamp for a session.
|
|
||||||
|
|
||||||
Called when a refresh token is used.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
session: UserSession object
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Updated UserSession
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
session.last_used_at = datetime.now(timezone.utc)
|
|
||||||
db.add(session)
|
|
||||||
db.commit()
|
|
||||||
db.refresh(session)
|
|
||||||
return session
|
|
||||||
except Exception as e:
|
|
||||||
db.rollback()
|
|
||||||
logger.error(f"Error updating last_used for session {session.id}: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def update_refresh_token(
|
|
||||||
self,
|
|
||||||
db: Session,
|
|
||||||
*,
|
|
||||||
session: UserSession,
|
|
||||||
new_jti: str,
|
|
||||||
new_expires_at: datetime
|
|
||||||
) -> UserSession:
|
|
||||||
"""
|
|
||||||
Update session with new refresh token JTI and expiration.
|
|
||||||
|
|
||||||
Called during token refresh (token rotation).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
session: UserSession object
|
|
||||||
new_jti: New refresh token JTI
|
|
||||||
new_expires_at: New expiration datetime
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Updated UserSession
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
session.refresh_token_jti = new_jti
|
|
||||||
session.expires_at = new_expires_at
|
|
||||||
session.last_used_at = datetime.now(timezone.utc)
|
|
||||||
db.add(session)
|
|
||||||
db.commit()
|
|
||||||
db.refresh(session)
|
|
||||||
return session
|
|
||||||
except Exception as e:
|
|
||||||
db.rollback()
|
|
||||||
logger.error(f"Error updating refresh token for session {session.id}: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
# Cleanup methods
|
|
||||||
# ---------------
|
|
||||||
|
|
||||||
def cleanup_expired(self, db: Session, *, keep_days: int = 30) -> int:
|
|
||||||
"""
|
|
||||||
Clean up expired sessions.
|
|
||||||
|
|
||||||
Deletes sessions that are:
|
|
||||||
- Expired (expires_at < now) AND inactive
|
|
||||||
- Older than keep_days (for audit trail)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
keep_days: Keep inactive sessions for this many days
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Number of sessions deleted
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=keep_days)
|
|
||||||
|
|
||||||
count = db.query(UserSession).filter(
|
|
||||||
and_(
|
|
||||||
UserSession.is_active == False,
|
|
||||||
UserSession.expires_at < datetime.now(timezone.utc),
|
|
||||||
UserSession.created_at < cutoff_date
|
|
||||||
)
|
|
||||||
).delete()
|
|
||||||
|
|
||||||
db.commit()
|
|
||||||
|
|
||||||
if count > 0:
|
|
||||||
logger.info(f"Cleaned up {count} expired sessions")
|
|
||||||
|
|
||||||
return count
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
db.rollback()
|
|
||||||
logger.error(f"Error cleaning up expired sessions: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
# Utility methods
|
|
||||||
# ---------------
|
|
||||||
|
|
||||||
def get_user_session_count(self, db: Session, *, user_id: str) -> int:
|
|
||||||
"""
|
|
||||||
Get count of active sessions for a user.
|
|
||||||
|
|
||||||
Useful for session limits or security monitoring.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
user_id: User ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Number of active sessions
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
return db.query(UserSession).filter(
|
|
||||||
and_(
|
and_(
|
||||||
UserSession.user_id == user_id,
|
UserSession.user_id == user_id,
|
||||||
UserSession.is_active == True
|
UserSession.is_active.is_(True),
|
||||||
)
|
)
|
||||||
).count()
|
)
|
||||||
except Exception as e:
|
.values(is_active=False)
|
||||||
logger.error(f"Error counting sessions for user {user_id}: {str(e)}")
|
)
|
||||||
raise
|
await db.commit()
|
||||||
|
count = result.rowcount
|
||||||
|
logger.info(f"Deactivated {count} sessions for user {user_id}")
|
||||||
|
return count
|
||||||
|
|
||||||
|
async def cleanup_expired(self, db: AsyncSession, *, keep_days: int = 30) -> int:
|
||||||
|
"""
|
||||||
|
Hard-delete inactive sessions older than keep_days.
|
||||||
|
|
||||||
|
Returns the number of sessions deleted.
|
||||||
|
"""
|
||||||
|
cutoff_date = datetime.now(timezone.utc) - timedelta(days=keep_days)
|
||||||
|
result = await db.execute(
|
||||||
|
select(UserSession).where(
|
||||||
|
and_(
|
||||||
|
UserSession.is_active.is_(False),
|
||||||
|
UserSession.expires_at < datetime.now(timezone.utc),
|
||||||
|
UserSession.created_at < cutoff_date,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
sessions = list(result.scalars().all())
|
||||||
|
for s in sessions:
|
||||||
|
await db.delete(s)
|
||||||
|
await db.commit()
|
||||||
|
if sessions:
|
||||||
|
logger.info(f"Cleaned up {len(sessions)} expired sessions")
|
||||||
|
return len(sessions)
|
||||||
|
|
||||||
|
|
||||||
# Create singleton instance
|
# Singleton instance — used by services, never imported directly in routes
|
||||||
# This is the instance that will be imported and used throughout the app
|
session_repo = SessionRepository(UserSession)
|
||||||
session = CRUDSession(UserSession)
|
|
||||||
```
|
```
|
||||||
|
|
||||||
**Key Patterns**:
|
**Key Patterns**:
|
||||||
|
|
||||||
1. **Error Handling**: Every method has try/except with rollback
|
1. **Async everywhere**: All methods use `async def` and `await`
|
||||||
2. **Logging**: Log all significant actions (create, delete, errors)
|
2. **Modern SQLAlchemy**: `select()` API, never `db.query()`
|
||||||
3. **Type Safety**: Full type hints for parameters and returns
|
3. **Bulk updates**: Use `update()` statement for multi-row changes (no N+1)
|
||||||
4. **Docstrings**: Document what each method does, args, returns, raises
|
4. **Error handling**: `try/except` with `await db.rollback()` in mutating methods
|
||||||
5. **Bulk Operations**: Use `query().update()` for efficiency when updating many rows
|
5. **Logging**: Log all significant actions (create, delete, errors)
|
||||||
6. **UUID Handling**: Convert string UUIDs to UUID objects when needed
|
6. **Type safety**: Full type hints; `UUID` not raw `str` for IDs
|
||||||
7. **Ordering**: Return results in a logical order (most recent first)
|
7. **Singleton pattern**: One module-level instance used by services
|
||||||
8. **Singleton Pattern**: Create one instance to be imported elsewhere
|
|
||||||
|
|
||||||
### Step 4: Create API Endpoints
|
### Step 4: Create API Endpoints
|
||||||
|
|
||||||
@@ -772,14 +568,15 @@ from uuid import UUID
|
|||||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||||
from slowapi import Limiter
|
from slowapi import Limiter
|
||||||
from slowapi.util import get_remote_address
|
from slowapi.util import get_remote_address
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.api.dependencies.auth import get_current_user
|
from app.api.dependencies.auth import get_current_user
|
||||||
|
from app.api.dependencies.services import get_session_service
|
||||||
from app.core.database import get_db
|
from app.core.database import get_db
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.schemas.sessions import SessionResponse, SessionListResponse
|
from app.schemas.sessions import SessionResponse, SessionListResponse
|
||||||
from app.schemas.common import MessageResponse
|
from app.schemas.common import MessageResponse
|
||||||
from app.crud.session import session as session_crud
|
from app.services.session_service import SessionService
|
||||||
from app.core.exceptions import NotFoundError, AuthorizationError, ErrorCode
|
from app.core.exceptions import NotFoundError, AuthorizationError, ErrorCode
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
@@ -803,61 +600,21 @@ limiter = Limiter(key_func=get_remote_address)
|
|||||||
operation_id="list_my_sessions"
|
operation_id="list_my_sessions"
|
||||||
)
|
)
|
||||||
@limiter.limit("30/minute")
|
@limiter.limit("30/minute")
|
||||||
def list_my_sessions(
|
async def list_my_sessions(
|
||||||
request: Request,
|
request: Request,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db)
|
session_service: SessionService = Depends(get_session_service),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""List all active sessions for the current user."""
|
||||||
List all active sessions for the current user.
|
sessions = await session_service.get_user_sessions(
|
||||||
|
db, user_id=current_user.id, active_only=True
|
||||||
Args:
|
|
||||||
request: FastAPI request object (for rate limiting)
|
|
||||||
current_user: Current authenticated user (injected)
|
|
||||||
db: Database session (injected)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
SessionListResponse with list of active sessions
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Get all active sessions for user
|
|
||||||
sessions = session_crud.get_user_sessions(
|
|
||||||
db,
|
|
||||||
user_id=str(current_user.id),
|
|
||||||
active_only=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert to response format
|
|
||||||
session_responses = []
|
|
||||||
for idx, s in enumerate(sessions):
|
|
||||||
session_response = SessionResponse(
|
|
||||||
id=s.id,
|
|
||||||
device_name=s.device_name,
|
|
||||||
device_id=s.device_id,
|
|
||||||
ip_address=s.ip_address,
|
|
||||||
location_city=s.location_city,
|
|
||||||
location_country=s.location_country,
|
|
||||||
last_used_at=s.last_used_at,
|
|
||||||
created_at=s.created_at,
|
|
||||||
expires_at=s.expires_at,
|
|
||||||
# Mark the most recently used session as current
|
|
||||||
is_current=(idx == 0)
|
|
||||||
)
|
|
||||||
session_responses.append(session_response)
|
|
||||||
|
|
||||||
logger.info(f"User {current_user.id} listed {len(session_responses)} active sessions")
|
|
||||||
|
|
||||||
return SessionListResponse(
|
|
||||||
sessions=session_responses,
|
|
||||||
total=len(session_responses)
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error listing sessions for user {current_user.id}: {str(e)}", exc_info=True)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail="Failed to retrieve sessions"
|
|
||||||
)
|
)
|
||||||
|
session_responses = [
|
||||||
|
SessionResponse.model_validate(s) | {"is_current": idx == 0}
|
||||||
|
for idx, s in enumerate(sessions)
|
||||||
|
]
|
||||||
|
return SessionListResponse(sessions=session_responses, total=len(session_responses))
|
||||||
|
|
||||||
|
|
||||||
@router.delete(
|
@router.delete(
|
||||||
@@ -876,70 +633,25 @@ def list_my_sessions(
|
|||||||
operation_id="revoke_session"
|
operation_id="revoke_session"
|
||||||
)
|
)
|
||||||
@limiter.limit("10/minute")
|
@limiter.limit("10/minute")
|
||||||
def revoke_session(
|
async def revoke_session(
|
||||||
request: Request,
|
request: Request,
|
||||||
session_id: UUID,
|
session_id: UUID,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db)
|
session_service: SessionService = Depends(get_session_service),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Revoke a specific session by ID.
|
Revoke a specific session by ID.
|
||||||
|
|
||||||
Args:
|
The service verifies ownership and raises NotFoundError /
|
||||||
request: FastAPI request object (for rate limiting)
|
AuthorizationError which are handled by global exception handlers.
|
||||||
session_id: UUID of the session to revoke
|
|
||||||
current_user: Current authenticated user (injected)
|
|
||||||
db: Database session (injected)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
MessageResponse with success message
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
NotFoundError: If session doesn't exist
|
|
||||||
AuthorizationError: If session belongs to another user
|
|
||||||
"""
|
"""
|
||||||
try:
|
device_name = await session_service.revoke_session(
|
||||||
# Get the session
|
db, session_id=session_id, user_id=current_user.id
|
||||||
session = session_crud.get(db, id=str(session_id))
|
|
||||||
|
|
||||||
if not session:
|
|
||||||
raise NotFoundError(
|
|
||||||
message=f"Session {session_id} not found",
|
|
||||||
error_code=ErrorCode.NOT_FOUND
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify session belongs to current user (authorization check)
|
|
||||||
if str(session.user_id) != str(current_user.id):
|
|
||||||
logger.warning(
|
|
||||||
f"User {current_user.id} attempted to revoke session {session_id} "
|
|
||||||
f"belonging to user {session.user_id}"
|
|
||||||
)
|
|
||||||
raise AuthorizationError(
|
|
||||||
message="You can only revoke your own sessions",
|
|
||||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS
|
|
||||||
)
|
|
||||||
|
|
||||||
# Deactivate the session
|
|
||||||
session_crud.deactivate(db, session_id=str(session_id))
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"User {current_user.id} revoked session {session_id} "
|
|
||||||
f"({session.device_name})"
|
|
||||||
)
|
|
||||||
|
|
||||||
return MessageResponse(
|
return MessageResponse(
|
||||||
success=True,
|
success=True,
|
||||||
message=f"Session revoked: {session.device_name or 'Unknown device'}"
|
message=f"Session revoked: {device_name or 'Unknown device'}"
|
||||||
)
|
|
||||||
|
|
||||||
except (NotFoundError, AuthorizationError):
|
|
||||||
# Re-raise custom exceptions (they'll be handled by global handlers)
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error revoking session {session_id}: {str(e)}", exc_info=True)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail="Failed to revoke session"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -958,55 +670,20 @@ def revoke_session(
|
|||||||
operation_id="cleanup_expired_sessions"
|
operation_id="cleanup_expired_sessions"
|
||||||
)
|
)
|
||||||
@limiter.limit("5/minute")
|
@limiter.limit("5/minute")
|
||||||
def cleanup_expired_sessions(
|
async def cleanup_expired_sessions(
|
||||||
request: Request,
|
request: Request,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db)
|
session_service: SessionService = Depends(get_session_service),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""Cleanup expired sessions for the current user."""
|
||||||
Cleanup expired sessions for the current user.
|
deleted_count = await session_service.cleanup_user_expired_sessions(
|
||||||
|
db, user_id=current_user.id
|
||||||
Args:
|
|
||||||
request: FastAPI request object (for rate limiting)
|
|
||||||
current_user: Current authenticated user (injected)
|
|
||||||
db: Database session (injected)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
MessageResponse with count of sessions cleaned
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
|
|
||||||
# Get all sessions for user (including inactive)
|
|
||||||
all_sessions = session_crud.get_user_sessions(
|
|
||||||
db,
|
|
||||||
user_id=str(current_user.id),
|
|
||||||
active_only=False
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Delete expired and inactive sessions
|
|
||||||
deleted_count = 0
|
|
||||||
for s in all_sessions:
|
|
||||||
if not s.is_active and s.expires_at < datetime.now(timezone.utc):
|
|
||||||
db.delete(s)
|
|
||||||
deleted_count += 1
|
|
||||||
|
|
||||||
db.commit()
|
|
||||||
|
|
||||||
logger.info(f"User {current_user.id} cleaned up {deleted_count} expired sessions")
|
|
||||||
|
|
||||||
return MessageResponse(
|
return MessageResponse(
|
||||||
success=True,
|
success=True,
|
||||||
message=f"Cleaned up {deleted_count} expired sessions"
|
message=f"Cleaned up {deleted_count} expired sessions"
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error cleaning up sessions for user {current_user.id}: {str(e)}", exc_info=True)
|
|
||||||
db.rollback()
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail="Failed to cleanup sessions"
|
|
||||||
)
|
|
||||||
```
|
```
|
||||||
|
|
||||||
**Key Patterns**:
|
**Key Patterns**:
|
||||||
@@ -1079,59 +756,25 @@ Session management needs to be integrated into the authentication flow.
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
from app.utils.device import extract_device_info
|
from app.utils.device import extract_device_info
|
||||||
from app.crud.session import session as session_crud
|
from app.api.dependencies.services import get_auth_service
|
||||||
from app.schemas.sessions import SessionCreate
|
from app.services.auth_service import AuthService
|
||||||
|
|
||||||
@router.post("/login")
|
@router.post("/login")
|
||||||
async def login(
|
async def login(
|
||||||
request: Request,
|
request: Request,
|
||||||
credentials: OAuth2PasswordRequestForm = Depends(),
|
credentials: OAuth2PasswordRequestForm = Depends(),
|
||||||
db: Session = Depends(get_db)
|
auth_service: AuthService = Depends(get_auth_service),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
"""Authenticate user and create session."""
|
"""Authenticate user and create session."""
|
||||||
|
# All business logic (validate credentials, create session, generate tokens)
|
||||||
# 1. Validate credentials
|
# is delegated to AuthService which calls the appropriate repositories.
|
||||||
user = user_crud.get_by_email(db, email=credentials.username)
|
return await auth_service.login(
|
||||||
if not user or not verify_password(credentials.password, user.hashed_password):
|
db,
|
||||||
raise AuthenticationError("Invalid credentials")
|
email=credentials.username,
|
||||||
|
password=credentials.password,
|
||||||
if not user.is_active:
|
request=request,
|
||||||
raise AuthenticationError("Account is inactive")
|
|
||||||
|
|
||||||
# 2. Extract device information from request
|
|
||||||
device_info = extract_device_info(request)
|
|
||||||
|
|
||||||
# 3. Generate tokens
|
|
||||||
jti = str(uuid.uuid4()) # Generate JTI for refresh token
|
|
||||||
access_token = create_access_token(subject=str(user.id))
|
|
||||||
refresh_token = create_refresh_token(subject=str(user.id), jti=jti)
|
|
||||||
|
|
||||||
# 4. Create session record
|
|
||||||
from datetime import datetime, timezone, timedelta
|
|
||||||
|
|
||||||
session_data = SessionCreate(
|
|
||||||
user_id=user.id,
|
|
||||||
refresh_token_jti=jti,
|
|
||||||
device_name=device_info.device_name,
|
|
||||||
device_id=device_info.device_id,
|
|
||||||
ip_address=device_info.ip_address,
|
|
||||||
user_agent=device_info.user_agent,
|
|
||||||
last_used_at=datetime.now(timezone.utc),
|
|
||||||
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
|
|
||||||
location_city=device_info.location_city,
|
|
||||||
location_country=device_info.location_country,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
session_crud.create_session(db, obj_in=session_data)
|
|
||||||
|
|
||||||
logger.info(f"User {user.email} logged in from {device_info.device_name}")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"access_token": access_token,
|
|
||||||
"refresh_token": refresh_token,
|
|
||||||
"token_type": "bearer",
|
|
||||||
"user": UserResponse.model_validate(user)
|
|
||||||
}
|
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 5.2 Create Device Info Utility
|
#### 5.2 Create Device Info Utility
|
||||||
@@ -1193,89 +836,35 @@ def extract_device_info(request: Request) -> DeviceInfo:
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
@router.post("/refresh")
|
@router.post("/refresh")
|
||||||
def refresh_token(
|
async def refresh_token(
|
||||||
refresh_request: RefreshRequest,
|
refresh_request: RefreshRequest,
|
||||||
db: Session = Depends(get_db)
|
auth_service: AuthService = Depends(get_auth_service),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
"""Refresh access token using refresh token."""
|
"""Refresh access token using refresh token."""
|
||||||
|
# AuthService handles token validation, session lookup, token rotation
|
||||||
try:
|
return await auth_service.refresh_tokens(
|
||||||
# 1. Decode and validate refresh token
|
db, refresh_token=refresh_request.refresh_token
|
||||||
payload = decode_token(refresh_request.refresh_token)
|
|
||||||
|
|
||||||
if payload.get("type") != "refresh":
|
|
||||||
raise AuthenticationError("Invalid token type")
|
|
||||||
|
|
||||||
user_id = UUID(payload.get("sub"))
|
|
||||||
jti = payload.get("jti")
|
|
||||||
|
|
||||||
# 2. Find and validate session
|
|
||||||
session = session_crud.get_active_by_jti(db, jti=jti)
|
|
||||||
|
|
||||||
if not session:
|
|
||||||
raise AuthenticationError("Session not found or expired")
|
|
||||||
|
|
||||||
if session.user_id != user_id:
|
|
||||||
raise AuthenticationError("Token mismatch")
|
|
||||||
|
|
||||||
# 3. Generate new tokens (token rotation)
|
|
||||||
new_jti = str(uuid.uuid4())
|
|
||||||
new_access_token = create_access_token(subject=str(user_id))
|
|
||||||
new_refresh_token = create_refresh_token(subject=str(user_id), jti=new_jti)
|
|
||||||
|
|
||||||
# 4. Update session with new JTI
|
|
||||||
session_crud.update_refresh_token(
|
|
||||||
db,
|
|
||||||
session=session,
|
|
||||||
new_jti=new_jti,
|
|
||||||
new_expires_at=datetime.now(timezone.utc) + timedelta(days=7)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Tokens refreshed for user {user_id}")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"access_token": new_access_token,
|
|
||||||
"refresh_token": new_refresh_token,
|
|
||||||
"token_type": "bearer"
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Token refresh failed: {str(e)}")
|
|
||||||
raise AuthenticationError("Failed to refresh token")
|
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 5.4 Update Logout Endpoint
|
#### 5.4 Update Logout Endpoint
|
||||||
|
|
||||||
```python
|
```python
|
||||||
@router.post("/logout")
|
@router.post("/logout")
|
||||||
def logout(
|
async def logout(
|
||||||
logout_request: LogoutRequest,
|
logout_request: LogoutRequest,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db)
|
auth_service: AuthService = Depends(get_auth_service),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
"""Logout from current device."""
|
"""Logout from current device."""
|
||||||
|
await auth_service.logout(
|
||||||
try:
|
db,
|
||||||
# Decode refresh token to get JTI
|
refresh_token=logout_request.refresh_token,
|
||||||
payload = decode_token(logout_request.refresh_token)
|
user_id=current_user.id,
|
||||||
jti = payload.get("jti")
|
|
||||||
|
|
||||||
# Find and deactivate session
|
|
||||||
session = session_crud.get_by_jti(db, jti=jti)
|
|
||||||
|
|
||||||
if session and session.user_id == current_user.id:
|
|
||||||
session_crud.deactivate(db, session_id=str(session.id))
|
|
||||||
logger.info(f"User {current_user.id} logged out from {session.device_name}")
|
|
||||||
|
|
||||||
return MessageResponse(
|
|
||||||
success=True,
|
|
||||||
message="Logged out successfully"
|
|
||||||
)
|
)
|
||||||
|
return MessageResponse(success=True, message="Logged out successfully")
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Logout failed: {str(e)}")
|
|
||||||
# Even if cleanup fails, return success (user intended to logout)
|
|
||||||
return MessageResponse(success=True, message="Logged out")
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Step 6: Add Background Jobs
|
### Step 6: Add Background Jobs
|
||||||
@@ -1287,8 +876,8 @@ def logout(
|
|||||||
Background job for cleaning up expired sessions.
|
Background job for cleaning up expired sessions.
|
||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
from app.core.database import SessionLocal
|
from app.core.database import AsyncSessionLocal
|
||||||
from app.crud.session import session as session_crud
|
from app.repositories.session import session_repo
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -1302,14 +891,12 @@ async def cleanup_expired_sessions():
|
|||||||
- Inactive (is_active = False)
|
- Inactive (is_active = False)
|
||||||
- Older than 30 days (for audit trail)
|
- Older than 30 days (for audit trail)
|
||||||
"""
|
"""
|
||||||
db = SessionLocal()
|
async with AsyncSessionLocal() as db:
|
||||||
try:
|
try:
|
||||||
count = session_crud.cleanup_expired(db, keep_days=30)
|
count = await session_repo.cleanup_expired(db, keep_days=30)
|
||||||
logger.info(f"Background cleanup: Removed {count} expired sessions")
|
logger.info(f"Background cleanup: Removed {count} expired sessions")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in session cleanup job: {str(e)}", exc_info=True)
|
logger.error(f"Error in session cleanup job: {str(e)}", exc_info=True)
|
||||||
finally:
|
|
||||||
db.close()
|
|
||||||
```
|
```
|
||||||
|
|
||||||
**Register in** `app/main.py`:
|
**Register in** `app/main.py`:
|
||||||
@@ -1679,7 +1266,8 @@ You've now implemented a complete feature! Here's what was created:
|
|||||||
**Files Created/Modified**:
|
**Files Created/Modified**:
|
||||||
1. `app/models/user_session.py` - Database model
|
1. `app/models/user_session.py` - Database model
|
||||||
2. `app/schemas/sessions.py` - Pydantic schemas
|
2. `app/schemas/sessions.py` - Pydantic schemas
|
||||||
3. `app/crud/session.py` - CRUD operations
|
3. `app/repositories/session.py` - Repository (data access)
|
||||||
|
4. `app/services/session_service.py` - Service (business logic)
|
||||||
4. `app/api/routes/sessions.py` - API endpoints
|
4. `app/api/routes/sessions.py` - API endpoints
|
||||||
5. `app/utils/device.py` - Device detection utility
|
5. `app/utils/device.py` - Device detection utility
|
||||||
6. `app/services/session_cleanup.py` - Background job
|
6. `app/services/session_cleanup.py` - Background job
|
||||||
@@ -1715,7 +1303,7 @@ You've now implemented a complete feature! Here's what was created:
|
|||||||
|
|
||||||
### Don'ts
|
### Don'ts
|
||||||
|
|
||||||
1. **Don't Mix Layers**: Keep business logic out of CRUD, database ops out of routes
|
1. **Don't Mix Layers**: Keep business logic in services, database ops in repositories, routing in routes
|
||||||
2. **Don't Expose Internals**: Never return sensitive data in API responses
|
2. **Don't Expose Internals**: Never return sensitive data in API responses
|
||||||
3. **Don't Trust Input**: Always validate and sanitize user input
|
3. **Don't Trust Input**: Always validate and sanitize user input
|
||||||
4. **Don't Ignore Errors**: Always handle exceptions properly
|
4. **Don't Ignore Errors**: Always handle exceptions properly
|
||||||
@@ -1733,7 +1321,9 @@ When implementing a new feature, use this checklist:
|
|||||||
- [ ] Design database schema
|
- [ ] Design database schema
|
||||||
- [ ] Create SQLAlchemy model
|
- [ ] Create SQLAlchemy model
|
||||||
- [ ] Design Pydantic schemas (Create, Update, Response)
|
- [ ] Design Pydantic schemas (Create, Update, Response)
|
||||||
- [ ] Implement CRUD operations
|
- [ ] Implement repository (data access)
|
||||||
|
- [ ] Implement service (business logic)
|
||||||
|
- [ ] Register service in `app/api/dependencies/services.py`
|
||||||
- [ ] Create API endpoints
|
- [ ] Create API endpoints
|
||||||
- [ ] Add authentication/authorization
|
- [ ] Add authentication/authorization
|
||||||
- [ ] Implement rate limiting
|
- [ ] Implement rate limiting
|
||||||
|
|||||||
Reference in New Issue
Block a user