Add UUID handling, sorting, filtering, and soft delete functionality to CRUD operations

- Enhanced UUID validation by supporting both string and `UUID` formats.
- Added advanced filtering and sorting capabilities to `get_multi_with_total` method.
- Introduced soft delete and restore functionality for models with `deleted_at` column.
- Updated tests to reflect new endpoints and rate-limiting logic.
- Improved schema definitions with `SortParams` and `SortOrder` for consistent API inputs.
This commit is contained in:
Felipe Cardoso
2025-10-30 16:44:15 +01:00
parent 2c600290a1
commit c684f2ba95
5 changed files with 200 additions and 56 deletions

View File

@@ -1,9 +1,10 @@
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union, Tuple from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union, Tuple
from datetime import datetime, timezone
from fastapi.encoders import jsonable_encoder from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from sqlalchemy.exc import IntegrityError, OperationalError, DataError from sqlalchemy.exc import IntegrityError, OperationalError, DataError
from sqlalchemy import func from sqlalchemy import func, asc, desc
from app.core.database import Base from app.core.database import Base
import logging import logging
import uuid import uuid
@@ -27,15 +28,18 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
def get(self, db: Session, id: str) -> Optional[ModelType]: def get(self, db: Session, id: str) -> Optional[ModelType]:
"""Get a single record by ID with UUID validation.""" """Get a single record by ID with UUID validation."""
# Validate UUID format # Validate UUID format and convert to UUID object if string
try: try:
uuid.UUID(id) if isinstance(id, uuid.UUID):
except (ValueError, AttributeError): uuid_obj = id
logger.warning(f"Invalid UUID format: {id}") else:
uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format: {id} - {str(e)}")
return None return None
try: try:
return db.query(self.model).filter(self.model.id == id).first() return db.query(self.model).filter(self.model.id == uuid_obj).first()
except Exception as e: except Exception as e:
logger.error(f"Error retrieving {self.model.__name__} with id {id}: {str(e)}") logger.error(f"Error retrieving {self.model.__name__} with id {id}: {str(e)}")
raise raise
@@ -124,15 +128,18 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
def remove(self, db: Session, *, id: str) -> Optional[ModelType]: def remove(self, db: Session, *, id: str) -> Optional[ModelType]:
"""Delete a record with error handling and null check.""" """Delete a record with error handling and null check."""
# Validate UUID format # Validate UUID format and convert to UUID object if string
try: try:
uuid.UUID(id) if isinstance(id, uuid.UUID):
except (ValueError, AttributeError): uuid_obj = id
logger.warning(f"Invalid UUID format for deletion: {id}") else:
uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format for deletion: {id} - {str(e)}")
return None return None
try: try:
obj = db.query(self.model).filter(self.model.id == id).first() obj = db.query(self.model).filter(self.model.id == uuid_obj).first()
if obj is None: if obj is None:
logger.warning(f"{self.model.__name__} with id {id} not found for deletion") logger.warning(f"{self.model.__name__} with id {id} not found for deletion")
return None return None
@@ -151,10 +158,25 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
raise raise
def get_multi_with_total( def get_multi_with_total(
self, db: Session, *, skip: int = 0, limit: int = 100 self,
db: Session,
*,
skip: int = 0,
limit: int = 100,
sort_by: Optional[str] = None,
sort_order: str = "asc",
filters: Optional[Dict[str, Any]] = None
) -> Tuple[List[ModelType], int]: ) -> Tuple[List[ModelType], int]:
""" """
Get multiple records with total count for pagination. Get multiple records with total count, filtering, and sorting.
Args:
db: Database session
skip: Number of records to skip
limit: Maximum number of records to return
sort_by: Field name to sort by (must be a valid model attribute)
sort_order: Sort order ("asc" or "desc")
filters: Dictionary of filters (field_name: value)
Returns: Returns:
Tuple of (items, total_count) Tuple of (items, total_count)
@@ -168,13 +190,115 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
raise ValueError("Maximum limit is 1000") raise ValueError("Maximum limit is 1000")
try: try:
# Get total count # Build base query
total = db.query(func.count(self.model.id)).scalar() query = db.query(self.model)
# Get paginated items # Exclude soft-deleted records by default
items = db.query(self.model).offset(skip).limit(limit).all() if hasattr(self.model, 'deleted_at'):
query = query.filter(self.model.deleted_at.is_(None))
# Apply filters
if filters:
for field, value in filters.items():
if hasattr(self.model, field) and value is not None:
query = query.filter(getattr(self.model, field) == value)
# Get total count (before pagination)
total = query.count()
# Apply sorting
if sort_by and hasattr(self.model, sort_by):
sort_column = getattr(self.model, sort_by)
if sort_order.lower() == "desc":
query = query.order_by(desc(sort_column))
else:
query = query.order_by(asc(sort_column))
# Apply pagination
items = query.offset(skip).limit(limit).all()
return items, total return items, total
except Exception as e: except Exception as e:
logger.error(f"Error retrieving paginated {self.model.__name__} records: {str(e)}") logger.error(f"Error retrieving paginated {self.model.__name__} records: {str(e)}")
raise
def soft_delete(self, db: Session, *, id: str) -> Optional[ModelType]:
"""
Soft delete a record by setting deleted_at timestamp.
Only works if the model has a 'deleted_at' column.
"""
# Validate UUID format and convert to UUID object if string
try:
if isinstance(id, uuid.UUID):
uuid_obj = id
else:
uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format for soft deletion: {id} - {str(e)}")
return None
try:
obj = db.query(self.model).filter(self.model.id == uuid_obj).first()
if obj is None:
logger.warning(f"{self.model.__name__} with id {id} not found for soft deletion")
return None
# Check if model supports soft deletes
if not hasattr(self.model, 'deleted_at'):
logger.error(f"{self.model.__name__} does not support soft deletes")
raise ValueError(f"{self.model.__name__} does not have a deleted_at column")
# Set deleted_at timestamp
obj.deleted_at = datetime.now(timezone.utc)
db.add(obj)
db.commit()
db.refresh(obj)
return obj
except Exception as e:
db.rollback()
logger.error(f"Error soft deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
raise
def restore(self, db: Session, *, id: str) -> Optional[ModelType]:
"""
Restore a soft-deleted record by clearing the deleted_at timestamp.
Only works if the model has a 'deleted_at' column.
"""
# Validate UUID format
try:
if isinstance(id, uuid.UUID):
uuid_obj = id
else:
uuid_obj = uuid.UUID(str(id))
except (ValueError, AttributeError, TypeError) as e:
logger.warning(f"Invalid UUID format for restoration: {id} - {str(e)}")
return None
try:
# Find the soft-deleted record
if hasattr(self.model, 'deleted_at'):
obj = db.query(self.model).filter(
self.model.id == uuid_obj,
self.model.deleted_at.isnot(None)
).first()
else:
logger.error(f"{self.model.__name__} does not support soft deletes")
raise ValueError(f"{self.model.__name__} does not have a deleted_at column")
if obj is None:
logger.warning(f"Soft-deleted {self.model.__name__} with id {id} not found for restoration")
return None
# Clear deleted_at timestamp
obj.deleted_at = None
db.add(obj)
db.commit()
db.refresh(obj)
return obj
except Exception as e:
db.rollback()
logger.error(f"Error restoring {self.model.__name__} with id {id}: {str(e)}", exc_info=True)
raise raise

View File

@@ -1,7 +1,8 @@
""" """
Common schemas used across the API for pagination, responses, etc. Common schemas used across the API for pagination, responses, filtering, and sorting.
""" """
from typing import Generic, TypeVar, List, Optional from typing import Generic, TypeVar, List, Optional
from enum import Enum
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from math import ceil from math import ceil
@@ -9,6 +10,12 @@ from math import ceil
T = TypeVar('T') T = TypeVar('T')
class SortOrder(str, Enum):
"""Sort order options."""
ASC = "asc"
DESC = "desc"
class PaginationParams(BaseModel): class PaginationParams(BaseModel):
"""Parameters for pagination.""" """Parameters for pagination."""
@@ -44,6 +51,28 @@ class PaginationParams(BaseModel):
} }
class SortParams(BaseModel):
"""Parameters for sorting."""
sort_by: Optional[str] = Field(
default=None,
description="Field name to sort by"
)
sort_order: SortOrder = Field(
default=SortOrder.ASC,
description="Sort order (asc or desc)"
)
model_config = {
"json_schema_extra": {
"example": {
"sort_by": "created_at",
"sort_order": "desc"
}
}
}
class PaginationMeta(BaseModel): class PaginationMeta(BaseModel):
"""Metadata for paginated responses.""" """Metadata for paginated responses."""

