forked from cardosofelipe/fast-next-template
Enhance user management, improve API structure, add database optimizations, and update Docker setup
- Introduced endpoints for user management, including CRUD operations, pagination, and password management. - Added new schema validations for user updates, password strength, pagination, and standardized error responses. - Integrated custom exception handling for a consistent API error experience. - Refined CORS settings: restricted methods and allowed headers, added header exposure, and preflight caching. - Optimized database: added indexes on `is_active` and `is_superuser` fields, updated column types, enforced constraints, and set defaults. - Updated `Dockerfile` to improve security by using a non-root user and adding a health check for the application. - Enhanced tests for database initialization, user operations, and exception handling to ensure better coverage.
This commit is contained in:
@@ -1,34 +1,66 @@
|
||||
# Development stage
|
||||
FROM python:3.12-slim AS development
|
||||
|
||||
# Create non-root user
|
||||
RUN groupadd -r appuser && useradd -r -g appuser appuser
|
||||
|
||||
WORKDIR /app
|
||||
ENV PYTHONDONTWRITEBYTECODE=1 \
|
||||
PYTHONUNBUFFERED=1 \
|
||||
PYTHONPATH=/app
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends gcc postgresql-client && \
|
||||
apt-get install -y --no-install-recommends gcc postgresql-client curl && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
COPY . .
|
||||
COPY entrypoint.sh /usr/local/bin/
|
||||
RUN chmod +x /usr/local/bin/entrypoint.sh
|
||||
|
||||
# Set ownership to non-root user
|
||||
RUN chown -R appuser:appuser /app
|
||||
|
||||
# Switch to non-root user
|
||||
USER appuser
|
||||
|
||||
ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
|
||||
|
||||
# Production stage
|
||||
FROM python:3.12-slim AS production
|
||||
|
||||
# Create non-root user
|
||||
RUN groupadd -r appuser && useradd -r -g appuser appuser
|
||||
|
||||
WORKDIR /app
|
||||
ENV PYTHONDONTWRITEBYTECODE=1 \
|
||||
PYTHONUNBUFFERED=1 \
|
||||
PYTHONPATH=/app
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends postgresql-client && \
|
||||
apt-get install -y --no-install-recommends postgresql-client curl && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
COPY . .
|
||||
COPY entrypoint.sh /usr/local/bin/
|
||||
RUN chmod +x /usr/local/bin/entrypoint.sh
|
||||
|
||||
# Set ownership to non-root user
|
||||
RUN chown -R appuser:appuser /app
|
||||
|
||||
# Switch to non-root user
|
||||
USER appuser
|
||||
|
||||
# Add health check
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=40s --retries=3 \
|
||||
CMD curl -f http://localhost:8000/health || exit 1
|
||||
|
||||
ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
@@ -0,0 +1,84 @@
|
||||
"""Add missing indexes and fix column types
|
||||
|
||||
Revision ID: 9e4f2a1b8c7d
|
||||
Revises: 38bf9e7e74b3
|
||||
Create Date: 2025-10-30 10:00:00.000000
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '9e4f2a1b8c7d'
|
||||
down_revision: Union[str, None] = '38bf9e7e74b3'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add missing indexes for is_active and is_superuser
|
||||
op.create_index(op.f('ix_users_is_active'), 'users', ['is_active'], unique=False)
|
||||
op.create_index(op.f('ix_users_is_superuser'), 'users', ['is_superuser'], unique=False)
|
||||
|
||||
# Fix column types to match model definitions with explicit lengths
|
||||
op.alter_column('users', 'email',
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(length=255),
|
||||
nullable=False)
|
||||
|
||||
op.alter_column('users', 'password_hash',
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(length=255),
|
||||
nullable=False)
|
||||
|
||||
op.alter_column('users', 'first_name',
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(length=100),
|
||||
nullable=False,
|
||||
server_default='user') # Add server default
|
||||
|
||||
op.alter_column('users', 'last_name',
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(length=100),
|
||||
nullable=True)
|
||||
|
||||
op.alter_column('users', 'phone_number',
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(length=20),
|
||||
nullable=True)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Revert column types
|
||||
op.alter_column('users', 'phone_number',
|
||||
existing_type=sa.String(length=20),
|
||||
type_=sa.String(),
|
||||
nullable=True)
|
||||
|
||||
op.alter_column('users', 'last_name',
|
||||
existing_type=sa.String(length=100),
|
||||
type_=sa.String(),
|
||||
nullable=True)
|
||||
|
||||
op.alter_column('users', 'first_name',
|
||||
existing_type=sa.String(length=100),
|
||||
type_=sa.String(),
|
||||
nullable=False,
|
||||
server_default=None) # Remove server default
|
||||
|
||||
op.alter_column('users', 'password_hash',
|
||||
existing_type=sa.String(length=255),
|
||||
type_=sa.String(),
|
||||
nullable=False)
|
||||
|
||||
op.alter_column('users', 'email',
|
||||
existing_type=sa.String(length=255),
|
||||
type_=sa.String(),
|
||||
nullable=False)
|
||||
|
||||
# Drop indexes
|
||||
op.drop_index(op.f('ix_users_is_superuser'), table_name='users')
|
||||
op.drop_index(op.f('ix_users_is_active'), table_name='users')
|
||||
@@ -1,6 +1,7 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.api.routes import auth
|
||||
from app.api.routes import auth, users
|
||||
|
||||
api_router = APIRouter()
|
||||
api_router.include_router(auth.router, prefix="/auth", tags=["auth"])
|
||||
api_router.include_router(auth.router, prefix="/auth", tags=["Authentication"])
|
||||
api_router.include_router(users.router, prefix="/users", tags=["Users"])
|
||||
|
||||
@@ -196,44 +196,6 @@ async def refresh_token(
|
||||
)
|
||||
|
||||
|
||||
@router.post("/change-password", status_code=status.HTTP_200_OK, operation_id="change_password")
|
||||
@limiter.limit("5/minute")
|
||||
async def change_password(
|
||||
request: Request,
|
||||
current_password: str = Body(..., embed=True),
|
||||
new_password: str = Body(..., embed=True),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Change current user's password.
|
||||
|
||||
Requires authentication.
|
||||
"""
|
||||
try:
|
||||
success = AuthService.change_password(
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
current_password=current_password,
|
||||
new_password=new_password
|
||||
)
|
||||
|
||||
if success:
|
||||
return {"message": "Password changed successfully"}
|
||||
except AuthenticationError as e:
|
||||
logger.warning(f"Password change failed: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during password change: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="An unexpected error occurred. Please try again later."
|
||||
)
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserResponse, operation_id="get_current_user_info")
|
||||
@limiter.limit("60/minute")
|
||||
async def get_current_user_info(
|
||||
|
||||
370
backend/app/api/routes/users.py
Normal file
370
backend/app/api/routes/users.py
Normal file
@@ -0,0 +1,370 @@
|
||||
"""
|
||||
User management endpoints for CRUD operations.
|
||||
"""
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, status, Request
|
||||
from sqlalchemy.orm import Session
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
|
||||
from app.api.dependencies.auth import get_current_user, get_current_superuser
|
||||
from app.core.database import get_db
|
||||
from app.crud.user import user as user_crud
|
||||
from app.models.user import User
|
||||
from app.schemas.users import UserResponse, UserUpdate, PasswordChange
|
||||
from app.schemas.common import (
|
||||
PaginationParams,
|
||||
PaginatedResponse,
|
||||
MessageResponse,
|
||||
create_pagination_meta
|
||||
)
|
||||
from app.services.auth_service import AuthService
|
||||
from app.core.exceptions import (
|
||||
NotFoundError,
|
||||
AuthorizationError,
|
||||
ErrorCode
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
|
||||
@router.get(
|
||||
"",
|
||||
response_model=PaginatedResponse[UserResponse],
|
||||
summary="List Users",
|
||||
description="""
|
||||
List all users with pagination (admin only).
|
||||
|
||||
**Authentication**: Required (Bearer token)
|
||||
**Authorization**: Superuser only
|
||||
|
||||
**Rate Limit**: 60 requests/minute
|
||||
""",
|
||||
operation_id="list_users"
|
||||
)
|
||||
def list_users(
|
||||
pagination: PaginationParams = Depends(),
|
||||
current_user: User = Depends(get_current_superuser),
|
||||
db: Session = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
List all users with pagination.
|
||||
|
||||
Only accessible by superusers.
|
||||
"""
|
||||
try:
|
||||
# Get paginated users with total count
|
||||
users, total = user_crud.get_multi_with_total(
|
||||
db,
|
||||
skip=pagination.offset,
|
||||
limit=pagination.limit
|
||||
)
|
||||
|
||||
# Create pagination metadata
|
||||
pagination_meta = create_pagination_meta(
|
||||
total=total,
|
||||
page=pagination.page,
|
||||
limit=pagination.limit,
|
||||
items_count=len(users)
|
||||
)
|
||||
|
||||
return PaginatedResponse(
|
||||
data=users,
|
||||
pagination=pagination_meta
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing users: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@router.get(
|
||||
"/me",
|
||||
response_model=UserResponse,
|
||||
summary="Get Current User",
|
||||
description="""
|
||||
Get the current authenticated user's profile.
|
||||
|
||||
**Authentication**: Required (Bearer token)
|
||||
|
||||
**Rate Limit**: 60 requests/minute
|
||||
""",
|
||||
operation_id="get_current_user_profile"
|
||||
)
|
||||
def get_current_user_profile(
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> Any:
|
||||
"""Get current user's profile."""
|
||||
return current_user
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/me",
|
||||
response_model=UserResponse,
|
||||
summary="Update Current User",
|
||||
description="""
|
||||
Update the current authenticated user's profile.
|
||||
|
||||
Users can update their own profile information (except is_superuser).
|
||||
|
||||
**Authentication**: Required (Bearer token)
|
||||
|
||||
**Rate Limit**: 30 requests/minute
|
||||
""",
|
||||
operation_id="update_current_user"
|
||||
)
|
||||
def update_current_user(
|
||||
user_update: UserUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Update current user's profile.
|
||||
|
||||
Users cannot elevate their own permissions (is_superuser).
|
||||
"""
|
||||
# Prevent users from making themselves superuser
|
||||
if user_update.is_superuser is not None:
|
||||
logger.warning(f"User {current_user.id} attempted to modify is_superuser field")
|
||||
raise AuthorizationError(
|
||||
message="Cannot modify superuser status",
|
||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS
|
||||
)
|
||||
|
||||
try:
|
||||
updated_user = user_crud.update(
|
||||
db,
|
||||
db_obj=current_user,
|
||||
obj_in=user_update
|
||||
)
|
||||
logger.info(f"User {current_user.id} updated their profile")
|
||||
return updated_user
|
||||
except ValueError as e:
|
||||
logger.error(f"Error updating user {current_user.id}: {str(e)}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error updating user {current_user.id}: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{user_id}",
|
||||
response_model=UserResponse,
|
||||
summary="Get User by ID",
|
||||
description="""
|
||||
Get a specific user by their ID.
|
||||
|
||||
**Authentication**: Required (Bearer token)
|
||||
**Authorization**:
|
||||
- Regular users: Can only access their own profile
|
||||
- Superusers: Can access any profile
|
||||
|
||||
**Rate Limit**: 60 requests/minute
|
||||
""",
|
||||
operation_id="get_user_by_id"
|
||||
)
|
||||
def get_user_by_id(
|
||||
user_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Get user by ID.
|
||||
|
||||
Users can only view their own profile unless they are superusers.
|
||||
"""
|
||||
# Check permissions
|
||||
if str(user_id) != str(current_user.id) and not current_user.is_superuser:
|
||||
logger.warning(
|
||||
f"User {current_user.id} attempted to access user {user_id} without permission"
|
||||
)
|
||||
raise AuthorizationError(
|
||||
message="Not enough permissions to view this user",
|
||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS
|
||||
)
|
||||
|
||||
# Get user
|
||||
user = user_crud.get(db, id=str(user_id))
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
message=f"User with id {user_id} not found",
|
||||
error_code=ErrorCode.USER_NOT_FOUND
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/{user_id}",
|
||||
response_model=UserResponse,
|
||||
summary="Update User",
|
||||
description="""
|
||||
Update a specific user by their ID.
|
||||
|
||||
**Authentication**: Required (Bearer token)
|
||||
**Authorization**:
|
||||
- Regular users: Can only update their own profile (except is_superuser)
|
||||
- Superusers: Can update any profile
|
||||
|
||||
**Rate Limit**: 30 requests/minute
|
||||
""",
|
||||
operation_id="update_user"
|
||||
)
|
||||
def update_user(
|
||||
user_id: UUID,
|
||||
user_update: UserUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Update user by ID.
|
||||
|
||||
Users can update their own profile. Superusers can update any profile.
|
||||
Regular users cannot modify is_superuser field.
|
||||
"""
|
||||
# Check permissions
|
||||
is_own_profile = str(user_id) == str(current_user.id)
|
||||
|
||||
if not is_own_profile and not current_user.is_superuser:
|
||||
logger.warning(
|
||||
f"User {current_user.id} attempted to update user {user_id} without permission"
|
||||
)
|
||||
raise AuthorizationError(
|
||||
message="Not enough permissions to update this user",
|
||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS
|
||||
)
|
||||
|
||||
# Get user
|
||||
user = user_crud.get(db, id=str(user_id))
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
message=f"User with id {user_id} not found",
|
||||
error_code=ErrorCode.USER_NOT_FOUND
|
||||
)
|
||||
|
||||
# Prevent non-superusers from modifying superuser status
|
||||
if user_update.is_superuser is not None and not current_user.is_superuser:
|
||||
logger.warning(f"User {current_user.id} attempted to modify is_superuser field")
|
||||
raise AuthorizationError(
|
||||
message="Cannot modify superuser status",
|
||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS
|
||||
)
|
||||
|
||||
try:
|
||||
updated_user = user_crud.update(db, db_obj=user, obj_in=user_update)
|
||||
logger.info(f"User {user_id} updated by {current_user.id}")
|
||||
return updated_user
|
||||
except ValueError as e:
|
||||
logger.error(f"Error updating user {user_id}: {str(e)}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error updating user {user_id}: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/me/password",
|
||||
response_model=MessageResponse,
|
||||
summary="Change Current User Password",
|
||||
description="""
|
||||
Change the current authenticated user's password.
|
||||
|
||||
Requires the current password for verification.
|
||||
|
||||
**Authentication**: Required (Bearer token)
|
||||
|
||||
**Rate Limit**: 5 requests/minute
|
||||
""",
|
||||
operation_id="change_current_user_password"
|
||||
)
|
||||
@limiter.limit("5/minute")
|
||||
def change_current_user_password(
|
||||
request: Request,
|
||||
password_change: PasswordChange,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Change current user's password.
|
||||
|
||||
Requires current password for verification.
|
||||
"""
|
||||
try:
|
||||
success = AuthService.change_password(
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
current_password=password_change.current_password,
|
||||
new_password=password_change.new_password
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info(f"User {current_user.id} changed their password")
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message="Password changed successfully"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error changing password for user {current_user.id}: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/{user_id}",
|
||||
status_code=status.HTTP_200_OK,
|
||||
response_model=MessageResponse,
|
||||
summary="Delete User",
|
||||
description="""
|
||||
Delete a specific user by their ID.
|
||||
|
||||
**Authentication**: Required (Bearer token)
|
||||
**Authorization**: Superuser only
|
||||
|
||||
**Rate Limit**: 10 requests/minute
|
||||
|
||||
**Note**: This performs a hard delete. Consider implementing soft deletes for production.
|
||||
""",
|
||||
operation_id="delete_user"
|
||||
)
|
||||
def delete_user(
|
||||
user_id: UUID,
|
||||
current_user: User = Depends(get_current_superuser),
|
||||
db: Session = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Delete user by ID (superuser only).
|
||||
|
||||
This is a hard delete operation.
|
||||
"""
|
||||
# Prevent self-deletion
|
||||
if str(user_id) == str(current_user.id):
|
||||
raise AuthorizationError(
|
||||
message="Cannot delete your own account",
|
||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS
|
||||
)
|
||||
|
||||
# Get user
|
||||
user = user_crud.get(db, id=str(user_id))
|
||||
if not user:
|
||||
raise NotFoundError(
|
||||
message=f"User with id {user_id} not found",
|
||||
error_code=ErrorCode.USER_NOT_FOUND
|
||||
)
|
||||
|
||||
try:
|
||||
user_crud.remove(db, id=str(user_id))
|
||||
logger.info(f"User {user_id} deleted by {current_user.id}")
|
||||
return MessageResponse(
|
||||
success=True,
|
||||
message=f"User {user_id} deleted successfully"
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"Error deleting user {user_id}: {str(e)}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error deleting user {user_id}: {str(e)}", exc_info=True)
|
||||
raise
|
||||
@@ -1,8 +1,10 @@
|
||||
# app/core/database.py
|
||||
import logging
|
||||
from sqlalchemy import create_engine
|
||||
from contextlib import contextmanager
|
||||
from typing import Generator
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import sessionmaker, Session
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
|
||||
@@ -49,12 +51,62 @@ def create_production_engine():
|
||||
|
||||
# Default production engine and session factory
|
||||
engine = create_production_engine()
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
SessionLocal = sessionmaker(
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
bind=engine,
|
||||
expire_on_commit=False # Prevent unnecessary queries after commit
|
||||
)
|
||||
|
||||
# FastAPI dependency
|
||||
def get_db():
|
||||
def get_db() -> Generator[Session, None, None]:
|
||||
"""
|
||||
FastAPI dependency that provides a database session.
|
||||
Automatically closes the session after the request completes.
|
||||
"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
db.close()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def transaction_scope() -> Generator[Session, None, None]:
|
||||
"""
|
||||
Provide a transactional scope for database operations.
|
||||
|
||||
Automatically commits on success or rolls back on exception.
|
||||
Useful for grouping multiple operations in a single transaction.
|
||||
|
||||
Usage:
|
||||
with transaction_scope() as db:
|
||||
user = user_crud.create(db, obj_in=user_create)
|
||||
profile = profile_crud.create(db, obj_in=profile_create)
|
||||
# Both operations committed together
|
||||
"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
db.commit()
|
||||
logger.debug("Transaction committed successfully")
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Transaction failed, rolling back: {str(e)}")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def check_database_health() -> bool:
|
||||
"""
|
||||
Check if database connection is healthy.
|
||||
Returns True if connection is successful, False otherwise.
|
||||
"""
|
||||
try:
|
||||
with transaction_scope() as db:
|
||||
db.execute(text("SELECT 1"))
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Database health check failed: {str(e)}")
|
||||
return False
|
||||
281
backend/app/core/exceptions.py
Normal file
281
backend/app/core/exceptions.py
Normal file
@@ -0,0 +1,281 @@
|
||||
"""
|
||||
Custom exceptions and global exception handlers for the API.
|
||||
"""
|
||||
import logging
|
||||
from typing import Optional, Union, List
|
||||
from fastapi import HTTPException, Request, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.schemas.errors import ErrorCode, ErrorDetail, ErrorResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class APIException(HTTPException):
|
||||
"""
|
||||
Base exception class with error code support.
|
||||
|
||||
This exception provides a standardized way to raise HTTP exceptions
|
||||
with machine-readable error codes.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int,
|
||||
error_code: ErrorCode,
|
||||
message: str,
|
||||
field: Optional[str] = None,
|
||||
headers: Optional[dict] = None
|
||||
):
|
||||
self.error_code = error_code
|
||||
self.field = field
|
||||
self.message = message
|
||||
super().__init__(
|
||||
status_code=status_code,
|
||||
detail=message,
|
||||
headers=headers
|
||||
)
|
||||
|
||||
|
||||
class AuthenticationError(APIException):
|
||||
"""Raised when authentication fails."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Authentication failed",
|
||||
error_code: ErrorCode = ErrorCode.INVALID_CREDENTIALS,
|
||||
field: Optional[str] = None
|
||||
):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
error_code=error_code,
|
||||
message=message,
|
||||
field=field,
|
||||
headers={"WWW-Authenticate": "Bearer"}
|
||||
)
|
||||
|
||||
|
||||
class AuthorizationError(APIException):
|
||||
"""Raised when user lacks required permissions."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Insufficient permissions",
|
||||
error_code: ErrorCode = ErrorCode.INSUFFICIENT_PERMISSIONS
|
||||
):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
error_code=error_code,
|
||||
message=message
|
||||
)
|
||||
|
||||
|
||||
class NotFoundError(APIException):
|
||||
"""Raised when a resource is not found."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Resource not found",
|
||||
error_code: ErrorCode = ErrorCode.NOT_FOUND
|
||||
):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
error_code=error_code,
|
||||
message=message
|
||||
)
|
||||
|
||||
|
||||
class DuplicateError(APIException):
|
||||
"""Raised when attempting to create a duplicate resource."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Resource already exists",
|
||||
error_code: ErrorCode = ErrorCode.DUPLICATE_ENTRY,
|
||||
field: Optional[str] = None
|
||||
):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
error_code=error_code,
|
||||
message=message,
|
||||
field=field
|
||||
)
|
||||
|
||||
|
||||
class ValidationException(APIException):
|
||||
"""Raised when input validation fails."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Validation error",
|
||||
error_code: ErrorCode = ErrorCode.VALIDATION_ERROR,
|
||||
field: Optional[str] = None
|
||||
):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
error_code=error_code,
|
||||
message=message,
|
||||
field=field
|
||||
)
|
||||
|
||||
|
||||
class DatabaseError(APIException):
|
||||
"""Raised when a database operation fails."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Database operation failed",
|
||||
error_code: ErrorCode = ErrorCode.DATABASE_ERROR
|
||||
):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
error_code=error_code,
|
||||
message=message
|
||||
)
|
||||
|
||||
|
||||
# Global exception handlers
|
||||
|
||||
|
||||
async def api_exception_handler(request: Request, exc: APIException) -> JSONResponse:
|
||||
"""
|
||||
Handler for APIException and its subclasses.
|
||||
|
||||
Returns a standardized error response with error code and message.
|
||||
"""
|
||||
logger.warning(
|
||||
f"API exception: {exc.error_code} - {exc.message} "
|
||||
f"(status: {exc.status_code}, path: {request.url.path})"
|
||||
)
|
||||
|
||||
error_response = ErrorResponse(
|
||||
errors=[ErrorDetail(
|
||||
code=exc.error_code,
|
||||
message=exc.message,
|
||||
field=exc.field
|
||||
)]
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content=error_response.model_dump(),
|
||||
headers=exc.headers
|
||||
)
|
||||
|
||||
|
||||
async def validation_exception_handler(
|
||||
request: Request,
|
||||
exc: Union[RequestValidationError, ValidationError]
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
Handler for Pydantic validation errors.
|
||||
|
||||
Converts Pydantic validation errors to standardized error response format.
|
||||
"""
|
||||
errors = []
|
||||
|
||||
if isinstance(exc, RequestValidationError):
|
||||
validation_errors = exc.errors()
|
||||
else:
|
||||
validation_errors = exc.errors()
|
||||
|
||||
for error in validation_errors:
|
||||
# Extract field name from error location
|
||||
field = None
|
||||
if error.get("loc") and len(error["loc"]) > 1:
|
||||
# Skip 'body' or 'query' prefix in location
|
||||
field = ".".join(str(x) for x in error["loc"][1:])
|
||||
|
||||
errors.append(ErrorDetail(
|
||||
code=ErrorCode.VALIDATION_ERROR,
|
||||
message=error["msg"],
|
||||
field=field
|
||||
))
|
||||
|
||||
logger.warning(
|
||||
f"Validation error: {len(errors)} errors "
|
||||
f"(path: {request.url.path})"
|
||||
)
|
||||
|
||||
error_response = ErrorResponse(errors=errors)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
content=error_response.model_dump()
|
||||
)
|
||||
|
||||
|
||||
async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
|
||||
"""
|
||||
Handler for standard HTTPException.
|
||||
|
||||
Converts standard FastAPI HTTPException to standardized error response format.
|
||||
"""
|
||||
# Map status codes to error codes
|
||||
status_code_to_error_code = {
|
||||
400: ErrorCode.INVALID_INPUT,
|
||||
401: ErrorCode.AUTHENTICATION_REQUIRED,
|
||||
403: ErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
404: ErrorCode.NOT_FOUND,
|
||||
405: ErrorCode.METHOD_NOT_ALLOWED,
|
||||
429: ErrorCode.RATE_LIMIT_EXCEEDED,
|
||||
500: ErrorCode.INTERNAL_ERROR,
|
||||
}
|
||||
|
||||
error_code = status_code_to_error_code.get(
|
||||
exc.status_code,
|
||||
ErrorCode.INTERNAL_ERROR
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
f"HTTP exception: {exc.status_code} - {exc.detail} "
|
||||
f"(path: {request.url.path})"
|
||||
)
|
||||
|
||||
error_response = ErrorResponse(
|
||||
errors=[ErrorDetail(
|
||||
code=error_code,
|
||||
message=str(exc.detail)
|
||||
)]
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content=error_response.model_dump(),
|
||||
headers=exc.headers
|
||||
)
|
||||
|
||||
|
||||
async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse:
|
||||
"""
|
||||
Handler for unhandled exceptions.
|
||||
|
||||
Logs the full exception and returns a generic error response to avoid
|
||||
leaking sensitive information in production.
|
||||
"""
|
||||
logger.error(
|
||||
f"Unhandled exception: {type(exc).__name__} - {str(exc)} "
|
||||
f"(path: {request.url.path})",
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
# In production, don't expose internal error details
|
||||
from app.core.config import settings
|
||||
if settings.ENVIRONMENT == "production":
|
||||
message = "An internal error occurred. Please try again later."
|
||||
else:
|
||||
message = f"{type(exc).__name__}: {str(exc)}"
|
||||
|
||||
error_response = ErrorResponse(
|
||||
errors=[ErrorDetail(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message=message
|
||||
)]
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content=error_response.model_dump()
|
||||
)
|
||||
@@ -1,8 +1,14 @@
|
||||
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union
|
||||
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union, Tuple
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError, DataError
|
||||
from sqlalchemy import func
|
||||
from app.core.database import Base
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ModelType = TypeVar("ModelType", bound=Base)
|
||||
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
|
||||
@@ -20,20 +26,63 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
self.model = model
|
||||
|
||||
def get(self, db: Session, id: str) -> Optional[ModelType]:
|
||||
return db.query(self.model).filter(self.model.id == id).first()
|
||||
"""Get a single record by ID with UUID validation."""
|
||||
# Validate UUID format
|
||||
try:
|
||||
uuid.UUID(id)
|
||||
except (ValueError, AttributeError):
|
||||
logger.warning(f"Invalid UUID format: {id}")
|
||||
return None
|
||||
|
||||
try:
|
||||
return db.query(self.model).filter(self.model.id == id).first()
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving {self.model.__name__} with id {id}: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_multi(
|
||||
self, db: Session, *, skip: int = 0, limit: int = 100
|
||||
) -> List[ModelType]:
|
||||
return db.query(self.model).offset(skip).limit(limit).all()
|
||||
"""Get multiple records with pagination validation."""
|
||||
# Validate pagination parameters
|
||||
if skip < 0:
|
||||
raise ValueError("skip must be non-negative")
|
||||
if limit < 0:
|
||||
raise ValueError("limit must be non-negative")
|
||||
if limit > 1000:
|
||||
raise ValueError("Maximum limit is 1000")
|
||||
|
||||
try:
|
||||
return db.query(self.model).offset(skip).limit(limit).all()
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving multiple {self.model.__name__} records: {str(e)}")
|
||||
raise
|
||||
|
||||
def create(self, db: Session, *, obj_in: CreateSchemaType) -> ModelType:
|
||||
obj_in_data = jsonable_encoder(obj_in)
|
||||
db_obj = self.model(**obj_in_data)
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
return db_obj
|
||||
"""Create a new record with error handling."""
|
||||
try:
|
||||
obj_in_data = jsonable_encoder(obj_in)
|
||||
db_obj = self.model(**obj_in_data)
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
||||
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
|
||||
logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(f"A {self.model.__name__} with this data already exists")
|
||||
logger.error(f"Integrity error creating {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
except (OperationalError, DataError) as e:
|
||||
db.rollback()
|
||||
logger.error(f"Database error creating {self.model.__name__}: {str(e)}")
|
||||
raise ValueError(f"Database operation failed: {str(e)}")
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Unexpected error creating {self.model.__name__}: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def update(
|
||||
self,
|
||||
@@ -42,21 +91,90 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
db_obj: ModelType,
|
||||
obj_in: Union[UpdateSchemaType, Dict[str, Any]]
|
||||
) -> ModelType:
|
||||
obj_data = jsonable_encoder(db_obj)
|
||||
if isinstance(obj_in, dict):
|
||||
update_data = obj_in
|
||||
else:
|
||||
update_data = obj_in.model_dump(exclude_unset=True)
|
||||
for field in obj_data:
|
||||
if field in update_data:
|
||||
setattr(db_obj, field, update_data[field])
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
return db_obj
|
||||
"""Update a record with error handling."""
|
||||
try:
|
||||
obj_data = jsonable_encoder(db_obj)
|
||||
if isinstance(obj_in, dict):
|
||||
update_data = obj_in
|
||||
else:
|
||||
update_data = obj_in.model_dump(exclude_unset=True)
|
||||
for field in obj_data:
|
||||
if field in update_data:
|
||||
setattr(db_obj, field, update_data[field])
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
||||
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
|
||||
logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(f"A {self.model.__name__} with this data already exists")
|
||||
logger.error(f"Integrity error updating {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
except (OperationalError, DataError) as e:
|
||||
db.rollback()
|
||||
logger.error(f"Database error updating {self.model.__name__}: {str(e)}")
|
||||
raise ValueError(f"Database operation failed: {str(e)}")
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Unexpected error updating {self.model.__name__}: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def remove(self, db: Session, *, id: str) -> ModelType:
|
||||
obj = db.query(self.model).get(id)
|
||||
db.delete(obj)
|
||||
db.commit()
|
||||
return obj
|
||||
def remove(self, db: Session, *, id: str) -> Optional[ModelType]:
|
||||
"""Delete a record with error handling and null check."""
|
||||
# Validate UUID format
|
||||
try:
|
||||
uuid.UUID(id)
|
||||
except (ValueError, AttributeError):
|
||||
logger.warning(f"Invalid UUID format for deletion: {id}")
|
||||
return None
|
||||
|
||||
try:
|
||||
obj = db.query(self.model).filter(self.model.id == id).first()
|
||||
if obj is None:
|
||||
logger.warning(f"{self.model.__name__} with id {id} not found for deletion")
|
||||
return None
|
||||
|
||||
db.delete(obj)
|
||||
db.commit()
|
||||
return obj
|
||||
except IntegrityError as e:
|
||||
db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
||||
logger.error(f"Integrity error deleting {self.model.__name__}: {error_msg}")
|
||||
raise ValueError(f"Cannot delete {self.model.__name__}: referenced by other records")
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Error deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def get_multi_with_total(
|
||||
self, db: Session, *, skip: int = 0, limit: int = 100
|
||||
) -> Tuple[List[ModelType], int]:
|
||||
"""
|
||||
Get multiple records with total count for pagination.
|
||||
|
||||
Returns:
|
||||
Tuple of (items, total_count)
|
||||
"""
|
||||
# Validate pagination parameters
|
||||
if skip < 0:
|
||||
raise ValueError("skip must be non-negative")
|
||||
if limit < 0:
|
||||
raise ValueError("limit must be non-negative")
|
||||
if limit > 1000:
|
||||
raise ValueError("Maximum limit is 1000")
|
||||
|
||||
try:
|
||||
# Get total count
|
||||
total = db.query(func.count(self.model.id)).scalar()
|
||||
|
||||
# Get paginated items
|
||||
items = db.query(self.model).offset(skip).limit(limit).all()
|
||||
|
||||
return items, total
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving paginated {self.model.__name__} records: {str(e)}")
|
||||
raise
|
||||
@@ -1,10 +1,14 @@
|
||||
# app/crud/user.py
|
||||
from typing import Optional, Union, Dict, Any
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from app.crud.base import CRUDBase
|
||||
from app.models.user import User
|
||||
from app.schemas.users import UserCreate, UserUpdate
|
||||
from app.core.auth import get_password_hash
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
@@ -12,19 +16,33 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
return db.query(User).filter(User.email == email).first()
|
||||
|
||||
def create(self, db: Session, *, obj_in: UserCreate) -> User:
|
||||
db_obj = User(
|
||||
email=obj_in.email,
|
||||
password_hash=get_password_hash(obj_in.password),
|
||||
first_name=obj_in.first_name,
|
||||
last_name=obj_in.last_name,
|
||||
phone_number=obj_in.phone_number if hasattr(obj_in, 'phone_number') else None,
|
||||
is_superuser=obj_in.is_superuser if hasattr(obj_in, 'is_superuser') else False,
|
||||
preferences={}
|
||||
)
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
return db_obj
|
||||
"""Create a new user with password hashing and error handling."""
|
||||
try:
|
||||
db_obj = User(
|
||||
email=obj_in.email,
|
||||
password_hash=get_password_hash(obj_in.password),
|
||||
first_name=obj_in.first_name,
|
||||
last_name=obj_in.last_name,
|
||||
phone_number=obj_in.phone_number if hasattr(obj_in, 'phone_number') else None,
|
||||
is_superuser=obj_in.is_superuser if hasattr(obj_in, 'is_superuser') else False,
|
||||
preferences={}
|
||||
)
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
|
||||
if "email" in error_msg.lower():
|
||||
logger.warning(f"Duplicate email attempted: {obj_in.email}")
|
||||
raise ValueError(f"User with email {obj_in.email} already exists")
|
||||
logger.error(f"Integrity error creating user: {error_msg}")
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Unexpected error creating user: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def update(
|
||||
self,
|
||||
|
||||
@@ -19,7 +19,7 @@ def init_db(db: Session) -> Optional[UserCreate]:
|
||||
"""
|
||||
# Use default values if not set in environment variables
|
||||
superuser_email = settings.FIRST_SUPERUSER_EMAIL or "admin@example.com"
|
||||
superuser_password = settings.FIRST_SUPERUSER_PASSWORD or "admin123"
|
||||
superuser_password = settings.FIRST_SUPERUSER_PASSWORD or "Admin123!Change"
|
||||
|
||||
if not settings.FIRST_SUPERUSER_EMAIL or not settings.FIRST_SUPERUSER_PASSWORD:
|
||||
logger.warning(
|
||||
|
||||
@@ -3,9 +3,10 @@ from datetime import datetime
|
||||
from typing import Dict, Any
|
||||
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from fastapi import FastAPI, status, Request
|
||||
from fastapi import FastAPI, status, Request, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import HTMLResponse, JSONResponse
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from slowapi import Limiter, _rate_limit_exceeded_handler
|
||||
from slowapi.util import get_remote_address
|
||||
from slowapi.errors import RateLimitExceeded
|
||||
@@ -13,7 +14,14 @@ from sqlalchemy import text
|
||||
|
||||
from app.api.main import api_router
|
||||
from app.core.config import settings
|
||||
from app.core.database import get_db
|
||||
from app.core.database import get_db, check_database_health
|
||||
from app.core.exceptions import (
|
||||
APIException,
|
||||
api_exception_handler,
|
||||
validation_exception_handler,
|
||||
http_exception_handler,
|
||||
unhandled_exception_handler
|
||||
)
|
||||
|
||||
scheduler = AsyncIOScheduler()
|
||||
|
||||
@@ -33,13 +41,30 @@ app = FastAPI(
|
||||
app.state.limiter = limiter
|
||||
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
||||
|
||||
# Set up CORS middleware
|
||||
# Register custom exception handlers (order matters - most specific first)
|
||||
app.add_exception_handler(APIException, api_exception_handler)
|
||||
app.add_exception_handler(RequestValidationError, validation_exception_handler)
|
||||
app.add_exception_handler(HTTPException, http_exception_handler)
|
||||
app.add_exception_handler(Exception, unhandled_exception_handler)
|
||||
|
||||
# Set up CORS middleware with explicit allowed methods and headers
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.BACKEND_CORS_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"], # Explicit methods only
|
||||
allow_headers=[
|
||||
"Content-Type",
|
||||
"Authorization",
|
||||
"Accept",
|
||||
"Origin",
|
||||
"User-Agent",
|
||||
"DNT",
|
||||
"Cache-Control",
|
||||
"X-Requested-With",
|
||||
], # Explicit headers only
|
||||
expose_headers=["Content-Length"],
|
||||
max_age=600, # Cache preflight requests for 10 minutes
|
||||
)
|
||||
|
||||
|
||||
@@ -120,15 +145,16 @@ async def health_check() -> JSONResponse:
|
||||
|
||||
response_status = status.HTTP_200_OK
|
||||
|
||||
# Database health check
|
||||
# Database health check using dedicated health check function
|
||||
try:
|
||||
db = next(get_db())
|
||||
db.execute(text("SELECT 1"))
|
||||
health_status["checks"]["database"] = {
|
||||
"status": "healthy",
|
||||
"message": "Database connection successful"
|
||||
}
|
||||
db.close()
|
||||
db_healthy = check_database_health()
|
||||
if db_healthy:
|
||||
health_status["checks"]["database"] = {
|
||||
"status": "healthy",
|
||||
"message": "Database connection successful"
|
||||
}
|
||||
else:
|
||||
raise Exception("Database health check returned unhealthy status")
|
||||
except Exception as e:
|
||||
health_status["status"] = "unhealthy"
|
||||
health_status["checks"]["database"] = {
|
||||
|
||||
139
backend/app/schemas/common.py
Normal file
139
backend/app/schemas/common.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""
|
||||
Common schemas used across the API for pagination, responses, etc.
|
||||
"""
|
||||
from typing import Generic, TypeVar, List, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from math import ceil
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
class PaginationParams(BaseModel):
|
||||
"""Parameters for pagination."""
|
||||
|
||||
page: int = Field(
|
||||
default=1,
|
||||
ge=1,
|
||||
description="Page number (1-indexed)"
|
||||
)
|
||||
limit: int = Field(
|
||||
default=20,
|
||||
ge=1,
|
||||
le=100,
|
||||
description="Number of items per page (max 100)"
|
||||
)
|
||||
|
||||
@property
|
||||
def offset(self) -> int:
|
||||
"""Calculate the offset for database queries."""
|
||||
return (self.page - 1) * self.limit
|
||||
|
||||
@property
|
||||
def skip(self) -> int:
|
||||
"""Alias for offset (compatibility with existing code)."""
|
||||
return self.offset
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"page": 1,
|
||||
"limit": 20
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class PaginationMeta(BaseModel):
|
||||
"""Metadata for paginated responses."""
|
||||
|
||||
total: int = Field(..., description="Total number of items")
|
||||
page: int = Field(..., description="Current page number")
|
||||
page_size: int = Field(..., description="Number of items in current page")
|
||||
total_pages: int = Field(..., description="Total number of pages")
|
||||
has_next: bool = Field(..., description="Whether there is a next page")
|
||||
has_prev: bool = Field(..., description="Whether there is a previous page")
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"total": 150,
|
||||
"page": 1,
|
||||
"page_size": 20,
|
||||
"total_pages": 8,
|
||||
"has_next": True,
|
||||
"has_prev": False
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class PaginatedResponse(BaseModel, Generic[T]):
|
||||
"""Generic paginated response wrapper."""
|
||||
|
||||
data: List[T] = Field(..., description="List of items")
|
||||
pagination: PaginationMeta = Field(..., description="Pagination metadata")
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"data": [
|
||||
{"id": "123", "name": "Example Item"}
|
||||
],
|
||||
"pagination": {
|
||||
"total": 150,
|
||||
"page": 1,
|
||||
"page_size": 20,
|
||||
"total_pages": 8,
|
||||
"has_next": True,
|
||||
"has_prev": False
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class MessageResponse(BaseModel):
|
||||
"""Simple message response."""
|
||||
|
||||
success: bool = Field(default=True, description="Operation success status")
|
||||
message: str = Field(..., description="Human-readable message")
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"success": True,
|
||||
"message": "Operation completed successfully"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def create_pagination_meta(
|
||||
total: int,
|
||||
page: int,
|
||||
limit: int,
|
||||
items_count: int
|
||||
) -> PaginationMeta:
|
||||
"""
|
||||
Helper function to create pagination metadata.
|
||||
|
||||
Args:
|
||||
total: Total number of items
|
||||
page: Current page number
|
||||
limit: Items per page
|
||||
items_count: Number of items in current page
|
||||
|
||||
Returns:
|
||||
PaginationMeta object with calculated values
|
||||
"""
|
||||
total_pages = ceil(total / limit) if limit > 0 else 0
|
||||
|
||||
return PaginationMeta(
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=items_count,
|
||||
total_pages=total_pages,
|
||||
has_next=page < total_pages,
|
||||
has_prev=page > 1
|
||||
)
|
||||
85
backend/app/schemas/errors.py
Normal file
85
backend/app/schemas/errors.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""
|
||||
Error schemas for standardized API error responses.
|
||||
"""
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ErrorCode(str, Enum):
|
||||
"""Standard error codes for the API."""
|
||||
|
||||
# Authentication errors (AUTH_xxx)
|
||||
INVALID_CREDENTIALS = "AUTH_001"
|
||||
TOKEN_EXPIRED = "AUTH_002"
|
||||
TOKEN_INVALID = "AUTH_003"
|
||||
INSUFFICIENT_PERMISSIONS = "AUTH_004"
|
||||
USER_INACTIVE = "AUTH_005"
|
||||
AUTHENTICATION_REQUIRED = "AUTH_006"
|
||||
|
||||
# User errors (USER_xxx)
|
||||
USER_NOT_FOUND = "USER_001"
|
||||
USER_ALREADY_EXISTS = "USER_002"
|
||||
USER_CREATION_FAILED = "USER_003"
|
||||
USER_UPDATE_FAILED = "USER_004"
|
||||
USER_DELETION_FAILED = "USER_005"
|
||||
|
||||
# Validation errors (VAL_xxx)
|
||||
VALIDATION_ERROR = "VAL_001"
|
||||
INVALID_PASSWORD = "VAL_002"
|
||||
INVALID_EMAIL = "VAL_003"
|
||||
INVALID_PHONE_NUMBER = "VAL_004"
|
||||
INVALID_UUID = "VAL_005"
|
||||
INVALID_INPUT = "VAL_006"
|
||||
|
||||
# Database errors (DB_xxx)
|
||||
DATABASE_ERROR = "DB_001"
|
||||
DUPLICATE_ENTRY = "DB_002"
|
||||
FOREIGN_KEY_VIOLATION = "DB_003"
|
||||
RECORD_NOT_FOUND = "DB_004"
|
||||
|
||||
# Generic errors (SYS_xxx)
|
||||
INTERNAL_ERROR = "SYS_001"
|
||||
NOT_FOUND = "SYS_002"
|
||||
METHOD_NOT_ALLOWED = "SYS_003"
|
||||
RATE_LIMIT_EXCEEDED = "SYS_004"
|
||||
|
||||
|
||||
class ErrorDetail(BaseModel):
|
||||
"""Detailed information about a single error."""
|
||||
|
||||
code: ErrorCode = Field(..., description="Machine-readable error code")
|
||||
message: str = Field(..., description="Human-readable error message")
|
||||
field: Optional[str] = Field(None, description="Field name if error is field-specific")
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"code": "VAL_002",
|
||||
"message": "Password must be at least 8 characters long",
|
||||
"field": "password"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
"""Standardized error response format."""
|
||||
|
||||
success: bool = Field(default=False, description="Always false for error responses")
|
||||
errors: List[ErrorDetail] = Field(..., description="List of errors that occurred")
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"success": False,
|
||||
"errors": [
|
||||
{
|
||||
"code": "AUTH_001",
|
||||
"message": "Invalid email or password",
|
||||
"field": None
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -123,7 +123,26 @@ class TokenData(BaseModel):
|
||||
is_superuser: bool = False
|
||||
|
||||
|
||||
class PasswordChange(BaseModel):
|
||||
"""Schema for changing password (requires current password)."""
|
||||
current_password: str
|
||||
new_password: str
|
||||
|
||||
@field_validator('new_password')
|
||||
@classmethod
|
||||
def password_strength(cls, v: str) -> str:
|
||||
"""Basic password strength validation"""
|
||||
if len(v) < 8:
|
||||
raise ValueError('Password must be at least 8 characters')
|
||||
if not any(char.isdigit() for char in v):
|
||||
raise ValueError('Password must contain at least one digit')
|
||||
if not any(char.isupper() for char in v):
|
||||
raise ValueError('Password must contain at least one uppercase letter')
|
||||
return v
|
||||
|
||||
|
||||
class PasswordReset(BaseModel):
|
||||
"""Schema for resetting password (via email token)."""
|
||||
token: str
|
||||
new_password: str
|
||||
|
||||
|
||||
1
backend/coverage.json
Normal file
1
backend/coverage.json
Normal file
File diff suppressed because one or more lines are too long
223
backend/tests/test_init_db.py
Normal file
223
backend/tests/test_init_db.py
Normal file
@@ -0,0 +1,223 @@
|
||||
# tests/test_init_db.py
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.init_db import init_db
|
||||
from app.models.user import User
|
||||
from app.schemas.users import UserCreate
|
||||
|
||||
|
||||
class TestInitDB:
|
||||
"""Tests for database initialization"""
|
||||
|
||||
def test_init_db_creates_superuser_when_not_exists(self, db_session, monkeypatch):
|
||||
"""Test that init_db creates superuser when it doesn't exist"""
|
||||
# Set environment variables
|
||||
monkeypatch.setenv("FIRST_SUPERUSER_EMAIL", "admin@test.com")
|
||||
monkeypatch.setenv("FIRST_SUPERUSER_PASSWORD", "TestPassword123!")
|
||||
|
||||
# Reload settings to pick up environment variables
|
||||
from app.core import config
|
||||
import importlib
|
||||
importlib.reload(config)
|
||||
from app.core.config import settings
|
||||
|
||||
# Mock user_crud to return None (user doesn't exist)
|
||||
with patch('app.init_db.user_crud') as mock_crud:
|
||||
mock_crud.get_by_email.return_value = None
|
||||
|
||||
# Create a mock user to return from create
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
mock_user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="admin@test.com",
|
||||
password_hash="hashed",
|
||||
first_name="Admin",
|
||||
last_name="User",
|
||||
is_active=True,
|
||||
is_superuser=True,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
mock_crud.create.return_value = mock_user
|
||||
|
||||
# Call init_db
|
||||
user = init_db(db_session)
|
||||
|
||||
# Verify user was created
|
||||
assert user is not None
|
||||
assert user.email == "admin@test.com"
|
||||
assert user.is_superuser is True
|
||||
mock_crud.create.assert_called_once()
|
||||
|
||||
def test_init_db_returns_existing_superuser(self, db_session, monkeypatch):
|
||||
"""Test that init_db returns existing superuser without creating new one"""
|
||||
# Set environment variables
|
||||
monkeypatch.setenv("FIRST_SUPERUSER_EMAIL", "existing@test.com")
|
||||
monkeypatch.setenv("FIRST_SUPERUSER_PASSWORD", "TestPassword123!")
|
||||
|
||||
# Reload settings
|
||||
from app.core import config
|
||||
import importlib
|
||||
importlib.reload(config)
|
||||
|
||||
# Mock user_crud to return existing user
|
||||
with patch('app.init_db.user_crud') as mock_crud:
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
existing_user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="existing@test.com",
|
||||
password_hash="hashed",
|
||||
first_name="Existing",
|
||||
last_name="User",
|
||||
is_active=True,
|
||||
is_superuser=True,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
mock_crud.get_by_email.return_value = existing_user
|
||||
|
||||
# Call init_db
|
||||
user = init_db(db_session)
|
||||
|
||||
# Verify existing user was returned
|
||||
assert user is not None
|
||||
assert user.email == "existing@test.com"
|
||||
# create should NOT be called
|
||||
mock_crud.create.assert_not_called()
|
||||
|
||||
def test_init_db_uses_defaults_when_env_not_set(self, db_session):
|
||||
"""Test that init_db uses default credentials when env vars not set"""
|
||||
# Mock settings to return None for superuser credentials
|
||||
with patch('app.init_db.settings') as mock_settings:
|
||||
mock_settings.FIRST_SUPERUSER_EMAIL = None
|
||||
mock_settings.FIRST_SUPERUSER_PASSWORD = None
|
||||
|
||||
# Mock user_crud
|
||||
with patch('app.init_db.user_crud') as mock_crud:
|
||||
mock_crud.get_by_email.return_value = None
|
||||
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
mock_user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="admin@example.com",
|
||||
password_hash="hashed",
|
||||
first_name="Admin",
|
||||
last_name="User",
|
||||
is_active=True,
|
||||
is_superuser=True,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
mock_crud.create.return_value = mock_user
|
||||
|
||||
# Call init_db
|
||||
with patch('app.init_db.logger') as mock_logger:
|
||||
user = init_db(db_session)
|
||||
|
||||
# Verify default email was used
|
||||
mock_crud.get_by_email.assert_called_with(db_session, email="admin@example.com")
|
||||
# Verify warning was logged since credentials not set
|
||||
assert mock_logger.warning.called
|
||||
|
||||
def test_init_db_handles_creation_error(self, db_session, monkeypatch):
|
||||
"""Test that init_db handles errors during user creation"""
|
||||
# Set environment variables
|
||||
monkeypatch.setenv("FIRST_SUPERUSER_EMAIL", "admin@test.com")
|
||||
monkeypatch.setenv("FIRST_SUPERUSER_PASSWORD", "TestPassword123!")
|
||||
|
||||
# Reload settings
|
||||
from app.core import config
|
||||
import importlib
|
||||
importlib.reload(config)
|
||||
|
||||
# Mock user_crud to raise an exception
|
||||
with patch('app.init_db.user_crud') as mock_crud:
|
||||
mock_crud.get_by_email.return_value = None
|
||||
mock_crud.create.side_effect = Exception("Database error")
|
||||
|
||||
# Call init_db and expect exception
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
init_db(db_session)
|
||||
|
||||
assert "Database error" in str(exc_info.value)
|
||||
|
||||
def test_init_db_logs_superuser_creation(self, db_session, monkeypatch):
|
||||
"""Test that init_db logs appropriate messages"""
|
||||
# Set environment variables
|
||||
monkeypatch.setenv("FIRST_SUPERUSER_EMAIL", "admin@test.com")
|
||||
monkeypatch.setenv("FIRST_SUPERUSER_PASSWORD", "TestPassword123!")
|
||||
|
||||
# Reload settings
|
||||
from app.core import config
|
||||
import importlib
|
||||
importlib.reload(config)
|
||||
|
||||
# Mock user_crud
|
||||
with patch('app.init_db.user_crud') as mock_crud:
|
||||
mock_crud.get_by_email.return_value = None
|
||||
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
mock_user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="admin@test.com",
|
||||
password_hash="hashed",
|
||||
first_name="Admin",
|
||||
last_name="User",
|
||||
is_active=True,
|
||||
is_superuser=True,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
mock_crud.create.return_value = mock_user
|
||||
|
||||
# Call init_db with logger mock
|
||||
with patch('app.init_db.logger') as mock_logger:
|
||||
user = init_db(db_session)
|
||||
|
||||
# Verify info log was called
|
||||
assert mock_logger.info.called
|
||||
info_call_args = str(mock_logger.info.call_args)
|
||||
assert "Created first superuser" in info_call_args
|
||||
|
||||
def test_init_db_logs_existing_user(self, db_session, monkeypatch):
|
||||
"""Test that init_db logs when user already exists"""
|
||||
# Set environment variables
|
||||
monkeypatch.setenv("FIRST_SUPERUSER_EMAIL", "existing@test.com")
|
||||
monkeypatch.setenv("FIRST_SUPERUSER_PASSWORD", "TestPassword123!")
|
||||
|
||||
# Reload settings
|
||||
from app.core import config
|
||||
import importlib
|
||||
importlib.reload(config)
|
||||
|
||||
# Mock user_crud to return existing user
|
||||
with patch('app.init_db.user_crud') as mock_crud:
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
existing_user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="existing@test.com",
|
||||
password_hash="hashed",
|
||||
first_name="Existing",
|
||||
last_name="User",
|
||||
is_active=True,
|
||||
is_superuser=True,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
mock_crud.get_by_email.return_value = existing_user
|
||||
|
||||
# Call init_db with logger mock
|
||||
with patch('app.init_db.logger') as mock_logger:
|
||||
user = init_db(db_session)
|
||||
|
||||
# Verify info log was called
|
||||
assert mock_logger.info.called
|
||||
info_call_args = str(mock_logger.info.call_args)
|
||||
assert "already exists" in info_call_args.lower()
|
||||
Reference in New Issue
Block a user