Add UUID handling, sorting, filtering, and soft delete functionality to CRUD operations

- Enhanced UUID validation by supporting both string and `UUID` formats.
- Added advanced filtering and sorting capabilities to `get_multi_with_total` method.
- Introduced soft delete and restore functionality for models with `deleted_at` column.
- Updated tests to reflect new endpoints and rate-limiting logic.
- Improved schema definitions with `SortParams` and `SortOrder` for consistent API inputs.
This commit is contained in:
Felipe Cardoso
2025-10-30 16:44:15 +01:00
parent 2c600290a1
commit c684f2ba95
5 changed files with 200 additions and 56 deletions

View File

@@ -1,9 +1,10 @@
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union, Tuple
from datetime import datetime, timezone
from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel
from sqlalchemy.orm import Session
from sqlalchemy.exc import IntegrityError, OperationalError, DataError
from sqlalchemy import func
from sqlalchemy import func, asc, desc
from app.core.database import Base
import logging
import uuid
@@ -27,15 +28,18 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
def get(self, db: Session, id: str) -> Optional[ModelType]:
"""Get a single record by ID with UUID validation."""
# Validate UUID format
# Validate UUID format and convert to UUID object if string
try:
uuid.UUID(id)
except (ValueError, AttributeError):
logger.warning(f"Invalid UUID format: {id}")
if isinstance(id, uuid.UUID):
uuid_obj = id
else:
uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format: {id} - {str(e)}")
return None
try:
return db.query(self.model).filter(self.model.id == id).first()
return db.query(self.model).filter(self.model.id == uuid_obj).first()
except Exception as e:
logger.error(f"Error retrieving {self.model.__name__} with id {id}: {str(e)}")
raise
@@ -124,15 +128,18 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
def remove(self, db: Session, *, id: str) -> Optional[ModelType]:
"""Delete a record with error handling and null check."""
# Validate UUID format
# Validate UUID format and convert to UUID object if string
try:
uuid.UUID(id)
except (ValueError, AttributeError):
logger.warning(f"Invalid UUID format for deletion: {id}")
if isinstance(id, uuid.UUID):
uuid_obj = id
else:
uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format for deletion: {id} - {str(e)}")
return None
try:
obj = db.query(self.model).filter(self.model.id == id).first()
obj = db.query(self.model).filter(self.model.id == uuid_obj).first()
if obj is None:
logger.warning(f"{self.model.__name__} with id {id} not found for deletion")
return None
@@ -151,10 +158,25 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
raise
def get_multi_with_total(
self, db: Session, *, skip: int = 0, limit: int = 100
self,
db: Session,
*,
skip: int = 0,
limit: int = 100,
sort_by: Optional[str] = None,
sort_order: str = "asc",
filters: Optional[Dict[str, Any]] = None
) -> Tuple[List[ModelType], int]:
"""
Get multiple records with total count for pagination.
Get multiple records with total count, filtering, and sorting.
Args:
db: Database session
skip: Number of records to skip
limit: Maximum number of records to return
sort_by: Field name to sort by (must be a valid model attribute)
sort_order: Sort order ("asc" or "desc")
filters: Dictionary of filters (field_name: value)
Returns:
Tuple of (items, total_count)
@@ -168,13 +190,115 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
raise ValueError("Maximum limit is 1000")
try:
# Get total count
total = db.query(func.count(self.model.id)).scalar()
# Build base query
query = db.query(self.model)
# Get paginated items
items = db.query(self.model).offset(skip).limit(limit).all()
# Exclude soft-deleted records by default
if hasattr(self.model, 'deleted_at'):
query = query.filter(self.model.deleted_at.is_(None))
# Apply filters
if filters:
for field, value in filters.items():
if hasattr(self.model, field) and value is not None:
query = query.filter(getattr(self.model, field) == value)
# Get total count (before pagination)
total = query.count()
# Apply sorting
if sort_by and hasattr(self.model, sort_by):
sort_column = getattr(self.model, sort_by)
if sort_order.lower() == "desc":
query = query.order_by(desc(sort_column))
else:
query = query.order_by(asc(sort_column))
# Apply pagination
items = query.offset(skip).limit(limit).all()
return items, total
except Exception as e:
logger.error(f"Error retrieving paginated {self.model.__name__} records: {str(e)}")
raise
def soft_delete(self, db: Session, *, id: str) -> Optional[ModelType]:
"""
Soft delete a record by setting deleted_at timestamp.
Only works if the model has a 'deleted_at' column.
"""
# Validate UUID format and convert to UUID object if string
try:
if isinstance(id, uuid.UUID):
uuid_obj = id
else:
uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format for soft deletion: {id} - {str(e)}")
return None
try:
obj = db.query(self.model).filter(self.model.id == uuid_obj).first()
if obj is None:
logger.warning(f"{self.model.__name__} with id {id} not found for soft deletion")
return None
# Check if model supports soft deletes
if not hasattr(self.model, 'deleted_at'):
logger.error(f"{self.model.__name__} does not support soft deletes")
raise ValueError(f"{self.model.__name__} does not have a deleted_at column")
# Set deleted_at timestamp
obj.deleted_at = datetime.now(timezone.utc)
db.add(obj)
db.commit()
db.refresh(obj)
return obj
except Exception as e:
db.rollback()
logger.error(f"Error soft deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
raise
def restore(self, db: Session, *, id: str) -> Optional[ModelType]:
"""
Restore a soft-deleted record by clearing the deleted_at timestamp.
Only works if the model has a 'deleted_at' column.
"""
# Validate UUID format
try:
if isinstance(id, uuid.UUID):
uuid_obj = id
else:
uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format for restoration: {id} - {str(e)}")
return None
try:
# Find the soft-deleted record
if hasattr(self.model, 'deleted_at'):
obj = db.query(self.model).filter(
self.model.id == uuid_obj,
self.model.deleted_at.isnot(None)
).first()
else:
logger.error(f"{self.model.__name__} does not support soft deletes")
raise ValueError(f"{self.model.__name__} does not have a deleted_at column")
if obj is None:
logger.warning(f"Soft-deleted {self.model.__name__} with id {id} not found for restoration")
return None
# Clear deleted_at timestamp
obj.deleted_at = None
db.add(obj)
db.commit()
db.refresh(obj)
return obj
except Exception as e:
db.rollback()
logger.error(f"Error restoring {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
raise

View File

@@ -1,7 +1,8 @@
"""
Common schemas used across the API for pagination, responses, etc.
Common schemas used across the API for pagination, responses, filtering, and sorting.
"""
from typing import Generic, TypeVar, List, Optional
from enum import Enum
from pydantic import BaseModel, Field
from math import ceil
@@ -9,6 +10,12 @@ from math import ceil
T = TypeVar('T')
class SortOrder(str, Enum):
"""Sort order options."""
ASC = "asc"
DESC = "desc"
class PaginationParams(BaseModel):
"""Parameters for pagination."""
@@ -44,6 +51,28 @@ class PaginationParams(BaseModel):
}
class SortParams(BaseModel):
"""Parameters for sorting."""
sort_by: Optional[str] = Field(
default=None,
description="Field name to sort by"
)
sort_order: SortOrder = Field(
default=SortOrder.ASC,
description="Sort order (asc or desc)"
)
model_config = {
"json_schema_extra": {
"example": {
"sort_by": "created_at",
"sort_order": "desc"
}
}
}
class PaginationMeta(BaseModel):
"""Metadata for paginated responses."""