View File

@@ -10,6 +10,7 @@ from fastapi.testclient import TestClient
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.api.routes.auth import router as auth_router from app.api.routes.auth import router as auth_router
from app.api.routes.users import router as users_router
from app.core.auth import get_password_hash from app.core.auth import get_password_hash
from app.core.database import get_db from app.core.database import get_db
from app.models.user import User from app.models.user import User
@@ -29,6 +30,7 @@ def app(override_get_db):
"""Create a FastAPI test application with overridden dependencies.""" """Create a FastAPI test application with overridden dependencies."""
app = FastAPI() app = FastAPI()
app.include_router(auth_router, prefix="/auth", tags=["auth"]) app.include_router(auth_router, prefix="/auth", tags=["auth"])
app.include_router(users_router, prefix="/api/v1/users", tags=["users"])
# Override the get_db dependency # Override the get_db dependency
app.dependency_overrides[get_db] = lambda: override_get_db app.dependency_overrides[get_db] = lambda: override_get_db
@@ -280,9 +282,9 @@ class TestChangePassword:
# Mock password change to return success # Mock password change to return success
with patch.object(AuthService, 'change_password', return_value=True): with patch.object(AuthService, 'change_password', return_value=True):
# Test request # Test request (new endpoint)
response = client.post( response = client.patch(
"/auth/change-password", "/api/v1/users/me/password",
json={ json={
"current_password": "OldPassword123", "current_password": "OldPassword123",
"new_password": "NewPassword123" "new_password": "NewPassword123"
@@ -291,7 +293,8 @@ class TestChangePassword:
# Assertions # Assertions
assert response.status_code == 200 assert response.status_code == 200
assert "success" in response.json()["message"].lower() assert response.json()["success"] is True
assert "message" in response.json()
# Clean up override # Clean up override
if get_current_user in app.dependency_overrides: if get_current_user in app.dependency_overrides:
@@ -312,18 +315,20 @@ class TestChangePassword:
# Mock password change to raise error # Mock password change to raise error
with patch.object(AuthService, 'change_password', with patch.object(AuthService, 'change_password',
side_effect=AuthenticationError("Current password is incorrect")): side_effect=AuthenticationError("Current password is incorrect")):
# Test request # Test request (new endpoint)
response = client.post( response = client.patch(
"/auth/change-password", "/api/v1/users/me/password",
json={ json={
"current_password": "WrongPassword", "current_password": "WrongPassword",
"new_password": "NewPassword123" "new_password": "NewPassword123"
} }
) )
# Assertions # Assertions - Now returns standardized error response
assert response.status_code == 400 assert response.status_code == 403
assert "incorrect" in response.json()["detail"].lower() # The response has standardized error format
data = response.json()
assert "detail" in data or "errors" in data
# Clean up override # Clean up override
if get_current_user in app.dependency_overrides: if get_current_user in app.dependency_overrides:

View File

@@ -13,17 +13,10 @@ from app.core.database import get_db
@pytest.fixture @pytest.fixture
def client(): def client():
"""Create a FastAPI test client for the main app with mocked database.""" """Create a FastAPI test client for the main app with mocked database."""
# Mock get_db to avoid connecting to the actual database # Mock check_database_health to avoid connecting to the actual database
with patch("app.main.get_db") as mock_get_db: with patch("app.main.check_database_health") as mock_health_check:
def mock_session_generator(): # By default, return True (healthy)
mock_session = MagicMock() mock_health_check.return_value = True
# Mock the execute method to return successfully
mock_session.execute.return_value = None
mock_session.close.return_value = None
yield mock_session
# Return a new generator each time get_db is called
mock_get_db.side_effect = lambda: mock_session_generator()
yield TestClient(app) yield TestClient(app)
@@ -90,23 +83,14 @@ class TestHealthEndpoint:
assert data["environment"] == settings.ENVIRONMENT assert data["environment"] == settings.ENVIRONMENT
def test_health_check_database_connection_failure(self, client): def test_health_check_database_connection_failure(self):
"""Test that health check returns unhealthy when database is not accessible""" """Test that health check returns unhealthy when database is not accessible"""
# Mock the database session to raise an exception # Mock check_database_health to return False (unhealthy)
with patch("app.main.get_db") as mock_get_db: with patch("app.main.check_database_health") as mock_health_check:
def mock_session(): mock_health_check.return_value = False
from unittest.mock import MagicMock
mock = MagicMock()
mock.execute.side_effect = OperationalError(
"Connection refused",
params=None,
orig=Exception("Connection refused")
)
yield mock
mock_get_db.return_value = mock_session() test_client = TestClient(app)
response = test_client.get("/health")
response = client.get("/health")
assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE
data = response.json() data = response.json()

View File

@@ -5,6 +5,7 @@ from fastapi.testclient import TestClient
from unittest.mock import patch, MagicMock from unittest.mock import patch, MagicMock
from app.api.routes.auth import router as auth_router, limiter from app.api.routes.auth import router as auth_router, limiter
from app.api.routes.users import router as users_router
from app.core.database import get_db from app.core.database import get_db
@@ -26,6 +27,7 @@ def app(override_get_db):
app.state.limiter = limiter app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
app.include_router(auth_router, prefix="/auth", tags=["auth"]) app.include_router(auth_router, prefix="/auth", tags=["auth"])
app.include_router(users_router, prefix="/api/v1/users", tags=["users"])
# Override the get_db dependency # Override the get_db dependency
app.dependency_overrides[get_db] = lambda: override_get_db app.dependency_overrides[get_db] = lambda: override_get_db
@@ -159,10 +161,10 @@ class TestChangePasswordRateLimiting:
"new_password": "NewPassword123!" "new_password": "NewPassword123!"
} }
# Make 6 requests (limit is 5/minute) # Make 6 requests (limit is 5/minute) - using new endpoint
responses = [] responses = []
for i in range(6): for i in range(6):
response = client.post("/auth/change-password", json=password_data) response = client.patch("/api/v1/users/me/password", json=password_data)
responses.append(response) responses.append(response)
# Last request should be rate limited # Last request should be rate limited