Add UUID handling, sorting, filtering, and soft delete functionality to CRUD operations

- Enhanced UUID validation by supporting both string and `UUID` formats.
- Added advanced filtering and sorting capabilities to `get_multi_with_total` method.
- Introduced soft delete and restore functionality for models with `deleted_at` column.
- Updated tests to reflect new endpoints and rate-limiting logic.
- Improved schema definitions with `SortParams` and `SortOrder` for consistent API inputs.
This commit is contained in:
Felipe Cardoso
2025-10-30 16:44:15 +01:00
parent 2c600290a1
commit c684f2ba95
5 changed files with 200 additions and 56 deletions

View File

@@ -1,9 +1,10 @@
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union, Tuple
from datetime import datetime, timezone
from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel
from sqlalchemy.orm import Session
from sqlalchemy.exc import IntegrityError, OperationalError, DataError
from sqlalchemy import func
from sqlalchemy import func, asc, desc
from app.core.database import Base
import logging
import uuid
@@ -27,15 +28,18 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
def get(self, db: Session, id: str) -> Optional[ModelType]:
"""Get a single record by ID with UUID validation."""
# Validate UUID format
# Validate UUID format and convert to UUID object if string
try:
uuid.UUID(id)
except (ValueError, AttributeError):
logger.warning(f"Invalid UUID format: {id}")
if isinstance(id, uuid.UUID):
uuid_obj = id
else:
uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format: {id} - {str(e)}")
return None
try:
return db.query(self.model).filter(self.model.id == id).first()
return db.query(self.model).filter(self.model.id == uuid_obj).first()
except Exception as e:
logger.error(f"Error retrieving {self.model.__name__} with id {id}: {str(e)}")
raise
@@ -124,15 +128,18 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
def remove(self, db: Session, *, id: str) -> Optional[ModelType]:
"""Delete a record with error handling and null check."""
# Validate UUID format
# Validate UUID format and convert to UUID object if string
try:
uuid.UUID(id)
except (ValueError, AttributeError):
logger.warning(f"Invalid UUID format for deletion: {id}")
if isinstance(id, uuid.UUID):
uuid_obj = id
else:
uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format for deletion: {id} - {str(e)}")
return None
try:
obj = db.query(self.model).filter(self.model.id == id).first()
obj = db.query(self.model).filter(self.model.id == uuid_obj).first()
if obj is None:
logger.warning(f"{self.model.__name__} with id {id} not found for deletion")
return None
@@ -151,10 +158,25 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
raise
def get_multi_with_total(
self, db: Session, *, skip: int = 0, limit: int = 100
self,
db: Session,
*,
skip: int = 0,
limit: int = 100,
sort_by: Optional[str] = None,
sort_order: str = "asc",
filters: Optional[Dict[str, Any]] = None
) -> Tuple[List[ModelType], int]:
"""
Get multiple records with total count for pagination.
Get multiple records with total count, filtering, and sorting.
Args:
db: Database session
skip: Number of records to skip
limit: Maximum number of records to return
sort_by: Field name to sort by (must be a valid model attribute)
sort_order: Sort order ("asc" or "desc")
filters: Dictionary of filters (field_name: value)
Returns:
Tuple of (items, total_count)
@@ -168,13 +190,115 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
raise ValueError("Maximum limit is 1000")
try:
# Get total count
total = db.query(func.count(self.model.id)).scalar()
# Build base query
query = db.query(self.model)
# Get paginated items
items = db.query(self.model).offset(skip).limit(limit).all()
# Exclude soft-deleted records by default
if hasattr(self.model, 'deleted_at'):
query = query.filter(self.model.deleted_at.is_(None))
# Apply filters
if filters:
for field, value in filters.items():
if hasattr(self.model, field) and value is not None:
query = query.filter(getattr(self.model, field) == value)
# Get total count (before pagination)
total = query.count()
# Apply sorting
if sort_by and hasattr(self.model, sort_by):
sort_column = getattr(self.model, sort_by)
if sort_order.lower() == "desc":
query = query.order_by(desc(sort_column))
else:
query = query.order_by(asc(sort_column))
# Apply pagination
items = query.offset(skip).limit(limit).all()
return items, total
except Exception as e:
logger.error(f"Error retrieving paginated {self.model.__name__} records: {str(e)}")
raise
def soft_delete(self, db: Session, *, id: str) -> Optional[ModelType]:
"""
Soft delete a record by setting deleted_at timestamp.
Only works if the model has a 'deleted_at' column.
"""
# Validate UUID format and convert to UUID object if string
try:
if isinstance(id, uuid.UUID):
uuid_obj = id
else:
uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format for soft deletion: {id} - {str(e)}")
return None
try:
obj = db.query(self.model).filter(self.model.id == uuid_obj).first()
if obj is None:
logger.warning(f"{self.model.__name__} with id {id} not found for soft deletion")
return None
# Check if model supports soft deletes
if not hasattr(self.model, 'deleted_at'):
logger.error(f"{self.model.__name__} does not support soft deletes")
raise ValueError(f"{self.model.__name__} does not have a deleted_at column")
# Set deleted_at timestamp
obj.deleted_at = datetime.now(timezone.utc)
db.add(obj)
db.commit()
db.refresh(obj)
return obj
except Exception as e:
db.rollback()
logger.error(f"Error soft deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
raise
def restore(self, db: Session, *, id: str) -> Optional[ModelType]:
"""
Restore a soft-deleted record by clearing the deleted_at timestamp.
Only works if the model has a 'deleted_at' column.
"""
# Validate UUID format
try:
if isinstance(id, uuid.UUID):
uuid_obj = id
else:
uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format for restoration: {id} - {str(e)}")
return None
try:
# Find the soft-deleted record
if hasattr(self.model, 'deleted_at'):
obj = db.query(self.model).filter(
self.model.id == uuid_obj,
self.model.deleted_at.isnot(None)
).first()
else:
logger.error(f"{self.model.__name__} does not support soft deletes")
raise ValueError(f"{self.model.__name__} does not have a deleted_at column")
if obj is None:
logger.warning(f"Soft-deleted {self.model.__name__} with id {id} not found for restoration")
return None
# Clear deleted_at timestamp
obj.deleted_at = None
db.add(obj)
db.commit()
db.refresh(obj)
return obj
except Exception as e:
db.rollback()
logger.error(f"Error restoring {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
raise

View File

@@ -1,7 +1,8 @@
"""
Common schemas used across the API for pagination, responses, etc.
Common schemas used across the API for pagination, responses, filtering, and sorting.
"""
from typing import Generic, TypeVar, List, Optional
from enum import Enum
from pydantic import BaseModel, Field
from math import ceil
@@ -9,6 +10,12 @@ from math import ceil
T = TypeVar('T')
class SortOrder(str, Enum):
"""Sort order options."""
ASC = "asc"
DESC = "desc"
class PaginationParams(BaseModel):
"""Parameters for pagination."""
@@ -44,6 +51,28 @@ class PaginationParams(BaseModel):
}
class SortParams(BaseModel):
"""Parameters for sorting."""
sort_by: Optional[str] = Field(
default=None,
description="Field name to sort by"
)
sort_order: SortOrder = Field(
default=SortOrder.ASC,
description="Sort order (asc or desc)"
)
model_config = {
"json_schema_extra": {
"example": {
"sort_by": "created_at",
"sort_order": "desc"
}
}
}
class PaginationMeta(BaseModel):
"""Metadata for paginated responses."""

View File

@@ -10,6 +10,7 @@ from fastapi.testclient import TestClient
from sqlalchemy.orm import Session
from app.api.routes.auth import router as auth_router
from app.api.routes.users import router as users_router
from app.core.auth import get_password_hash
from app.core.database import get_db
from app.models.user import User
@@ -29,6 +30,7 @@ def app(override_get_db):
"""Create a FastAPI test application with overridden dependencies."""
app = FastAPI()
app.include_router(auth_router, prefix="/auth", tags=["auth"])
app.include_router(users_router, prefix="/api/v1/users", tags=["users"])
# Override the get_db dependency
app.dependency_overrides[get_db] = lambda: override_get_db
@@ -280,9 +282,9 @@ class TestChangePassword:
# Mock password change to return success
with patch.object(AuthService, 'change_password', return_value=True):
# Test request
response = client.post(
"/auth/change-password",
# Test request (new endpoint)
response = client.patch(
"/api/v1/users/me/password",
json={
"current_password": "OldPassword123",
"new_password": "NewPassword123"
@@ -291,7 +293,8 @@ class TestChangePassword:
# Assertions
assert response.status_code == 200
assert "success" in response.json()["message"].lower()
assert response.json()["success"] is True
assert "message" in response.json()
# Clean up override
if get_current_user in app.dependency_overrides:
@@ -312,18 +315,20 @@ class TestChangePassword:
# Mock password change to raise error
with patch.object(AuthService, 'change_password',
side_effect=AuthenticationError("Current password is incorrect")):
# Test request
response = client.post(
"/auth/change-password",
# Test request (new endpoint)
response = client.patch(
"/api/v1/users/me/password",
json={
"current_password": "WrongPassword",
"new_password": "NewPassword123"
}
)
# Assertions
assert response.status_code == 400
assert "incorrect" in response.json()["detail"].lower()
# Assertions - Now returns standardized error response
assert response.status_code == 403
# The response has standardized error format
data = response.json()
assert "detail" in data or "errors" in data
# Clean up override
if get_current_user in app.dependency_overrides:

View File

@@ -13,17 +13,10 @@ from app.core.database import get_db
@pytest.fixture
def client():
"""Create a FastAPI test client for the main app with mocked database."""
# Mock get_db to avoid connecting to the actual database
with patch("app.main.get_db") as mock_get_db:
def mock_session_generator():
mock_session = MagicMock()
# Mock the execute method to return successfully
mock_session.execute.return_value = None
mock_session.close.return_value = None
yield mock_session
# Return a new generator each time get_db is called
mock_get_db.side_effect = lambda: mock_session_generator()
# Mock check_database_health to avoid connecting to the actual database
with patch("app.main.check_database_health") as mock_health_check:
# By default, return True (healthy)
mock_health_check.return_value = True
yield TestClient(app)
@@ -90,23 +83,14 @@ class TestHealthEndpoint:
assert data["environment"] == settings.ENVIRONMENT
def test_health_check_database_connection_failure(self, client):
def test_health_check_database_connection_failure(self):
"""Test that health check returns unhealthy when database is not accessible"""
# Mock the database session to raise an exception
with patch("app.main.get_db") as mock_get_db:
def mock_session():
from unittest.mock import MagicMock
mock = MagicMock()
mock.execute.side_effect = OperationalError(
"Connection refused",
params=None,
orig=Exception("Connection refused")
)
yield mock
# Mock check_database_health to return False (unhealthy)
with patch("app.main.check_database_health") as mock_health_check:
mock_health_check.return_value = False
mock_get_db.return_value = mock_session()
response = client.get("/health")
test_client = TestClient(app)
response = test_client.get("/health")
assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE
data = response.json()

View File

@@ -5,6 +5,7 @@ from fastapi.testclient import TestClient
from unittest.mock import patch, MagicMock
from app.api.routes.auth import router as auth_router, limiter
from app.api.routes.users import router as users_router
from app.core.database import get_db
@@ -26,6 +27,7 @@ def app(override_get_db):
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
app.include_router(auth_router, prefix="/auth", tags=["auth"])
app.include_router(users_router, prefix="/api/v1/users", tags=["users"])
# Override the get_db dependency
app.dependency_overrides[get_db] = lambda: override_get_db
@@ -159,10 +161,10 @@ class TestChangePasswordRateLimiting:
"new_password": "NewPassword123!"
}
# Make 6 requests (limit is 5/minute)
# Make 6 requests (limit is 5/minute) - using new endpoint
responses = []
for i in range(6):
response = client.post("/auth/change-password", json=password_data)
response = client.patch("/api/v1/users/me/password", json=password_data)
responses.append(response)
# Last request should be rate limited