Enhance user management, improve API structure, add database optimizations, and update Docker setup

- Introduced endpoints for user management, including CRUD operations, pagination, and password management.
- Added new schema validations for user updates, password strength, pagination, and standardized error responses.
- Integrated custom exception handling for a consistent API error experience.
- Refined CORS settings: restricted methods and allowed headers, added header exposure, and preflight caching.
- Optimized database: added indexes on `is_active` and `is_superuser` fields, updated column types, enforced constraints, and set defaults.
- Updated `Dockerfile` to improve security by using a non-root user and adding a health check for the application.
- Enhanced tests for database initialization, user operations, and exception handling to ensure better coverage.
This commit is contained in:
Felipe Cardoso
2025-10-30 15:43:52 +01:00
parent d83959963b
commit 2c600290a1
16 changed files with 1511 additions and 100 deletions

View File

@@ -1,34 +1,66 @@
# Development stage
FROM python:3.12-slim AS development
# Create non-root user
RUN groupadd -r appuser && useradd -r -g appuser appuser
WORKDIR /app
ENV PYTHONDONTWRITEBYTECODE=1 \
PYTHONUNBUFFERED=1 \
PYTHONPATH=/app
RUN apt-get update && \
apt-get install -y --no-install-recommends gcc postgresql-client && \
apt-get install -y --no-install-recommends gcc postgresql-client curl && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
COPY . .
COPY entrypoint.sh /usr/local/bin/
RUN chmod +x /usr/local/bin/entrypoint.sh
# Set ownership to non-root user
RUN chown -R appuser:appuser /app
# Switch to non-root user
USER appuser
ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
# Production stage
FROM python:3.12-slim AS production
# Create non-root user
RUN groupadd -r appuser && useradd -r -g appuser appuser
WORKDIR /app
ENV PYTHONDONTWRITEBYTECODE=1 \
PYTHONUNBUFFERED=1 \
PYTHONPATH=/app
RUN apt-get update && \
apt-get install -y --no-install-recommends postgresql-client && \
apt-get install -y --no-install-recommends postgresql-client curl && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
COPY . .
COPY entrypoint.sh /usr/local/bin/
RUN chmod +x /usr/local/bin/entrypoint.sh
# Set ownership to non-root user
RUN chown -R appuser:appuser /app
# Switch to non-root user
USER appuser
# Add health check
HEALTHCHECK --interval=30s --timeout=10s --start-period=40s --retries=3 \
CMD curl -f http://localhost:8000/health || exit 1
ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]

View File

