From c684f2ba95e0f174be3e404256985193b09a02be Mon Sep 17 00:00:00 2001 From: Felipe Cardoso Date: Thu, 30 Oct 2025 16:44:15 +0100 Subject: [PATCH] Add UUID handling, sorting, filtering, and soft delete functionality to CRUD operations - Enhanced UUID validation by supporting both string and `UUID` formats. - Added advanced filtering and sorting capabilities to `get_multi_with_total` method. - Introduced soft delete and restore functionality for models with `deleted_at` column. - Updated tests to reflect new endpoints and rate-limiting logic. - Improved schema definitions with `SortParams` and `SortOrder` for consistent API inputs. --- backend/app/crud/base.py | 158 ++++++++++++++++-- backend/app/schemas/common.py | 31 +++- backend/tests/api/routes/test_auth.py | 25 +-- backend/tests/api/routes/test_health.py | 36 ++-- .../tests/api/routes/test_rate_limiting.py | 6 +- 5 files changed, 200 insertions(+), 56 deletions(-) diff --git a/backend/app/crud/base.py b/backend/app/crud/base.py index 6981171..c524df4 100644 --- a/backend/app/crud/base.py +++ b/backend/app/crud/base.py @@ -1,9 +1,10 @@ from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union, Tuple +from datetime import datetime, timezone from fastapi.encoders import jsonable_encoder from pydantic import BaseModel from sqlalchemy.orm import Session from sqlalchemy.exc import IntegrityError, OperationalError, DataError -from sqlalchemy import func +from sqlalchemy import func, asc, desc from app.core.database import Base import logging import uuid @@ -27,15 +28,18 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): def get(self, db: Session, id: str) -> Optional[ModelType]: """Get a single record by ID with UUID validation.""" - # Validate UUID format + # Validate UUID format and convert to UUID object if string try: - uuid.UUID(id) - except (ValueError, AttributeError): - logger.warning(f"Invalid UUID format: {id}") + if isinstance(id, uuid.UUID): + uuid_obj = id + else: + uuid_obj = uuid.UUID(str(id)) + except (ValueError, AttributeError, TypeError) as e: + logger.warning(f"Invalid UUID format: {id} - {str(e)}") return None try: - return db.query(self.model).filter(self.model.id == id).first() + return db.query(self.model).filter(self.model.id == uuid_obj).first() except Exception as e: logger.error(f"Error retrieving {self.model.__name__} with id {id}: {str(e)}") raise @@ -124,15 +128,18 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): def remove(self, db: Session, *, id: str) -> Optional[ModelType]: """Delete a record with error handling and null check.""" - # Validate UUID format + # Validate UUID format and convert to UUID object if string try: - uuid.UUID(id) - except (ValueError, AttributeError): - logger.warning(f"Invalid UUID format for deletion: {id}") + if isinstance(id, uuid.UUID): + uuid_obj = id + else: + uuid_obj = uuid.UUID(str(id)) + except (ValueError, AttributeError, TypeError) as e: + logger.warning(f"Invalid UUID format for deletion: {id} - {str(e)}") return None try: - obj = db.query(self.model).filter(self.model.id == id).first() + obj = db.query(self.model).filter(self.model.id == uuid_obj).first() if obj is None: logger.warning(f"{self.model.__name__} with id {id} not found for deletion") return None @@ -151,10 +158,25 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): raise def get_multi_with_total( - self, db: Session, *, skip: int = 0, limit: int = 100 + self, + db: Session, + *, + skip: int = 0, + limit: int = 100, + sort_by: Optional[str] = None, + sort_order: str = "asc", + filters: Optional[Dict[str, Any]] = None ) -> Tuple[List[ModelType], int]: """ - Get multiple records with total count for pagination. + Get multiple records with total count, filtering, and sorting. + + Args: + db: Database session + skip: Number of records to skip + limit: Maximum number of records to return + sort_by: Field name to sort by (must be a valid model attribute) + sort_order: Sort order ("asc" or "desc") + filters: Dictionary of filters (field_name: value) Returns: Tuple of (items, total_count) @@ -168,13 +190,115 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): raise ValueError("Maximum limit is 1000") try: - # Get total count - total = db.query(func.count(self.model.id)).scalar() + # Build base query + query = db.query(self.model) - # Get paginated items - items = db.query(self.model).offset(skip).limit(limit).all() + # Exclude soft-deleted records by default + if hasattr(self.model, 'deleted_at'): + query = query.filter(self.model.deleted_at.is_(None)) + + # Apply filters + if filters: + for field, value in filters.items(): + if hasattr(self.model, field) and value is not None: + query = query.filter(getattr(self.model, field) == value) + + # Get total count (before pagination) + total = query.count() + + # Apply sorting + if sort_by and hasattr(self.model, sort_by): + sort_column = getattr(self.model, sort_by) + if sort_order.lower() == "desc": + query = query.order_by(desc(sort_column)) + else: + query = query.order_by(asc(sort_column)) + + # Apply pagination + items = query.offset(skip).limit(limit).all() return items, total except Exception as e: logger.error(f"Error retrieving paginated {self.model.__name__} records: {str(e)}") + raise + + def soft_delete(self, db: Session, *, id: str) -> Optional[ModelType]: + """ + Soft delete a record by setting deleted_at timestamp. + + Only works if the model has a 'deleted_at' column. + """ + # Validate UUID format and convert to UUID object if string + try: + if isinstance(id, uuid.UUID): + uuid_obj = id + else: + uuid_obj = uuid.UUID(str(id)) + except (ValueError, AttributeError, TypeError) as e: + logger.warning(f"Invalid UUID format for soft deletion: {id} - {str(e)}") + return None + + try: + obj = db.query(self.model).filter(self.model.id == uuid_obj).first() + + if obj is None: + logger.warning(f"{self.model.__name__} with id {id} not found for soft deletion") + return None + + # Check if model supports soft deletes + if not hasattr(self.model, 'deleted_at'): + logger.error(f"{self.model.__name__} does not support soft deletes") + raise ValueError(f"{self.model.__name__} does not have a deleted_at column") + + # Set deleted_at timestamp + obj.deleted_at = datetime.now(timezone.utc) + db.add(obj) + db.commit() + db.refresh(obj) + return obj + except Exception as e: + db.rollback() + logger.error(f"Error soft deleting {self.model.__name__} with id {id}: {str(e)}", exc_info=True) + raise + + def restore(self, db: Session, *, id: str) -> Optional[ModelType]: + """ + Restore a soft-deleted record by clearing the deleted_at timestamp. + + Only works if the model has a 'deleted_at' column. + """ + # Validate UUID format + try: + if isinstance(id, uuid.UUID): + uuid_obj = id + else: + uuid_obj = uuid.UUID(str(id)) + except (ValueError, AttributeError, TypeError) as e: + logger.warning(f"Invalid UUID format for restoration: {id} - {str(e)}") + return None + + try: + # Find the soft-deleted record + if hasattr(self.model, 'deleted_at'): + obj = db.query(self.model).filter( + self.model.id == uuid_obj, + self.model.deleted_at.isnot(None) + ).first() + else: + logger.error(f"{self.model.__name__} does not support soft deletes") + raise ValueError(f"{self.model.__name__} does not have a deleted_at column") + + if obj is None: + logger.warning(f"Soft-deleted {self.model.__name__} with id {id} not found for restoration") + return None + + # Clear deleted_at timestamp + obj.deleted_at = None + db.add(obj) + db.commit() + db.refresh(obj) + return obj + except Exception as e: + db.rollback() + logger.error(f"Error restoring {self.model.__name__} with id {id}: {str(e)}", exc_info=True) raise \ No newline at end of file diff --git a/backend/app/schemas/common.py b/backend/app/schemas/common.py index f46cc4c..fa49199 100644 --- a/backend/app/schemas/common.py +++ b/backend/app/schemas/common.py @@ -1,7 +1,8 @@ """ -Common schemas used across the API for pagination, responses, etc. +Common schemas used across the API for pagination, responses, filtering, and sorting. """ from typing import Generic, TypeVar, List, Optional +from enum import Enum from pydantic import BaseModel, Field from math import ceil @@ -9,6 +10,12 @@ from math import ceil T = TypeVar('T') +class SortOrder(str, Enum): + """Sort order options.""" + ASC = "asc" + DESC = "desc" + + class PaginationParams(BaseModel): """Parameters for pagination.""" @@ -44,6 +51,28 @@ class PaginationParams(BaseModel): } +class SortParams(BaseModel): + """Parameters for sorting.""" + + sort_by: Optional[str] = Field( + default=None, + description="Field name to sort by" + ) + sort_order: SortOrder = Field( + default=SortOrder.ASC, + description="Sort order (asc or desc)" + ) + + model_config = { + "json_schema_extra": { + "example": { + "sort_by": "created_at", + "sort_order": "desc" + } + } + } + + class PaginationMeta(BaseModel): """Metadata for paginated responses.""" diff --git a/backend/tests/api/routes/test_auth.py b/backend/tests/api/routes/test_auth.py index dc3f99f..6b675d5 100644 --- a/backend/tests/api/routes/test_auth.py +++ b/backend/tests/api/routes/test_auth.py @@ -10,6 +10,7 @@ from fastapi.testclient import TestClient from sqlalchemy.orm import Session from app.api.routes.auth import router as auth_router +from app.api.routes.users import router as users_router from app.core.auth import get_password_hash from app.core.database import get_db from app.models.user import User @@ -29,6 +30,7 @@ def app(override_get_db): """Create a FastAPI test application with overridden dependencies.""" app = FastAPI() app.include_router(auth_router, prefix="/auth", tags=["auth"]) + app.include_router(users_router, prefix="/api/v1/users", tags=["users"]) # Override the get_db dependency app.dependency_overrides[get_db] = lambda: override_get_db @@ -280,9 +282,9 @@ class TestChangePassword: # Mock password change to return success with patch.object(AuthService, 'change_password', return_value=True): - # Test request - response = client.post( - "/auth/change-password", + # Test request (new endpoint) + response = client.patch( + "/api/v1/users/me/password", json={ "current_password": "OldPassword123", "new_password": "NewPassword123" @@ -291,7 +293,8 @@ class TestChangePassword: # Assertions assert response.status_code == 200 - assert "success" in response.json()["message"].lower() + assert response.json()["success"] is True + assert "message" in response.json() # Clean up override if get_current_user in app.dependency_overrides: @@ -312,18 +315,20 @@ class TestChangePassword: # Mock password change to raise error with patch.object(AuthService, 'change_password', side_effect=AuthenticationError("Current password is incorrect")): - # Test request - response = client.post( - "/auth/change-password", + # Test request (new endpoint) + response = client.patch( + "/api/v1/users/me/password", json={ "current_password": "WrongPassword", "new_password": "NewPassword123" } ) - # Assertions - assert response.status_code == 400 - assert "incorrect" in response.json()["detail"].lower() + # Assertions - Now returns standardized error response + assert response.status_code == 403 + # The response has standardized error format + data = response.json() + assert "detail" in data or "errors" in data # Clean up override if get_current_user in app.dependency_overrides: diff --git a/backend/tests/api/routes/test_health.py b/backend/tests/api/routes/test_health.py index d2dd18e..47f9bce 100644 --- a/backend/tests/api/routes/test_health.py +++ b/backend/tests/api/routes/test_health.py @@ -13,17 +13,10 @@ from app.core.database import get_db @pytest.fixture def client(): """Create a FastAPI test client for the main app with mocked database.""" - # Mock get_db to avoid connecting to the actual database - with patch("app.main.get_db") as mock_get_db: - def mock_session_generator(): - mock_session = MagicMock() - # Mock the execute method to return successfully - mock_session.execute.return_value = None - mock_session.close.return_value = None - yield mock_session - - # Return a new generator each time get_db is called - mock_get_db.side_effect = lambda: mock_session_generator() + # Mock check_database_health to avoid connecting to the actual database + with patch("app.main.check_database_health") as mock_health_check: + # By default, return True (healthy) + mock_health_check.return_value = True yield TestClient(app) @@ -90,23 +83,14 @@ class TestHealthEndpoint: assert data["environment"] == settings.ENVIRONMENT - def test_health_check_database_connection_failure(self, client): + def test_health_check_database_connection_failure(self): """Test that health check returns unhealthy when database is not accessible""" - # Mock the database session to raise an exception - with patch("app.main.get_db") as mock_get_db: - def mock_session(): - from unittest.mock import MagicMock - mock = MagicMock() - mock.execute.side_effect = OperationalError( - "Connection refused", - params=None, - orig=Exception("Connection refused") - ) - yield mock + # Mock check_database_health to return False (unhealthy) + with patch("app.main.check_database_health") as mock_health_check: + mock_health_check.return_value = False - mock_get_db.return_value = mock_session() - - response = client.get("/health") + test_client = TestClient(app) + response = test_client.get("/health") assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE data = response.json() diff --git a/backend/tests/api/routes/test_rate_limiting.py b/backend/tests/api/routes/test_rate_limiting.py index 97ad318..3f645c3 100644 --- a/backend/tests/api/routes/test_rate_limiting.py +++ b/backend/tests/api/routes/test_rate_limiting.py @@ -5,6 +5,7 @@ from fastapi.testclient import TestClient from unittest.mock import patch, MagicMock from app.api.routes.auth import router as auth_router, limiter +from app.api.routes.users import router as users_router from app.core.database import get_db @@ -26,6 +27,7 @@ def app(override_get_db): app.state.limiter = limiter app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) app.include_router(auth_router, prefix="/auth", tags=["auth"]) + app.include_router(users_router, prefix="/api/v1/users", tags=["users"]) # Override the get_db dependency app.dependency_overrides[get_db] = lambda: override_get_db @@ -159,10 +161,10 @@ class TestChangePasswordRateLimiting: "new_password": "NewPassword123!" } - # Make 6 requests (limit is 5/minute) + # Make 6 requests (limit is 5/minute) - using new endpoint responses = [] for i in range(6): - response = client.post("/auth/change-password", json=password_data) + response = client.patch("/api/v1/users/me/password", json=password_data) responses.append(response) # Last request should be rate limited