From c2cdc3c110277c3df18a7acb83b96166a320d3a5 Mon Sep 17 00:00:00 2001 From: Felipe Cardoso Date: Sun, 9 Mar 2025 16:04:51 +0100 Subject: [PATCH] Refactor event API and extend authentication utilities Refactored the event API routes to improve error handling, add logging, and provide enhanced response structures with pagination. Updated tests to use new fixtures and include additional authentication utilities to facilitate testing with FastAPI's dependency injection. Also resolved issues with timezone awareness in event schemas. --- backend/app/api/routes/events/router.py | 337 ++++++++++++++++---- backend/app/schemas/events.py | 9 +- backend/app/utils/auth_test_utils.py | 130 ++++++++ backend/tests/api/routes/events/__init__.py | 0 backend/tests/conftest.py | 46 ++- backend/tests/crud/test_event.py | 36 +-- backend/tests/crud/test_event_theme.py | 4 +- 7 files changed, 466 insertions(+), 96 deletions(-) create mode 100644 backend/app/utils/auth_test_utils.py create mode 100644 backend/tests/api/routes/events/__init__.py diff --git a/backend/app/api/routes/events/router.py b/backend/app/api/routes/events/router.py index bba5f16..832637a 100644 --- a/backend/app/api/routes/events/router.py +++ b/backend/app/api/routes/events/router.py @@ -1,24 +1,34 @@ -from datetime import datetime -from typing import List, Optional +from datetime import timezone +from typing import Optional, Any, Dict from uuid import UUID -from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi import APIRouter, Depends, HTTPException, Query, status +from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import Session -from api.dependencies.auth import get_current_user +from app.api.dependencies.auth import get_current_user, get_optional_current_user from app.core.database import get_db from app.crud.event import event +from app.models.event_manager import EventManager from app.models.user import User +from app.schemas.common import PaginatedResponse from app.schemas.events import ( EventCreate, EventUpdate, - EventResponse + EventResponse, ) +import logging +logger = logging.getLogger(__name__) events_router = APIRouter() -@events_router.post("/", response_model=EventResponse, operation_id="create_event") +@events_router.post( + "/", + response_model=EventResponse, + status_code=status.HTTP_201_CREATED, + operation_id="create_event" +) def create_event( *, db: Session = Depends(get_db), @@ -26,50 +36,129 @@ def create_event( current_user: User = Depends(get_current_user) ) -> EventResponse: """Create a new event.""" - # Check if slug is already taken - if event.get_by_slug(db, slug=event_in.slug): + try: + # Check if slug is already taken + if event.get_by_slug(db, slug=event_in.slug): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="An event with this slug already exists" + ) + + created_event = event.create_with_owner(db=db, obj_in=event_in, owner_id=current_user.id) + logger.info(f"Event created by {current_user.email}: {created_event.slug}") + return created_event + except SQLAlchemyError as e: + db.rollback() raise HTTPException( - status_code=400, - detail="An event with this slug already exists" + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Database error occurred" ) - return event.create_with_owner( - db=db, - obj_in=event_in, - owner_id=current_user.id - ) - -@events_router.get("/me", response_model=List[EventResponse], operation_id="get_user_events") +@events_router.get( + "/me", + response_model=PaginatedResponse[EventResponse], + operation_id="get_user_events" +) def get_user_events( + *, db: Session = Depends(get_db), - skip: int = 0, - limit: int = 100, - include_inactive: bool = False, + skip: int = Query(0, ge=0), + limit: int = Query(100, ge=1, le=500), + include_inactive: bool = Query(False), current_user: User = Depends(get_current_user) -) -> List[EventResponse]: - """Get all events created by the current user.""" - return event.get_user_events( - db=db, - user_id=current_user.id, - skip=skip, - limit=limit, - include_inactive=include_inactive - ) +) -> Dict[str, Any]: + """Get all events created by the current user with pagination.""" + try: + total = event.count_user_events( + db=db, + user_id=current_user.id, + include_inactive=include_inactive + ) + items = event.get_user_events( + db=db, + user_id=current_user.id, + skip=skip, + limit=limit, + include_inactive=include_inactive + ) + return { + "total": total, + "items": items, + "page": skip // limit + 1 if limit > 0 else 1, + "size": limit + } + except SQLAlchemyError: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Error retrieving events" + ) -@events_router.get("/upcoming", response_model=List[EventResponse], operation_id="get_upcoming_events") +@events_router.get( + "/upcoming", + response_model=PaginatedResponse[EventResponse], + operation_id="get_upcoming_events" +) def get_upcoming_events( *, db: Session = Depends(get_db), - skip: int = 0, - limit: int = 100 -) -> List[EventResponse]: - """Get upcoming public events.""" - return event.get_upcoming_events(db=db, skip=skip, limit=limit) + skip: int = Query(0, ge=0), + limit: int = Query(100, ge=1, le=500) +) -> Dict[str, Any]: + """Get upcoming public events with pagination.""" + try: + items = event.get_upcoming_events(db=db, skip=skip, limit=limit) + # Count total upcoming events for pagination + total = event.count_upcoming_events(db=db) + + return { + "total": total, + "items": items, + "page": skip // limit + 1 if limit > 0 else 1, + "size": limit + } + except SQLAlchemyError: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Error retrieving upcoming events" + ) -@events_router.get("/{event_id}", response_model=EventResponse, operation_id="get_event") +@events_router.get( + "/public", + response_model=PaginatedResponse[EventResponse], + operation_id="get_public_events" +) +def get_public_events( + *, + db: Session = Depends(get_db), + skip: int = Query(0, ge=0), + limit: int = Query(100, ge=1, le=500) +) -> Dict[str, Any]: + """Get all public events with pagination.""" + try: + items = event.get_public_events(db=db, skip=skip, limit=limit) + total = event.count_public_events(db=db) + + return { + "total": total, + "items": items, + "page": skip // limit + 1 if limit > 0 else 1, + "size": limit + } + except SQLAlchemyError: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Error retrieving public events" + ) + + +@events_router.get( + "/{event_id}", + response_model=EventResponse, + operation_id="get_event" +) def get_event( *, db: Session = Depends(get_db), @@ -77,27 +166,78 @@ def get_event( current_user: User = Depends(get_current_user) ) -> EventResponse: """Get event by ID.""" - event_obj = event.get(db=db, id=event_id) - if not event_obj: - raise HTTPException(status_code=404, detail="Event not found") - return event_obj + try: + event_obj = event.get(db=db, id=event_id) + if not event_obj: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Event not found" + ) + + # Check if user is creator or manager + if event_obj.created_by != current_user.id: + # Check if user is a manager + is_manager = db.query(EventManager).filter( + EventManager.event_id == event_id, + EventManager.user_id == current_user.id + ).first() + + if not is_manager and not current_user.is_superuser: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Not enough permissions" + ) + + return event_obj + except SQLAlchemyError: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Error retrieving event" + ) -@events_router.get("/by-slug/{slug}", response_model=EventResponse, operation_id="get_public_event") +@events_router.get( + "/by-slug/{slug}", + response_model=EventResponse, + operation_id="get_public_event" +) def get_public_event( *, db: Session = Depends(get_db), slug: str, - access_code: Optional[str] = Query(None) + access_code: Optional[str] = Query(None), + current_user: Optional[User] = Depends(get_optional_current_user) ) -> EventResponse: """Get public event by slug.""" - event_obj = event.get_public_event(db=db, slug=slug, access_code=access_code) - if not event_obj: - raise HTTPException(status_code=404, detail="Event not found") - return event_obj + try: + event_obj = event.get_public_event(db=db, slug=slug, access_code=access_code) + if not event_obj: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Event not found" + ) + + # If event is not public and user is not authenticated, check access code + if not event_obj.is_public and not current_user: + if not access_code or access_code != event_obj.access_code: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Invalid access code" + ) + + return event_obj + except SQLAlchemyError: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Error retrieving event" + ) -@events_router.put("/{event_id}", response_model=EventResponse, operation_id="update_event") +@events_router.put( + "/{event_id}", + response_model=EventResponse, + operation_id="update_event" +) def update_event( *, db: Session = Depends(get_db), @@ -106,36 +246,95 @@ def update_event( current_user: User = Depends(get_current_user) ) -> EventResponse: """Update event.""" - event_obj = event.get(db=db, id=event_id) - if not event_obj: - raise HTTPException(status_code=404, detail="Event not found") - if event_obj.created_by != current_user.id: - raise HTTPException(status_code=403, detail="Not enough permissions") - - # If slug is being updated, check if new slug is available - if event_in.slug and event_in.slug != event_obj.slug: - if event.get_by_slug(db, slug=event_in.slug): + try: + event_obj = event.get(db=db, id=event_id) + if not event_obj: raise HTTPException( - status_code=400, - detail="An event with this slug already exists" + status_code=status.HTTP_404_NOT_FOUND, + detail="Event not found" ) - return event.update(db=db, db_obj=event_obj, obj_in=event_in) + # Check permissions (creator or manager with edit rights) + has_permission = False + + if event_obj.created_by == current_user.id or current_user.is_superuser: + has_permission = True + else: + manager = db.query(EventManager).filter( + EventManager.event_id == event_id, + EventManager.user_id == current_user.id, + EventManager.can_edit == True + ).first() + has_permission = manager is not None + + if not has_permission: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Not enough permissions to edit this event" + ) + + # If slug is being updated, check if new slug is available and different + if event_in.slug and event_in.slug != event_obj.slug: + existing = event.get_by_slug(db, slug=event_in.slug) + if existing and existing.id != event_id: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="An event with this slug already exists" + ) + + return event.update(db=db, db_obj=event_obj, obj_in=event_in) + except SQLAlchemyError: + db.rollback() + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Error updating event" + ) -@events_router.delete("/{event_id}", operation_id="delete_event") +@events_router.delete( + "/{event_id}", + status_code=status.HTTP_204_NO_CONTENT, + operation_id="delete_event" +) def delete_event( *, db: Session = Depends(get_db), event_id: UUID, - current_user: User = Depends(get_current_user) + current_user: User = Depends(get_current_user), + hard_delete: bool = Query(False, description="Perform hard delete instead of soft delete") ): - """Delete event.""" - event_obj = event.get(db=db, id=event_id) - if not event_obj: - raise HTTPException(status_code=404, detail="Event not found") - if event_obj.created_by != current_user.id: - raise HTTPException(status_code=403, detail="Not enough permissions") + """Delete event (soft delete by default).""" + try: + event_obj = event.get(db=db, id=event_id) + if not event_obj: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Event not found" + ) - event.remove(db=db, id=event_id) - return {"message": "Event deleted successfully"} \ No newline at end of file + # Only creator or superuser can delete + if event_obj.created_by != current_user.id and not current_user.is_superuser: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Not enough permissions to delete this event" + ) + + if hard_delete: + # Hard delete - only for superusers + if not current_user.is_superuser: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Only administrators can perform hard delete" + ) + event.remove(db=db, id=event_id) + else: + # Soft delete - set is_active to False + event.update(db=db, db_obj=event_obj, obj_in={"is_active": False}) + + return None # 204 No Content + except SQLAlchemyError: + db.rollback() + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Error deleting event" + ) diff --git a/backend/app/schemas/events.py b/backend/app/schemas/events.py index d860011..9ceba50 100644 --- a/backend/app/schemas/events.py +++ b/backend/app/schemas/events.py @@ -41,7 +41,7 @@ class EventBase(BaseModel): @field_validator('event_date') def validate_event_date(cls, v): if not v.tzinfo: - raise ValueError("Event date must be timezone-aware") + v = v.replace(tzinfo=timezone.utc) now = datetime.now(tz=timezone.utc) if v < now: raise ValueError("Event date cannot be in the past") @@ -81,7 +81,12 @@ class EventInDBBase(EventBase): class EventResponse(EventInDBBase): - pass + @field_validator('event_date') + def validate_datetime(cls, v): + if v.tzinfo is None: + v.event_date = v.event_date.replace(tzinfo=timezone.utc) + return v + class Event(EventInDBBase): diff --git a/backend/app/utils/auth_test_utils.py b/backend/app/utils/auth_test_utils.py new file mode 100644 index 0000000..1580cdd --- /dev/null +++ b/backend/app/utils/auth_test_utils.py @@ -0,0 +1,130 @@ +""" +Authentication utilities for testing. +This module provides tools to bypass FastAPI's authentication in tests. +""" +from typing import Callable, Dict, Optional + +from fastapi import FastAPI +from fastapi.security import OAuth2PasswordBearer +from starlette.testclient import TestClient + +# Import these from wherever they are defined in your app +from api.dependencies.auth import get_current_user, get_optional_current_user +from app.models.user import User + + +def create_test_auth_client( + app: FastAPI, + test_user: User, + extra_overrides: Optional[Dict[Callable, Callable]] = None +) -> TestClient: + """ + Create a test client with authentication pre-configured. + + This bypasses the OAuth2 token validation and directly returns the test user. + + Args: + app: The FastAPI app to test + test_user: The user object to use for authentication + extra_overrides: Additional dependency overrides to apply + + Returns: + TestClient with authentication configured + """ + # First override the oauth2_scheme dependency to return a dummy token + # This prevents FastAPI from trying to extract a real bearer token from the request + oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login") + app.dependency_overrides[oauth2_scheme] = lambda: "dummy_token_for_testing" + + # Then override the get_current_user dependency to return our test user + app.dependency_overrides[get_current_user] = lambda: test_user + + # Apply any extra overrides + if extra_overrides: + for dep, override in extra_overrides.items(): + app.dependency_overrides[dep] = override + + # Create and return the client + return TestClient(app) + + +def create_test_optional_auth_client( + app: FastAPI, + test_user: User +) -> TestClient: + """ + Create a test client with optional authentication pre-configured. + + This is useful for testing endpoints that use get_optional_current_user. + + Args: + app: The FastAPI app to test + test_user: The user object to use for authentication + + Returns: + TestClient with optional authentication configured + """ + # Override the get_optional_current_user dependency + app.dependency_overrides[get_optional_current_user] = lambda: test_user + + # Create and return the client + return TestClient(app) + + +def create_test_superuser_client( + app: FastAPI, + test_user: User +) -> TestClient: + """ + Create a test client with superuser authentication pre-configured. + + Args: + app: The FastAPI app to test + test_user: The user object to use as superuser + + Returns: + TestClient with superuser authentication + """ + # Make sure user is a superuser + test_user.is_superuser = True + + # Use the auth client creation with superuser + return create_test_auth_client(app, test_user) + + +def create_test_unauthenticated_client(app: FastAPI) -> TestClient: + """ + Create a test client that will fail authentication checks. + + This is useful for testing the unauthorized case of protected endpoints. + + Args: + app: The FastAPI app to test + + Returns: + TestClient without authentication + """ + # Any authentication attempts will fail + return TestClient(app) + + +def cleanup_test_client_auth(app: FastAPI) -> None: + """ + Clean up authentication overrides from the FastAPI app. + + Call this after your tests to restore normal authentication behavior. + + Args: + app: The FastAPI app to clean up + """ + # Get all auth dependencies + auth_deps = [ + get_current_user, + get_optional_current_user, + OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login") + ] + + # Remove overrides + for dep in auth_deps: + if dep in app.dependency_overrides: + del app.dependency_overrides[dep] diff --git a/backend/tests/api/routes/events/__init__.py b/backend/tests/api/routes/events/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 5453e0d..79556e4 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -1,10 +1,17 @@ -# tests/conftest.py import uuid from datetime import datetime, timezone, timedelta from typing import Dict +from typing import Optional import pytest +from fastapi import FastAPI +from fastapi.routing import APIRouter +from fastapi.security import OAuth2PasswordBearer +from fastapi.testclient import TestClient +from sqlalchemy.orm import Session +from app.api.dependencies.auth import get_current_user +from app.core.database import get_db from app.models import Event, GiftItem, GiftStatus, GiftPriority, GiftCategory, GiftPurchase, RSVP, RSVPStatus, \ EventMedia, MediaType, MediaPurpose, \ EventTheme, Guest, GuestStatus, ActivityType, ActivityLog, EmailTemplate, TemplateType, NotificationLog, \ @@ -39,6 +46,7 @@ async def async_test_db(): yield test_engine, AsyncTestingSessionLocal await teardown_async_test_db(test_engine) + @pytest.fixture def user_create_data(): return { @@ -72,7 +80,7 @@ def mock_user(db_session): @pytest.fixture -def event_fixture(db_session, mock_user): +def mock_event(db_session, mock_user): """Create a test event fixture.""" event_data = { "title": "Birthday Party", @@ -91,7 +99,6 @@ def event_fixture(db_session, mock_user): return event - @pytest.fixture def gift_item_fixture(db_session, mock_user): """ @@ -308,7 +315,7 @@ def gift_purchase_fixture(db_session, gift_item_fixture, guest_fixture): @pytest.fixture -def gift_category_fixture(db_session, mock_user, event_fixture): +def gift_category_fixture(db_session, mock_user, mock_event): """ Fixture to create and return a GiftCategory instance. """ @@ -316,7 +323,7 @@ def gift_category_fixture(db_session, mock_user, event_fixture): id=uuid.uuid4(), name="Electronics", description="Category for electronic gifts", - event_id=event_fixture.id, + event_id=mock_event.id, created_by=mock_user.id, display_order=0, is_visible=True @@ -344,3 +351,32 @@ def theme_data() -> Dict: } } + +@pytest.fixture +def oauth2_scheme(): + """Return the OAuth2PasswordBearer instance used by the app.""" + return OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login") + +@pytest.fixture +def create_test_client(): + def _create_test_client(router: APIRouter, prefix: str, db_session, user): + from fastapi import FastAPI + app = FastAPI() + + # Mimic your dependency overrides here + def get_db_override(): + yield db_session + + def get_current_user_override(): + return user + + # Include your router + app.include_router(router, prefix=prefix) + + # Override dependencies + app.dependency_overrides[get_db] = get_db_override + app.dependency_overrides[get_current_user] = get_current_user_override + + return TestClient(app) + + return _create_test_client diff --git a/backend/tests/crud/test_event.py b/backend/tests/crud/test_event.py index 1a671e3..31dd96e 100644 --- a/backend/tests/crud/test_event.py +++ b/backend/tests/crud/test_event.py @@ -28,23 +28,23 @@ def test_create_event(db_session, mock_user): assert event.created_by == mock_user.id -def test_get_event(db_session, event_fixture): +def test_get_event(db_session, mock_event): """Test retrieving an event by ID.""" - stored_event = crud_event.get(db=db_session, id=event_fixture.id) + stored_event = crud_event.get(db=db_session, id=mock_event.id) assert stored_event - assert stored_event.id == event_fixture.id - assert stored_event.title == event_fixture.title + assert stored_event.id == mock_event.id + assert stored_event.title == mock_event.title -def test_get_event_by_slug(db_session, event_fixture): +def test_get_event_by_slug(db_session, mock_event): """Test retrieving an event by slug.""" - stored_event = crud_event.get_by_slug(db=db_session, slug=event_fixture.slug) + stored_event = crud_event.get_by_slug(db=db_session, slug=mock_event.slug) assert stored_event - assert stored_event.id == event_fixture.id - assert stored_event.slug == event_fixture.slug + assert stored_event.id == mock_event.id + assert stored_event.slug == mock_event.slug -def test_get_user_events(db_session, event_fixture, mock_user): +def test_get_user_events(db_session, mock_event, mock_user): """Test retrieving all events for a specific user.""" events = crud_event.get_user_events( db=db_session, @@ -53,10 +53,10 @@ def test_get_user_events(db_session, event_fixture, mock_user): limit=100 ) assert len(events) > 0 - assert any(event.id == event_fixture.id for event in events) + assert any(event.id == mock_event.id for event in events) -def test_update_event(db_session, event_fixture): +def test_update_event(db_session, mock_event): """Test updating an event.""" update_data = EventUpdate( title="Updated Birthday Party", @@ -64,18 +64,18 @@ def test_update_event(db_session, event_fixture): ) updated_event = crud_event.update( db=db_session, - db_obj=event_fixture, + db_obj=mock_event, obj_in=update_data ) assert updated_event.title == "Updated Birthday Party" assert updated_event.description == "Updated description" -def test_delete_event(db_session, event_fixture): +def test_delete_event(db_session, mock_event): """Test deleting an event.""" - event = crud_event.remove(db=db_session, id=event_fixture.id) - assert event.id == event_fixture.id - deleted_event = crud_event.get(db=db_session, id=event_fixture.id) + event = crud_event.remove(db=db_session, id=mock_event.id) + assert event.id == mock_event.id + deleted_event = crud_event.get(db=db_session, id=mock_event.id) assert deleted_event is None @@ -115,11 +115,11 @@ def test_get_upcoming_events(db_session): assert event_date >= current_time -def test_get_public_event(db_session, event_fixture): +def test_get_public_event(db_session, mock_event): """Test retrieving a public event.""" public_event = crud_event.get_public_event( db=db_session, - slug=event_fixture.slug + slug=mock_event.slug ) assert public_event is not None assert public_event.is_public is True diff --git a/backend/tests/crud/test_event_theme.py b/backend/tests/crud/test_event_theme.py index d588a8b..041a75a 100644 --- a/backend/tests/crud/test_event_theme.py +++ b/backend/tests/crud/test_event_theme.py @@ -77,9 +77,9 @@ def test_delete_event_theme(db_session: Session, event_theme_fixture) -> None: assert theme is None -def test_get_active_themes(db_session: Session, event_theme_fixture, event_fixture) -> None: +def test_get_active_themes(db_session: Session, event_theme_fixture, mock_event) -> None: # First, ensure the theme is associated with an event - event_fixture.theme_id = event_theme_fixture.id + mock_event.theme_id = event_theme_fixture.id db_session.commit() # Get active themes