@@ -0,0 +1,84 @@
"""Add missing indexes and fix column types
Revision ID: 9e4f2a1b8c7d
Revises: 38bf9e7e74b3
Create Date: 2025-10-30 10:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '9e4f2a1b8c7d'
down_revision: Union[str, None] = '38bf9e7e74b3'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Add missing indexes for is_active and is_superuser
op.create_index(op.f('ix_users_is_active'), 'users', ['is_active'], unique=False)
op.create_index(op.f('ix_users_is_superuser'), 'users', ['is_superuser'], unique=False)
# Fix column types to match model definitions with explicit lengths
op.alter_column('users', 'email',
existing_type=sa.String(),
type_=sa.String(length=255),
nullable=False)
op.alter_column('users', 'password_hash',
existing_type=sa.String(),
type_=sa.String(length=255),
nullable=False)
op.alter_column('users', 'first_name',
existing_type=sa.String(),
type_=sa.String(length=100),
nullable=False,
server_default='user') # Add server default
op.alter_column('users', 'last_name',
existing_type=sa.String(),
type_=sa.String(length=100),
nullable=True)
op.alter_column('users', 'phone_number',
existing_type=sa.String(),
type_=sa.String(length=20),
nullable=True)
def downgrade() -> None:
# Revert column types
op.alter_column('users', 'phone_number',
existing_type=sa.String(length=20),
type_=sa.String(),
nullable=True)
op.alter_column('users', 'last_name',
existing_type=sa.String(length=100),
type_=sa.String(),
nullable=True)
op.alter_column('users', 'first_name',
existing_type=sa.String(length=100),
type_=sa.String(),
nullable=False,
server_default=None) # Remove server default
op.alter_column('users', 'password_hash',
existing_type=sa.String(length=255),
type_=sa.String(),
nullable=False)
op.alter_column('users', 'email',
existing_type=sa.String(length=255),
type_=sa.String(),
nullable=False)
# Drop indexes
op.drop_index(op.f('ix_users_is_superuser'), table_name='users')
op.drop_index(op.f('ix_users_is_active'), table_name='users')

View File

@@ -1,6 +1,7 @@
from fastapi import APIRouter
from app.api.routes import auth
from app.api.routes import auth, users
api_router = APIRouter()
api_router.include_router(auth.router, prefix="/auth", tags=["auth"])
api_router.include_router(auth.router, prefix="/auth", tags=["Authentication"])
api_router.include_router(users.router, prefix="/users", tags=["Users"])

View File

@@ -196,44 +196,6 @@ async def refresh_token(
)
@router.post("/change-password", status_code=status.HTTP_200_OK, operation_id="change_password")
@limiter.limit("5/minute")
async def change_password(
request: Request,
current_password: str = Body(..., embed=True),
new_password: str = Body(..., embed=True),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
) -> Any:
"""
Change current user's password.
Requires authentication.
"""
try:
success = AuthService.change_password(
db=db,
user_id=current_user.id,
current_password=current_password,
new_password=new_password
)
if success:
return {"message": "Password changed successfully"}
except AuthenticationError as e:
logger.warning(f"Password change failed: {str(e)}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e)
)
except Exception as e:
logger.error(f"Unexpected error during password change: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An unexpected error occurred. Please try again later."
)
@router.get("/me", response_model=UserResponse, operation_id="get_current_user_info")
@limiter.limit("60/minute")
async def get_current_user_info(

View File

@@ -0,0 +1,370 @@
"""
User management endpoints for CRUD operations.
"""
import logging
from typing import Any
from uuid import UUID
from fastapi import APIRouter, Depends, Query, status, Request
from sqlalchemy.orm import Session
from slowapi import Limiter
from slowapi.util import get_remote_address
from app.api.dependencies.auth import get_current_user, get_current_superuser
from app.core.database import get_db
from app.crud.user import user as user_crud
from app.models.user import User
from app.schemas.users import UserResponse, UserUpdate, PasswordChange
from app.schemas.common import (
PaginationParams,
PaginatedResponse,
MessageResponse,
create_pagination_meta
)
from app.services.auth_service import AuthService
from app.core.exceptions import (
NotFoundError,
AuthorizationError,
ErrorCode
)
logger = logging.getLogger(__name__)
router = APIRouter()
limiter = Limiter(key_func=get_remote_address)
@router.get(
"",
response_model=PaginatedResponse[UserResponse],
summary="List Users",
description="""
List all users with pagination (admin only).
**Authentication**: Required (Bearer token)
**Authorization**: Superuser only
**Rate Limit**: 60 requests/minute
""",
operation_id="list_users"
)
def list_users(
pagination: PaginationParams = Depends(),
current_user: User = Depends(get_current_superuser),
db: Session = Depends(get_db)
) -> Any:
"""
List all users with pagination.
Only accessible by superusers.
"""
try:
# Get paginated users with total count
users, total = user_crud.get_multi_with_total(
db,
skip=pagination.offset,
limit=pagination.limit
)
# Create pagination metadata
pagination_meta = create_pagination_meta(
total=total,
page=pagination.page,
limit=pagination.limit,
items_count=len(users)
)
return PaginatedResponse(
data=users,
pagination=pagination_meta
)
except Exception as e:
logger.error(f"Error listing users: {str(e)}", exc_info=True)
raise
@router.get(
"/me",
response_model=UserResponse,
summary="Get Current User",
description="""
Get the current authenticated user's profile.
**Authentication**: Required (Bearer token)
**Rate Limit**: 60 requests/minute
""",
operation_id="get_current_user_profile"
)
def get_current_user_profile(
current_user: User = Depends(get_current_user)
) -> Any:
"""Get current user's profile."""
return current_user
@router.patch(
"/me",
response_model=UserResponse,
summary="Update Current User",
description="""
Update the current authenticated user's profile.
Users can update their own profile information (except is_superuser).
**Authentication**: Required (Bearer token)
**Rate Limit**: 30 requests/minute
""",
operation_id="update_current_user"
)
def update_current_user(
user_update: UserUpdate,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
) -> Any:
"""
Update current user's profile.
Users cannot elevate their own permissions (is_superuser).
"""
# Prevent users from making themselves superuser
if user_update.is_superuser is not None:
logger.warning(f"User {current_user.id} attempted to modify is_superuser field")
raise AuthorizationError(
message="Cannot modify superuser status",
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS
)
try:
updated_user = user_crud.update(
db,
db_obj=current_user,
obj_in=user_update
)
logger.info(f"User {current_user.id} updated their profile")
return updated_user
except ValueError as e:
logger.error(f"Error updating user {current_user.id}: {str(e)}")
raise
except Exception as e:
logger.error(f"Unexpected error updating user {current_user.id}: {str(e)}", exc_info=True)
raise
@router.get(
"/{user_id}",
response_model=UserResponse,
summary="Get User by ID",
description="""
Get a specific user by their ID.
**Authentication**: Required (Bearer token)
**Authorization**:
- Regular users: Can only access their own profile
- Superusers: Can access any profile
**Rate Limit**: 60 requests/minute
""",
operation_id="get_user_by_id"
)
def get_user_by_id(
user_id: UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
) -> Any:
"""
Get user by ID.
Users can only view their own profile unless they are superusers.
"""
# Check permissions
if str(user_id) != str(current_user.id) and not current_user.is_superuser:
logger.warning(
f"User {current_user.id} attempted to access user {user_id} without permission"
)
raise AuthorizationError(
message="Not enough permissions to view this user",
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS
)
# Get user
user = user_crud.get(db, id=str(user_id))
if not user:
raise NotFoundError(
message=f"User with id {user_id} not found",
error_code=ErrorCode.USER_NOT_FOUND
)
return user
@router.patch(
"/{user_id}",
response_model=UserResponse,
summary="Update User",
description="""
Update a specific user by their ID.
**Authentication**: Required (Bearer token)
**Authorization**:
- Regular users: Can only update their own profile (except is_superuser)
- Superusers: Can update any profile
**Rate Limit**: 30 requests/minute
""",
operation_id="update_user"
)
def update_user(
user_id: UUID,
user_update: UserUpdate,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
) -> Any:
"""
Update user by ID.
Users can update their own profile. Superusers can update any profile.
Regular users cannot modify is_superuser field.
"""
# Check permissions
is_own_profile = str(user_id) == str(current_user.id)
if not is_own_profile and not current_user.is_superuser:
logger.warning(
f"User {current_user.id} attempted to update user {user_id} without permission"
)
raise AuthorizationError(
message="Not enough permissions to update this user",
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS
)
# Get user
user = user_crud.get(db, id=str(user_id))
if not user:
raise NotFoundError(
message=f"User with id {user_id} not found",
error_code=ErrorCode.USER_NOT_FOUND
)
# Prevent non-superusers from modifying superuser status
if user_update.is_superuser is not None and not current_user.is_superuser:
logger.warning(f"User {current_user.id} attempted to modify is_superuser field")
raise AuthorizationError(
message="Cannot modify superuser status",
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS
)
try:
updated_user = user_crud.update(db, db_obj=user, obj_in=user_update)
logger.info(f"User {user_id} updated by {current_user.id}")
return updated_user
except ValueError as e:
logger.error(f"Error updating user {user_id}: {str(e)}")
raise
except Exception as e:
logger.error(f"Unexpected error updating user {user_id}: {str(e)}", exc_info=True)
raise
@router.patch(
"/me/password",
response_model=MessageResponse,
summary="Change Current User Password",
description="""
Change the current authenticated user's password.
Requires the current password for verification.
**Authentication**: Required (Bearer token)
**Rate Limit**: 5 requests/minute
""",
operation_id="change_current_user_password"
)
@limiter.limit("5/minute")
def change_current_user_password(
request: Request,
password_change: PasswordChange,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
) -> Any:
"""
Change current user's password.
Requires current password for verification.
"""
try:
success = AuthService.change_password(
db=db,
user_id=current_user.id,
current_password=password_change.current_password,
new_password=password_change.new_password
)
if success:
logger.info(f"User {current_user.id} changed their password")
return MessageResponse(
success=True,
message="Password changed successfully"
)
except Exception as e:
logger.error(f"Error changing password for user {current_user.id}: {str(e)}")
raise
@router.delete(
"/{user_id}",
status_code=status.HTTP_200_OK,
response_model=MessageResponse,
summary="Delete User",
description="""
Delete a specific user by their ID.
**Authentication**: Required (Bearer token)
**Authorization**: Superuser only
**Rate Limit**: 10 requests/minute
**Note**: This performs a hard delete. Consider implementing soft deletes for production.
""",
operation_id="delete_user"
)
def delete_user(
user_id: UUID,
current_user: User = Depends(get_current_superuser),
db: Session = Depends(get_db)
) -> Any:
"""
Delete user by ID (superuser only).
This is a hard delete operation.
"""
# Prevent self-deletion
if str(user_id) == str(current_user.id):
raise AuthorizationError(
message="Cannot delete your own account",
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS
)
# Get user
user = user_crud.get(db, id=str(user_id))
if not user:
raise NotFoundError(
message=f"User with id {user_id} not found",
error_code=ErrorCode.USER_NOT_FOUND
)
try:
user_crud.remove(db, id=str(user_id))
logger.info(f"User {user_id} deleted by {current_user.id}")
return MessageResponse(
success=True,
message=f"User {user_id} deleted successfully"
)
except ValueError as e:
logger.error(f"Error deleting user {user_id}: {str(e)}")
raise
except Exception as e:
logger.error(f"Unexpected error deleting user {user_id}: {str(e)}", exc_info=True)
raise

View File

@@ -1,8 +1,10 @@
# app/core/database.py
import logging
from sqlalchemy import create_engine
from contextlib import contextmanager
from typing import Generator
from sqlalchemy import create_engine, text
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.dialects.postgresql import JSONB, UUID
@@ -49,12 +51,62 @@ def create_production_engine():
# Default production engine and session factory
engine = create_production_engine()
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
SessionLocal = sessionmaker(
autocommit=False,
autoflush=False,
bind=engine,
expire_on_commit=False # Prevent unnecessary queries after commit
)
# FastAPI dependency
def get_db():
def get_db() -> Generator[Session, None, None]:
"""
FastAPI dependency that provides a database session.
Automatically closes the session after the request completes.
"""
db = SessionLocal()
try:
yield db
finally:
db.close()
db.close()
@contextmanager
def transaction_scope() -> Generator[Session, None, None]:
"""
Provide a transactional scope for database operations.
Automatically commits on success or rolls back on exception.
Useful for grouping multiple operations in a single transaction.
Usage:
with transaction_scope() as db:
user = user_crud.create(db, obj_in=user_create)
profile = profile_crud.create(db, obj_in=profile_create)
# Both operations committed together
"""
db = SessionLocal()
try:
yield db
db.commit()
logger.debug("Transaction committed successfully")
except Exception as e:
db.rollback()
logger.error(f"Transaction failed, rolling back: {str(e)}")
raise
finally:
db.close()
def check_database_health() -> bool:
"""
Check if database connection is healthy.
Returns True if connection is successful, False otherwise.
"""
try:
with transaction_scope() as db:
db.execute(text("SELECT 1"))
return True
except Exception as e:
logger.error(f"Database health check failed: {str(e)}")
return False

View File

@@ -0,0 +1,281 @@
"""
Custom exceptions and global exception handlers for the API.
"""
import logging
from typing import Optional, Union, List
from fastapi import HTTPException, Request, status
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
from pydantic import ValidationError
from app.schemas.errors import ErrorCode, ErrorDetail, ErrorResponse
logger = logging.getLogger(__name__)
class APIException(HTTPException):
"""
Base exception class with error code support.
This exception provides a standardized way to raise HTTP exceptions
with machine-readable error codes.
"""
def __init__(
self,
status_code: int,
error_code: ErrorCode,
message: str,
field: Optional[str] = None,
headers: Optional[dict] = None
):
self.error_code = error_code
self.field = field
self.message = message
super().__init__(
status_code=status_code,
detail=message,
headers=headers
)
class AuthenticationError(APIException):
"""Raised when authentication fails."""
def __init__(
self,
message: str = "Authentication failed",
error_code: ErrorCode = ErrorCode.INVALID_CREDENTIALS,
field: Optional[str] = None
):
super().__init__(
status_code=status.HTTP_401_UNAUTHORIZED,
error_code=error_code,
message=message,
field=field,
headers={"WWW-Authenticate": "Bearer"}
)
class AuthorizationError(APIException):
"""Raised when user lacks required permissions."""
def __init__(
self,
message: str = "Insufficient permissions",
error_code: ErrorCode = ErrorCode.INSUFFICIENT_PERMISSIONS
):
super().__init__(
status_code=status.HTTP_403_FORBIDDEN,
error_code=error_code,
message=message
)
class NotFoundError(APIException):
"""Raised when a resource is not found."""
def __init__(
self,
message: str = "Resource not found",
error_code: ErrorCode = ErrorCode.NOT_FOUND
):
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
error_code=error_code,
message=message
)
class DuplicateError(APIException):
"""Raised when attempting to create a duplicate resource."""
def __init__(
self,
message: str = "Resource already exists",
error_code: ErrorCode = ErrorCode.DUPLICATE_ENTRY,
field: Optional[str] = None
):
super().__init__(
status_code=status.HTTP_409_CONFLICT,
error_code=error_code,
message=message,
field=field
)
class ValidationException(APIException):
"""Raised when input validation fails."""
def __init__(
self,
message: str = "Validation error",
error_code: ErrorCode = ErrorCode.VALIDATION_ERROR,
field: Optional[str] = None
):
super().__init__(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
error_code=error_code,
message=message,
field=field
)
class DatabaseError(APIException):
"""Raised when a database operation fails."""
def __init__(
self,
message: str = "Database operation failed",
error_code: ErrorCode = ErrorCode.DATABASE_ERROR
):
super().__init__(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
error_code=error_code,
message=message
)
# Global exception handlers
async def api_exception_handler(request: Request, exc: APIException) -> JSONResponse:
"""
Handler for APIException and its subclasses.
Returns a standardized error response with error code and message.
"""
logger.warning(
f"API exception: {exc.error_code} - {exc.message} "
f"(status: {exc.status_code}, path: {request.url.path})"
)
error_response = ErrorResponse(
errors=[ErrorDetail(
code=exc.error_code,
message=exc.message,
field=exc.field
)]
)
return JSONResponse(
status_code=exc.status_code,
content=error_response.model_dump(),
headers=exc.headers
)
async def validation_exception_handler(
request: Request,
exc: Union[RequestValidationError, ValidationError]
) -> JSONResponse:
"""
Handler for Pydantic validation errors.
Converts Pydantic validation errors to standardized error response format.
"""
errors = []
if isinstance(exc, RequestValidationError):
validation_errors = exc.errors()
else:
validation_errors = exc.errors()
for error in validation_errors:
# Extract field name from error location
field = None
if error.get("loc") and len(error["loc"]) > 1:
# Skip 'body' or 'query' prefix in location
field = ".".join(str(x) for x in error["loc"][1:])
errors.append(ErrorDetail(
code=ErrorCode.VALIDATION_ERROR,
message=error["msg"],
field=field
))
logger.warning(
f"Validation error: {len(errors)} errors "
f"(path: {request.url.path})"
)
error_response = ErrorResponse(errors=errors)
return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
content=error_response.model_dump()
)
async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
"""
Handler for standard HTTPException.
Converts standard FastAPI HTTPException to standardized error response format.
"""
# Map status codes to error codes
status_code_to_error_code = {
400: ErrorCode.INVALID_INPUT,
401: ErrorCode.AUTHENTICATION_REQUIRED,
403: ErrorCode.INSUFFICIENT_PERMISSIONS,
404: ErrorCode.NOT_FOUND,
405: ErrorCode.METHOD_NOT_ALLOWED,
429: ErrorCode.RATE_LIMIT_EXCEEDED,
500: ErrorCode.INTERNAL_ERROR,
}
error_code = status_code_to_error_code.get(
exc.status_code,
ErrorCode.INTERNAL_ERROR
)
logger.warning(
f"HTTP exception: {exc.status_code} - {exc.detail} "
f"(path: {request.url.path})"
)
error_response = ErrorResponse(
errors=[ErrorDetail(
code=error_code,
message=str(exc.detail)
)]
)
return JSONResponse(
status_code=exc.status_code,
content=error_response.model_dump(),
headers=exc.headers
)
async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse:
"""
Handler for unhandled exceptions.
Logs the full exception and returns a generic error response to avoid
leaking sensitive information in production.
"""
logger.error(
f"Unhandled exception: {type(exc).__name__} - {str(exc)} "
f"(path: {request.url.path})",
exc_info=True
)
# In production, don't expose internal error details
from app.core.config import settings
if settings.ENVIRONMENT == "production":
message = "An internal error occurred. Please try again later."
else:
message = f"{type(exc).__name__}: {str(exc)}"
error_response = ErrorResponse(
errors=[ErrorDetail(
code=ErrorCode.INTERNAL_ERROR,
message=message
)]
)
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content=error_response.model_dump()
)

View File

@@ -1,8 +1,14 @@
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union, Tuple
from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel
from sqlalchemy.orm import Session
from sqlalchemy.exc import IntegrityError, OperationalError, DataError
from sqlalchemy import func
from app.core.database import Base
import logging
import uuid
logger = logging.getLogger(__name__)
ModelType = TypeVar("ModelType", bound=Base)
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
@@ -20,20 +26,63 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
self.model = model
def get(self, db: Session, id: str) -> Optional[ModelType]:
return db.query(self.model).filter(self.model.id == id).first()
"""Get a single record by ID with UUID validation."""
# Validate UUID format
try:
uuid.UUID(id)
except (ValueError, AttributeError):
logger.warning(f"Invalid UUID format: {id}")
return None
try:
return db.query(self.model).filter(self.model.id == id).first()
except Exception as e:
logger.error(f"Error retrieving {self.model.__name__} with id {id}: {str(e)}")
raise
def get_multi(
self, db: Session, *, skip: int = 0, limit: int = 100
) -> List[ModelType]:
return db.query(self.model).offset(skip).limit(limit).all()
"""Get multiple records with pagination validation."""
# Validate pagination parameters
if skip < 0:
raise ValueError("skip must be non-negative")
if limit < 0:
raise ValueError("limit must be non-negative")
if limit > 1000:
raise ValueError("Maximum limit is 1000")
try:
return db.query(self.model).offset(skip).limit(limit).all()
except Exception as e:
logger.error(f"Error retrieving multiple {self.model.__name__} records: {str(e)}")
raise
def create(self, db: Session, *, obj_in: CreateSchemaType) -> ModelType:
obj_in_data = jsonable_encoder(obj_in)
db_obj = self.model(**obj_in_data)
db.add(db_obj)
db.commit()
db.refresh(db_obj)
return db_obj
"""Create a new record with error handling."""
try:
obj_in_data = jsonable_encoder(obj_in)
db_obj = self.model(**obj_in_data)
db.add(db_obj)
db.commit()
db.refresh(db_obj)
return db_obj
except IntegrityError as e:
db.rollback()
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}")
raise ValueError(f"A {self.model.__name__} with this data already exists")
logger.error(f"Integrity error creating {self.model.__name__}: {error_msg}")
raise ValueError(f"Database integrity error: {error_msg}")
except (OperationalError, DataError) as e:
db.rollback()
logger.error(f"Database error creating {self.model.__name__}: {str(e)}")
raise ValueError(f"Database operation failed: {str(e)}")
except Exception as e:
db.rollback()
logger.error(f"Unexpected error creating {self.model.__name__}: {str(e)}", exc_info=True)
raise
def update(
self,
@@ -42,21 +91,90 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
db_obj: ModelType,
obj_in: Union[UpdateSchemaType, Dict[str, Any]]
) -> ModelType:
obj_data = jsonable_encoder(db_obj)
if isinstance(obj_in, dict):
update_data = obj_in
else:
update_data = obj_in.model_dump(exclude_unset=True)
for field in obj_data:
if field in update_data:
setattr(db_obj, field, update_data[field])
db.add(db_obj)
db.commit()
db.refresh(db_obj)
return db_obj
"""Update a record with error handling."""
try:
obj_data = jsonable_encoder(db_obj)
if isinstance(obj_in, dict):
update_data = obj_in
else:
update_data = obj_in.model_dump(exclude_unset=True)
for field in obj_data:
if field in update_data:
setattr(db_obj, field, update_data[field])
db.add(db_obj)
db.commit()
db.refresh(db_obj)
return db_obj
except IntegrityError as e:
db.rollback()
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
if "unique" in error_msg.lower() or "duplicate" in error_msg.lower():
logger.warning(f"Duplicate entry attempted for {self.model.__name__}: {error_msg}")
raise ValueError(f"A {self.model.__name__} with this data already exists")
logger.error(f"Integrity error updating {self.model.__name__}: {error_msg}")
raise ValueError(f"Database integrity error: {error_msg}")
except (OperationalError, DataError) as e:
db.rollback()
logger.error(f"Database error updating {self.model.__name__}: {str(e)}")
raise ValueError(f"Database operation failed: {str(e)}")
except Exception as e:
db.rollback()
logger.error(f"Unexpected error updating {self.model.__name__}: {str(e)}", exc_info=True)
raise
def remove(self, db: Session, *, id: str) -> ModelType:
obj = db.query(self.model).get(id)
db.delete(obj)
db.commit()
return obj
def remove(self, db: Session, *, id: str) -> Optional[ModelType]:
"""Delete a record with error handling and null check."""
# Validate UUID format
try:
uuid.UUID(id)
except (ValueError, AttributeError):
logger.warning(f"Invalid UUID format for deletion: {id}")
return None
try:
obj = db.query(self.model).filter(self.model.id == id).first()
if obj is None:
logger.warning(f"{self.model.__name__} with id {id} not found for deletion")
return None
db.delete(obj)
db.commit()
return obj
except IntegrityError as e:
db.rollback()
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
logger.error(f"Integrity error deleting {self.model.__name__}: {error_msg}")
raise ValueError(f"Cannot delete {self.model.__name__}: referenced by other records")
except Exception as e:
db.rollback()
logger.error(f"Error deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
raise
def get_multi_with_total(
self, db: Session, *, skip: int = 0, limit: int = 100
) -> Tuple[List[ModelType], int]:
"""
Get multiple records with total count for pagination.
Returns:
Tuple of (items, total_count)
"""
# Validate pagination parameters
if skip < 0:
raise ValueError("skip must be non-negative")
if limit < 0:
raise ValueError("limit must be non-negative")
if limit > 1000:
raise ValueError("Maximum limit is 1000")
try:
# Get total count
total = db.query(func.count(self.model.id)).scalar()
# Get paginated items
items = db.query(self.model).offset(skip).limit(limit).all()
return items, total
except Exception as e:
logger.error(f"Error retrieving paginated {self.model.__name__} records: {str(e)}")
raise

View File

@@ -1,10 +1,14 @@
# app/crud/user.py
from typing import Optional, Union, Dict, Any
from sqlalchemy.orm import Session
from sqlalchemy.exc import IntegrityError
from app.crud.base import CRUDBase
from app.models.user import User
from app.schemas.users import UserCreate, UserUpdate
from app.core.auth import get_password_hash
import logging
logger = logging.getLogger(__name__)
class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
@@ -12,19 +16,33 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
return db.query(User).filter(User.email == email).first()
def create(self, db: Session, *, obj_in: UserCreate) -> User:
db_obj = User(
email=obj_in.email,
password_hash=get_password_hash(obj_in.password),
first_name=obj_in.first_name,
last_name=obj_in.last_name,
phone_number=obj_in.phone_number if hasattr(obj_in, 'phone_number') else None,
is_superuser=obj_in.is_superuser if hasattr(obj_in, 'is_superuser') else False,
preferences={}
)
db.add(db_obj)
db.commit()
db.refresh(db_obj)
return db_obj
"""Create a new user with password hashing and error handling."""
try:
db_obj = User(
email=obj_in.email,
password_hash=get_password_hash(obj_in.password),
first_name=obj_in.first_name,
last_name=obj_in.last_name,
phone_number=obj_in.phone_number if hasattr(obj_in, 'phone_number') else None,
is_superuser=obj_in.is_superuser if hasattr(obj_in, 'is_superuser') else False,
preferences={}
)
db.add(db_obj)
db.commit()
db.refresh(db_obj)
return db_obj
except IntegrityError as e:
db.rollback()
error_msg = str(e.orig) if hasattr(e, 'orig') else str(e)
if "email" in error_msg.lower():
logger.warning(f"Duplicate email attempted: {obj_in.email}")
raise ValueError(f"User with email {obj_in.email} already exists")
logger.error(f"Integrity error creating user: {error_msg}")
raise ValueError(f"Database integrity error: {error_msg}")
except Exception as e:
db.rollback()
logger.error(f"Unexpected error creating user: {str(e)}", exc_info=True)
raise
def update(
self,

View File

@@ -19,7 +19,7 @@ def init_db(db: Session) -> Optional[UserCreate]:
"""
# Use default values if not set in environment variables
superuser_email = settings.FIRST_SUPERUSER_EMAIL or "admin@example.com"
superuser_password = settings.FIRST_SUPERUSER_PASSWORD or "admin123"
superuser_password = settings.FIRST_SUPERUSER_PASSWORD or "Admin123!Change"
if not settings.FIRST_SUPERUSER_EMAIL or not settings.FIRST_SUPERUSER_PASSWORD:
logger.warning(

View File

@@ -3,9 +3,10 @@ from datetime import datetime
from typing import Dict, Any
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from fastapi import FastAPI, status, Request
from fastapi import FastAPI, status, Request, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.exceptions import RequestValidationError
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
@@ -13,7 +14,14 @@ from sqlalchemy import text
from app.api.main import api_router
from app.core.config import settings
from app.core.database import get_db
from app.core.database import get_db, check_database_health
from app.core.exceptions import (
APIException,
api_exception_handler,
validation_exception_handler,
http_exception_handler,
unhandled_exception_handler
)
scheduler = AsyncIOScheduler()
@@ -33,13 +41,30 @@ app = FastAPI(
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
# Set up CORS middleware
# Register custom exception handlers (order matters - most specific first)
app.add_exception_handler(APIException, api_exception_handler)
app.add_exception_handler(RequestValidationError, validation_exception_handler)
app.add_exception_handler(HTTPException, http_exception_handler)
app.add_exception_handler(Exception, unhandled_exception_handler)
# Set up CORS middleware with explicit allowed methods and headers
app.add_middleware(
CORSMiddleware,
allow_origins=settings.BACKEND_CORS_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"], # Explicit methods only
allow_headers=[
"Content-Type",
"Authorization",
"Accept",
"Origin",
"User-Agent",
"DNT",
"Cache-Control",
"X-Requested-With",
], # Explicit headers only
expose_headers=["Content-Length"],
max_age=600, # Cache preflight requests for 10 minutes
)
@@ -120,15 +145,16 @@ async def health_check() -> JSONResponse:
response_status = status.HTTP_200_OK
# Database health check
# Database health check using dedicated health check function
try:
db = next(get_db())
db.execute(text("SELECT 1"))
health_status["checks"]["database"] = {
"status": "healthy",
"message": "Database connection successful"
}
db.close()
db_healthy = check_database_health()
if db_healthy:
health_status["checks"]["database"] = {
"status": "healthy",
"message": "Database connection successful"
}
else:
raise Exception("Database health check returned unhealthy status")
except Exception as e:
health_status["status"] = "unhealthy"
health_status["checks"]["database"] = {

View File

@@ -0,0 +1,139 @@
"""
Common schemas used across the API for pagination, responses, etc.
"""
from typing import Generic, TypeVar, List, Optional
from pydantic import BaseModel, Field
from math import ceil
T = TypeVar('T')
class PaginationParams(BaseModel):
"""Parameters for pagination."""
page: int = Field(
default=1,
ge=1,
description="Page number (1-indexed)"
)
limit: int = Field(
default=20,
ge=1,
le=100,
description="Number of items per page (max 100)"
)
@property
def offset(self) -> int:
"""Calculate the offset for database queries."""
return (self.page - 1) * self.limit
@property
def skip(self) -> int:
"""Alias for offset (compatibility with existing code)."""
return self.offset
model_config = {
"json_schema_extra": {
"example": {
"page": 1,
"limit": 20
}
}
}
class PaginationMeta(BaseModel):
"""Metadata for paginated responses."""
total: int = Field(..., description="Total number of items")
page: int = Field(..., description="Current page number")
page_size: int = Field(..., description="Number of items in current page")
total_pages: int = Field(..., description="Total number of pages")
has_next: bool = Field(..., description="Whether there is a next page")
has_prev: bool = Field(..., description="Whether there is a previous page")
model_config = {
"json_schema_extra": {
"example": {
"total": 150,
"page": 1,
"page_size": 20,
"total_pages": 8,
"has_next": True,
"has_prev": False
}
}
}
class PaginatedResponse(BaseModel, Generic[T]):
"""Generic paginated response wrapper."""
data: List[T] = Field(..., description="List of items")
pagination: PaginationMeta = Field(..., description="Pagination metadata")
model_config = {
"json_schema_extra": {
"example": {
"data": [
{"id": "123", "name": "Example Item"}
],
"pagination": {
"total": 150,
"page": 1,
"page_size": 20,
"total_pages": 8,
"has_next": True,
"has_prev": False
}
}
}
}
class MessageResponse(BaseModel):
"""Simple message response."""
success: bool = Field(default=True, description="Operation success status")
message: str = Field(..., description="Human-readable message")
model_config = {
"json_schema_extra": {
"example": {
"success": True,
"message": "Operation completed successfully"
}
}
}
def create_pagination_meta(
total: int,
page: int,
limit: int,
items_count: int
) -> PaginationMeta:
"""
Helper function to create pagination metadata.
Args:
total: Total number of items
page: Current page number
limit: Items per page
items_count: Number of items in current page
Returns:
PaginationMeta object with calculated values
"""
total_pages = ceil(total / limit) if limit > 0 else 0
return PaginationMeta(
total=total,
page=page,
page_size=items_count,
total_pages=total_pages,
has_next=page < total_pages,
has_prev=page > 1
)

View File

@@ -0,0 +1,85 @@
"""
Error schemas for standardized API error responses.
"""
from enum import Enum
from typing import List, Optional
from pydantic import BaseModel, Field
class ErrorCode(str, Enum):
"""Standard error codes for the API."""
# Authentication errors (AUTH_xxx)
INVALID_CREDENTIALS = "AUTH_001"
TOKEN_EXPIRED = "AUTH_002"
TOKEN_INVALID = "AUTH_003"
INSUFFICIENT_PERMISSIONS = "AUTH_004"
USER_INACTIVE = "AUTH_005"
AUTHENTICATION_REQUIRED = "AUTH_006"
# User errors (USER_xxx)
USER_NOT_FOUND = "USER_001"
USER_ALREADY_EXISTS = "USER_002"
USER_CREATION_FAILED = "USER_003"
USER_UPDATE_FAILED = "USER_004"
USER_DELETION_FAILED = "USER_005"
# Validation errors (VAL_xxx)
VALIDATION_ERROR = "VAL_001"
INVALID_PASSWORD = "VAL_002"
INVALID_EMAIL = "VAL_003"
INVALID_PHONE_NUMBER = "VAL_004"
INVALID_UUID = "VAL_005"
INVALID_INPUT = "VAL_006"
# Database errors (DB_xxx)
DATABASE_ERROR = "DB_001"
DUPLICATE_ENTRY = "DB_002"
FOREIGN_KEY_VIOLATION = "DB_003"
RECORD_NOT_FOUND = "DB_004"
# Generic errors (SYS_xxx)
INTERNAL_ERROR = "SYS_001"
NOT_FOUND = "SYS_002"
METHOD_NOT_ALLOWED = "SYS_003"
RATE_LIMIT_EXCEEDED = "SYS_004"
class ErrorDetail(BaseModel):
"""Detailed information about a single error."""
code: ErrorCode = Field(..., description="Machine-readable error code")
message: str = Field(..., description="Human-readable error message")
field: Optional[str] = Field(None, description="Field name if error is field-specific")
model_config = {
"json_schema_extra": {
"example": {
"code": "VAL_002",
"message": "Password must be at least 8 characters long",
"field": "password"
}
}
}
class ErrorResponse(BaseModel):
"""Standardized error response format."""
success: bool = Field(default=False, description="Always false for error responses")
errors: List[ErrorDetail] = Field(..., description="List of errors that occurred")
model_config = {
"json_schema_extra": {
"example": {
"success": False,
"errors": [
{
"code": "AUTH_001",
"message": "Invalid email or password",
"field": None
}
]
}
}
}

View File

@@ -123,7 +123,26 @@ class TokenData(BaseModel):
is_superuser: bool = False
class PasswordChange(BaseModel):
"""Schema for changing password (requires current password)."""
current_password: str
new_password: str
@field_validator('new_password')
@classmethod
def password_strength(cls, v: str) -> str:
"""Basic password strength validation"""
if len(v) < 8:
raise ValueError('Password must be at least 8 characters')
if not any(char.isdigit() for char in v):
raise ValueError('Password must contain at least one digit')
if not any(char.isupper() for char in v):
raise ValueError('Password must contain at least one uppercase letter')
return v
class PasswordReset(BaseModel):
"""Schema for resetting password (via email token)."""
token: str
new_password: str

1
backend/coverage.json Normal file

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,223 @@
# tests/test_init_db.py
import pytest
from unittest.mock import patch, MagicMock
from sqlalchemy.orm import Session
from app.init_db import init_db
from app.models.user import User
from app.schemas.users import UserCreate
class TestInitDB:
"""Tests for database initialization"""
def test_init_db_creates_superuser_when_not_exists(self, db_session, monkeypatch):
"""Test that init_db creates superuser when it doesn't exist"""
# Set environment variables
monkeypatch.setenv("FIRST_SUPERUSER_EMAIL", "admin@test.com")
monkeypatch.setenv("FIRST_SUPERUSER_PASSWORD", "TestPassword123!")
# Reload settings to pick up environment variables
from app.core import config
import importlib
importlib.reload(config)
from app.core.config import settings
# Mock user_crud to return None (user doesn't exist)
with patch('app.init_db.user_crud') as mock_crud:
mock_crud.get_by_email.return_value = None
# Create a mock user to return from create
from datetime import datetime, timezone
import uuid
mock_user = User(
id=uuid.uuid4(),
email="admin@test.com",
password_hash="hashed",
first_name="Admin",
last_name="User",
is_active=True,
is_superuser=True,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc)
)
mock_crud.create.return_value = mock_user
# Call init_db
user = init_db(db_session)
# Verify user was created
assert user is not None
assert user.email == "admin@test.com"
assert user.is_superuser is True
mock_crud.create.assert_called_once()
def test_init_db_returns_existing_superuser(self, db_session, monkeypatch):
"""Test that init_db returns existing superuser without creating new one"""
# Set environment variables
monkeypatch.setenv("FIRST_SUPERUSER_EMAIL", "existing@test.com")
monkeypatch.setenv("FIRST_SUPERUSER_PASSWORD", "TestPassword123!")
# Reload settings
from app.core import config
import importlib
importlib.reload(config)
# Mock user_crud to return existing user
with patch('app.init_db.user_crud') as mock_crud:
from datetime import datetime, timezone
import uuid
existing_user = User(
id=uuid.uuid4(),
email="existing@test.com",
password_hash="hashed",
first_name="Existing",
last_name="User",
is_active=True,
is_superuser=True,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc)
)
mock_crud.get_by_email.return_value = existing_user
# Call init_db
user = init_db(db_session)
# Verify existing user was returned
assert user is not None
assert user.email == "existing@test.com"
# create should NOT be called
mock_crud.create.assert_not_called()
def test_init_db_uses_defaults_when_env_not_set(self, db_session):
"""Test that init_db uses default credentials when env vars not set"""
# Mock settings to return None for superuser credentials
with patch('app.init_db.settings') as mock_settings:
mock_settings.FIRST_SUPERUSER_EMAIL = None
mock_settings.FIRST_SUPERUSER_PASSWORD = None
# Mock user_crud
with patch('app.init_db.user_crud') as mock_crud:
mock_crud.get_by_email.return_value = None
from datetime import datetime, timezone
import uuid
mock_user = User(
id=uuid.uuid4(),
email="admin@example.com",
password_hash="hashed",
first_name="Admin",
last_name="User",
is_active=True,
is_superuser=True,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc)
)
mock_crud.create.return_value = mock_user
# Call init_db
with patch('app.init_db.logger') as mock_logger:
user = init_db(db_session)
# Verify default email was used
mock_crud.get_by_email.assert_called_with(db_session, email="admin@example.com")
# Verify warning was logged since credentials not set
assert mock_logger.warning.called
def test_init_db_handles_creation_error(self, db_session, monkeypatch):
"""Test that init_db handles errors during user creation"""
# Set environment variables
monkeypatch.setenv("FIRST_SUPERUSER_EMAIL", "admin@test.com")
monkeypatch.setenv("FIRST_SUPERUSER_PASSWORD", "TestPassword123!")
# Reload settings
from app.core import config
import importlib
importlib.reload(config)
# Mock user_crud to raise an exception
with patch('app.init_db.user_crud') as mock_crud:
mock_crud.get_by_email.return_value = None
mock_crud.create.side_effect = Exception("Database error")
# Call init_db and expect exception
with pytest.raises(Exception) as exc_info:
init_db(db_session)
assert "Database error" in str(exc_info.value)
def test_init_db_logs_superuser_creation(self, db_session, monkeypatch):
"""Test that init_db logs appropriate messages"""
# Set environment variables
monkeypatch.setenv("FIRST_SUPERUSER_EMAIL", "admin@test.com")
monkeypatch.setenv("FIRST_SUPERUSER_PASSWORD", "TestPassword123!")
# Reload settings
from app.core import config
import importlib
importlib.reload(config)
# Mock user_crud
with patch('app.init_db.user_crud') as mock_crud:
mock_crud.get_by_email.return_value = None
from datetime import datetime, timezone
import uuid
mock_user = User(
id=uuid.uuid4(),
email="admin@test.com",
password_hash="hashed",
first_name="Admin",
last_name="User",
is_active=True,
is_superuser=True,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc)
)
mock_crud.create.return_value = mock_user
# Call init_db with logger mock
with patch('app.init_db.logger') as mock_logger:
user = init_db(db_session)
# Verify info log was called
assert mock_logger.info.called
info_call_args = str(mock_logger.info.call_args)
assert "Created first superuser" in info_call_args
def test_init_db_logs_existing_user(self, db_session, monkeypatch):
"""Test that init_db logs when user already exists"""
# Set environment variables
monkeypatch.setenv("FIRST_SUPERUSER_EMAIL", "existing@test.com")
monkeypatch.setenv("FIRST_SUPERUSER_PASSWORD", "TestPassword123!")
# Reload settings
from app.core import config
import importlib
importlib.reload(config)
# Mock user_crud to return existing user
with patch('app.init_db.user_crud') as mock_crud:
from datetime import datetime, timezone
import uuid
existing_user = User(
id=uuid.uuid4(),
email="existing@test.com",
password_hash="hashed",
first_name="Existing",
last_name="User",
is_active=True,
is_superuser=True,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc)
)
mock_crud.get_by_email.return_value = existing_user
# Call init_db with logger mock
with patch('app.init_db.logger') as mock_logger:
user = init_db(db_session)
# Verify info log was called
assert mock_logger.info.called
info_call_args = str(mock_logger.info.call_args)
assert "already exists" in info_call_args.lower()