Refactor event API and extend authentication utilities
All checks were successful
Build and Push Docker Images / changes (push) Successful in 6s
Build and Push Docker Images / build-backend (push) Successful in 55s
Build and Push Docker Images / build-frontend (push) Has been skipped

Refactored the event API routes to improve error handling, add logging, and provide enhanced response structures with pagination. Updated tests to use new fixtures and include additional authentication utilities to facilitate testing with FastAPI's dependency injection. Also resolved issues with timezone awareness in event schemas.
This commit is contained in:
2025-03-09 16:04:51 +01:00
parent fe2bcbd6e7
commit c2cdc3c110
7 changed files with 466 additions and 96 deletions

View File

@@ -1,24 +1,34 @@
from datetime import datetime from datetime import timezone
from typing import List, Optional from typing import Optional, Any, Dict
from uuid import UUID from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from api.dependencies.auth import get_current_user from app.api.dependencies.auth import get_current_user, get_optional_current_user
from app.core.database import get_db from app.core.database import get_db
from app.crud.event import event from app.crud.event import event
from app.models.event_manager import EventManager
from app.models.user import User from app.models.user import User
from app.schemas.common import PaginatedResponse
from app.schemas.events import ( from app.schemas.events import (
EventCreate, EventCreate,
EventUpdate, EventUpdate,
EventResponse EventResponse,
) )
import logging
logger = logging.getLogger(__name__)
events_router = APIRouter() events_router = APIRouter()
@events_router.post("/", response_model=EventResponse, operation_id="create_event") @events_router.post(
"/",
response_model=EventResponse,
status_code=status.HTTP_201_CREATED,
operation_id="create_event"
)
def create_event( def create_event(
*, *,
db: Session = Depends(get_db), db: Session = Depends(get_db),
@@ -26,50 +36,129 @@ def create_event(
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
) -> EventResponse: ) -> EventResponse:
"""Create a new event.""" """Create a new event."""
# Check if slug is already taken try:
if event.get_by_slug(db, slug=event_in.slug): # Check if slug is already taken
if event.get_by_slug(db, slug=event_in.slug):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="An event with this slug already exists"
)
created_event = event.create_with_owner(db=db, obj_in=event_in, owner_id=current_user.id)
logger.info(f"Event created by {current_user.email}: {created_event.slug}")
return created_event
except SQLAlchemyError as e:
db.rollback()
raise HTTPException( raise HTTPException(
status_code=400, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An event with this slug already exists" detail=f"Database error occurred"
) )
return event.create_with_owner(
db=db,
obj_in=event_in,
owner_id=current_user.id
)
@events_router.get(
@events_router.get("/me", response_model=List[EventResponse], operation_id="get_user_events") "/me",
response_model=PaginatedResponse[EventResponse],
operation_id="get_user_events"
)
def get_user_events( def get_user_events(
*,
db: Session = Depends(get_db), db: Session = Depends(get_db),
skip: int = 0, skip: int = Query(0, ge=0),
limit: int = 100, limit: int = Query(100, ge=1, le=500),
include_inactive: bool = False, include_inactive: bool = Query(False),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
) -> List[EventResponse]: ) -> Dict[str, Any]:
"""Get all events created by the current user.""" """Get all events created by the current user with pagination."""
return event.get_user_events( try:
db=db, total = event.count_user_events(
user_id=current_user.id, db=db,
skip=skip, user_id=current_user.id,
limit=limit, include_inactive=include_inactive
include_inactive=include_inactive )
) items = event.get_user_events(
db=db,
user_id=current_user.id,
skip=skip,
limit=limit,
include_inactive=include_inactive
)
return {
"total": total,
"items": items,
"page": skip // limit + 1 if limit > 0 else 1,
"size": limit
}
except SQLAlchemyError:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error retrieving events"
)
@events_router.get("/upcoming", response_model=List[EventResponse], operation_id="get_upcoming_events") @events_router.get(
"/upcoming",
response_model=PaginatedResponse[EventResponse],
operation_id="get_upcoming_events"
)
def get_upcoming_events( def get_upcoming_events(
*, *,
db: Session = Depends(get_db), db: Session = Depends(get_db),
skip: int = 0, skip: int = Query(0, ge=0),
limit: int = 100 limit: int = Query(100, ge=1, le=500)
) -> List[EventResponse]: ) -> Dict[str, Any]:
"""Get upcoming public events.""" """Get upcoming public events with pagination."""
return event.get_upcoming_events(db=db, skip=skip, limit=limit) try:
items = event.get_upcoming_events(db=db, skip=skip, limit=limit)
# Count total upcoming events for pagination
total = event.count_upcoming_events(db=db)
return {
"total": total,
"items": items,
"page": skip // limit + 1 if limit > 0 else 1,
"size": limit
}
except SQLAlchemyError:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error retrieving upcoming events"
)
@events_router.get("/{event_id}", response_model=EventResponse, operation_id="get_event") @events_router.get(
"/public",
response_model=PaginatedResponse[EventResponse],
operation_id="get_public_events"
)
def get_public_events(
*,
db: Session = Depends(get_db),
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=500)
) -> Dict[str, Any]:
"""Get all public events with pagination."""
try:
items = event.get_public_events(db=db, skip=skip, limit=limit)
total = event.count_public_events(db=db)
return {
"total": total,
"items": items,
"page": skip // limit + 1 if limit > 0 else 1,
"size": limit
}
except SQLAlchemyError:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error retrieving public events"
)
@events_router.get(
"/{event_id}",
response_model=EventResponse,
operation_id="get_event"
)
def get_event( def get_event(
*, *,
db: Session = Depends(get_db), db: Session = Depends(get_db),
@@ -77,27 +166,78 @@ def get_event(
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
) -> EventResponse: ) -> EventResponse:
"""Get event by ID.""" """Get event by ID."""
event_obj = event.get(db=db, id=event_id) try:
if not event_obj: event_obj = event.get(db=db, id=event_id)
raise HTTPException(status_code=404, detail="Event not found") if not event_obj:
return event_obj raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Event not found"
)
# Check if user is creator or manager
if event_obj.created_by != current_user.id:
# Check if user is a manager
is_manager = db.query(EventManager).filter(
EventManager.event_id == event_id,
EventManager.user_id == current_user.id
).first()
if not is_manager and not current_user.is_superuser:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not enough permissions"
)
return event_obj
except SQLAlchemyError:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error retrieving event"
)
@events_router.get("/by-slug/{slug}", response_model=EventResponse, operation_id="get_public_event") @events_router.get(
"/by-slug/{slug}",
response_model=EventResponse,
operation_id="get_public_event"
)
def get_public_event( def get_public_event(
*, *,
db: Session = Depends(get_db), db: Session = Depends(get_db),
slug: str, slug: str,
access_code: Optional[str] = Query(None) access_code: Optional[str] = Query(None),
current_user: Optional[User] = Depends(get_optional_current_user)
) -> EventResponse: ) -> EventResponse:
"""Get public event by slug.""" """Get public event by slug."""
event_obj = event.get_public_event(db=db, slug=slug, access_code=access_code) try:
if not event_obj: event_obj = event.get_public_event(db=db, slug=slug, access_code=access_code)
raise HTTPException(status_code=404, detail="Event not found") if not event_obj:
return event_obj raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Event not found"
)
# If event is not public and user is not authenticated, check access code
if not event_obj.is_public and not current_user:
if not access_code or access_code != event_obj.access_code:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid access code"
)
return event_obj
except SQLAlchemyError:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error retrieving event"
)
@events_router.put("/{event_id}", response_model=EventResponse, operation_id="update_event") @events_router.put(
"/{event_id}",
response_model=EventResponse,
operation_id="update_event"
)
def update_event( def update_event(
*, *,
db: Session = Depends(get_db), db: Session = Depends(get_db),
@@ -106,36 +246,95 @@ def update_event(
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
) -> EventResponse: ) -> EventResponse:
"""Update event.""" """Update event."""
event_obj = event.get(db=db, id=event_id) try:
if not event_obj: event_obj = event.get(db=db, id=event_id)
raise HTTPException(status_code=404, detail="Event not found") if not event_obj:
if event_obj.created_by != current_user.id:
raise HTTPException(status_code=403, detail="Not enough permissions")
# If slug is being updated, check if new slug is available
if event_in.slug and event_in.slug != event_obj.slug:
if event.get_by_slug(db, slug=event_in.slug):
raise HTTPException( raise HTTPException(
status_code=400, status_code=status.HTTP_404_NOT_FOUND,
detail="An event with this slug already exists" detail="Event not found"
) )
return event.update(db=db, db_obj=event_obj, obj_in=event_in) # Check permissions (creator or manager with edit rights)
has_permission = False
if event_obj.created_by == current_user.id or current_user.is_superuser:
has_permission = True
else:
manager = db.query(EventManager).filter(
EventManager.event_id == event_id,
EventManager.user_id == current_user.id,
EventManager.can_edit == True
).first()
has_permission = manager is not None
if not has_permission:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not enough permissions to edit this event"
)
# If slug is being updated, check if new slug is available and different
if event_in.slug and event_in.slug != event_obj.slug:
existing = event.get_by_slug(db, slug=event_in.slug)
if existing and existing.id != event_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="An event with this slug already exists"
)
return event.update(db=db, db_obj=event_obj, obj_in=event_in)
except SQLAlchemyError:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error updating event"
)
@events_router.delete("/{event_id}", operation_id="delete_event") @events_router.delete(
"/{event_id}",
status_code=status.HTTP_204_NO_CONTENT,
operation_id="delete_event"
)
def delete_event( def delete_event(
*, *,
db: Session = Depends(get_db), db: Session = Depends(get_db),
event_id: UUID, event_id: UUID,
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
hard_delete: bool = Query(False, description="Perform hard delete instead of soft delete")
): ):
"""Delete event.""" """Delete event (soft delete by default)."""
event_obj = event.get(db=db, id=event_id) try:
if not event_obj: event_obj = event.get(db=db, id=event_id)
raise HTTPException(status_code=404, detail="Event not found") if not event_obj:
if event_obj.created_by != current_user.id: raise HTTPException(
raise HTTPException(status_code=403, detail="Not enough permissions") status_code=status.HTTP_404_NOT_FOUND,
detail="Event not found"
)
event.remove(db=db, id=event_id) # Only creator or superuser can delete
return {"message": "Event deleted successfully"} if event_obj.created_by != current_user.id and not current_user.is_superuser:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not enough permissions to delete this event"
)
if hard_delete:
# Hard delete - only for superusers
if not current_user.is_superuser:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only administrators can perform hard delete"
)
event.remove(db=db, id=event_id)
else:
# Soft delete - set is_active to False
event.update(db=db, db_obj=event_obj, obj_in={"is_active": False})
return None # 204 No Content
except SQLAlchemyError:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error deleting event"
)

View File

@@ -41,7 +41,7 @@ class EventBase(BaseModel):
@field_validator('event_date') @field_validator('event_date')
def validate_event_date(cls, v): def validate_event_date(cls, v):
if not v.tzinfo: if not v.tzinfo:
raise ValueError("Event date must be timezone-aware") v = v.replace(tzinfo=timezone.utc)
now = datetime.now(tz=timezone.utc) now = datetime.now(tz=timezone.utc)
if v < now: if v < now:
raise ValueError("Event date cannot be in the past") raise ValueError("Event date cannot be in the past")
@@ -81,7 +81,12 @@ class EventInDBBase(EventBase):
class EventResponse(EventInDBBase): class EventResponse(EventInDBBase):
pass @field_validator('event_date')
def validate_datetime(cls, v):
if v.tzinfo is None:
v.event_date = v.event_date.replace(tzinfo=timezone.utc)
return v
class Event(EventInDBBase): class Event(EventInDBBase):

View File

@@ -0,0 +1,130 @@
"""
Authentication utilities for testing.
This module provides tools to bypass FastAPI's authentication in tests.
"""
from typing import Callable, Dict, Optional
from fastapi import FastAPI
from fastapi.security import OAuth2PasswordBearer
from starlette.testclient import TestClient
# Import these from wherever they are defined in your app
from api.dependencies.auth import get_current_user, get_optional_current_user
from app.models.user import User
def create_test_auth_client(
app: FastAPI,
test_user: User,
extra_overrides: Optional[Dict[Callable, Callable]] = None
) -> TestClient:
"""
Create a test client with authentication pre-configured.
This bypasses the OAuth2 token validation and directly returns the test user.
Args:
app: The FastAPI app to test
test_user: The user object to use for authentication
extra_overrides: Additional dependency overrides to apply
Returns:
TestClient with authentication configured
"""
# First override the oauth2_scheme dependency to return a dummy token
# This prevents FastAPI from trying to extract a real bearer token from the request
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
app.dependency_overrides[oauth2_scheme] = lambda: "dummy_token_for_testing"
# Then override the get_current_user dependency to return our test user
app.dependency_overrides[get_current_user] = lambda: test_user
# Apply any extra overrides
if extra_overrides:
for dep, override in extra_overrides.items():
app.dependency_overrides[dep] = override
# Create and return the client
return TestClient(app)
def create_test_optional_auth_client(
app: FastAPI,
test_user: User
) -> TestClient:
"""
Create a test client with optional authentication pre-configured.
This is useful for testing endpoints that use get_optional_current_user.
Args:
app: The FastAPI app to test
test_user: The user object to use for authentication
Returns:
TestClient with optional authentication configured
"""
# Override the get_optional_current_user dependency
app.dependency_overrides[get_optional_current_user] = lambda: test_user
# Create and return the client
return TestClient(app)
def create_test_superuser_client(
app: FastAPI,
test_user: User
) -> TestClient:
"""
Create a test client with superuser authentication pre-configured.
Args:
app: The FastAPI app to test
test_user: The user object to use as superuser
Returns:
TestClient with superuser authentication
"""
# Make sure user is a superuser
test_user.is_superuser = True
# Use the auth client creation with superuser
return create_test_auth_client(app, test_user)
def create_test_unauthenticated_client(app: FastAPI) -> TestClient:
"""
Create a test client that will fail authentication checks.
This is useful for testing the unauthorized case of protected endpoints.
Args:
app: The FastAPI app to test
Returns:
TestClient without authentication
"""
# Any authentication attempts will fail
return TestClient(app)
def cleanup_test_client_auth(app: FastAPI) -> None:
"""
Clean up authentication overrides from the FastAPI app.
Call this after your tests to restore normal authentication behavior.
Args:
app: The FastAPI app to clean up
"""
# Get all auth dependencies
auth_deps = [
get_current_user,
get_optional_current_user,
OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
]
# Remove overrides
for dep in auth_deps:
if dep in app.dependency_overrides:
del app.dependency_overrides[dep]

View File

@@ -1,10 +1,17 @@
# tests/conftest.py
import uuid import uuid
from datetime import datetime, timezone, timedelta from datetime import datetime, timezone, timedelta
from typing import Dict from typing import Dict
from typing import Optional
import pytest import pytest
from fastapi import FastAPI
from fastapi.routing import APIRouter
from fastapi.security import OAuth2PasswordBearer
from fastapi.testclient import TestClient
from sqlalchemy.orm import Session
from app.api.dependencies.auth import get_current_user
from app.core.database import get_db
from app.models import Event, GiftItem, GiftStatus, GiftPriority, GiftCategory, GiftPurchase, RSVP, RSVPStatus, \ from app.models import Event, GiftItem, GiftStatus, GiftPriority, GiftCategory, GiftPurchase, RSVP, RSVPStatus, \
EventMedia, MediaType, MediaPurpose, \ EventMedia, MediaType, MediaPurpose, \
EventTheme, Guest, GuestStatus, ActivityType, ActivityLog, EmailTemplate, TemplateType, NotificationLog, \ EventTheme, Guest, GuestStatus, ActivityType, ActivityLog, EmailTemplate, TemplateType, NotificationLog, \
@@ -39,6 +46,7 @@ async def async_test_db():
yield test_engine, AsyncTestingSessionLocal yield test_engine, AsyncTestingSessionLocal
await teardown_async_test_db(test_engine) await teardown_async_test_db(test_engine)
@pytest.fixture @pytest.fixture
def user_create_data(): def user_create_data():
return { return {
@@ -72,7 +80,7 @@ def mock_user(db_session):
@pytest.fixture @pytest.fixture
def event_fixture(db_session, mock_user): def mock_event(db_session, mock_user):
"""Create a test event fixture.""" """Create a test event fixture."""
event_data = { event_data = {
"title": "Birthday Party", "title": "Birthday Party",
@@ -91,7 +99,6 @@ def event_fixture(db_session, mock_user):
return event return event
@pytest.fixture @pytest.fixture
def gift_item_fixture(db_session, mock_user): def gift_item_fixture(db_session, mock_user):
""" """
@@ -308,7 +315,7 @@ def gift_purchase_fixture(db_session, gift_item_fixture, guest_fixture):
@pytest.fixture @pytest.fixture
def gift_category_fixture(db_session, mock_user, event_fixture): def gift_category_fixture(db_session, mock_user, mock_event):
""" """
Fixture to create and return a GiftCategory instance. Fixture to create and return a GiftCategory instance.
""" """
@@ -316,7 +323,7 @@ def gift_category_fixture(db_session, mock_user, event_fixture):
id=uuid.uuid4(), id=uuid.uuid4(),
name="Electronics", name="Electronics",
description="Category for electronic gifts", description="Category for electronic gifts",
event_id=event_fixture.id, event_id=mock_event.id,
created_by=mock_user.id, created_by=mock_user.id,
display_order=0, display_order=0,
is_visible=True is_visible=True
@@ -344,3 +351,32 @@ def theme_data() -> Dict:
} }
} }
@pytest.fixture
def oauth2_scheme():
"""Return the OAuth2PasswordBearer instance used by the app."""
return OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
@pytest.fixture
def create_test_client():
def _create_test_client(router: APIRouter, prefix: str, db_session, user):
from fastapi import FastAPI
app = FastAPI()
# Mimic your dependency overrides here
def get_db_override():
yield db_session
def get_current_user_override():
return user
# Include your router
app.include_router(router, prefix=prefix)
# Override dependencies
app.dependency_overrides[get_db] = get_db_override
app.dependency_overrides[get_current_user] = get_current_user_override
return TestClient(app)
return _create_test_client

View File

@@ -28,23 +28,23 @@ def test_create_event(db_session, mock_user):
assert event.created_by == mock_user.id assert event.created_by == mock_user.id
def test_get_event(db_session, event_fixture): def test_get_event(db_session, mock_event):
"""Test retrieving an event by ID.""" """Test retrieving an event by ID."""
stored_event = crud_event.get(db=db_session, id=event_fixture.id) stored_event = crud_event.get(db=db_session, id=mock_event.id)
assert stored_event assert stored_event
assert stored_event.id == event_fixture.id assert stored_event.id == mock_event.id
assert stored_event.title == event_fixture.title assert stored_event.title == mock_event.title
def test_get_event_by_slug(db_session, event_fixture): def test_get_event_by_slug(db_session, mock_event):
"""Test retrieving an event by slug.""" """Test retrieving an event by slug."""
stored_event = crud_event.get_by_slug(db=db_session, slug=event_fixture.slug) stored_event = crud_event.get_by_slug(db=db_session, slug=mock_event.slug)
assert stored_event assert stored_event
assert stored_event.id == event_fixture.id assert stored_event.id == mock_event.id
assert stored_event.slug == event_fixture.slug assert stored_event.slug == mock_event.slug
def test_get_user_events(db_session, event_fixture, mock_user): def test_get_user_events(db_session, mock_event, mock_user):
"""Test retrieving all events for a specific user.""" """Test retrieving all events for a specific user."""
events = crud_event.get_user_events( events = crud_event.get_user_events(
db=db_session, db=db_session,
@@ -53,10 +53,10 @@ def test_get_user_events(db_session, event_fixture, mock_user):
limit=100 limit=100
) )
assert len(events) > 0 assert len(events) > 0
assert any(event.id == event_fixture.id for event in events) assert any(event.id == mock_event.id for event in events)
def test_update_event(db_session, event_fixture): def test_update_event(db_session, mock_event):
"""Test updating an event.""" """Test updating an event."""
update_data = EventUpdate( update_data = EventUpdate(
title="Updated Birthday Party", title="Updated Birthday Party",
@@ -64,18 +64,18 @@ def test_update_event(db_session, event_fixture):
) )
updated_event = crud_event.update( updated_event = crud_event.update(
db=db_session, db=db_session,
db_obj=event_fixture, db_obj=mock_event,
obj_in=update_data obj_in=update_data
) )
assert updated_event.title == "Updated Birthday Party" assert updated_event.title == "Updated Birthday Party"
assert updated_event.description == "Updated description" assert updated_event.description == "Updated description"
def test_delete_event(db_session, event_fixture): def test_delete_event(db_session, mock_event):
"""Test deleting an event.""" """Test deleting an event."""
event = crud_event.remove(db=db_session, id=event_fixture.id) event = crud_event.remove(db=db_session, id=mock_event.id)
assert event.id == event_fixture.id assert event.id == mock_event.id
deleted_event = crud_event.get(db=db_session, id=event_fixture.id) deleted_event = crud_event.get(db=db_session, id=mock_event.id)
assert deleted_event is None assert deleted_event is None
@@ -115,11 +115,11 @@ def test_get_upcoming_events(db_session):
assert event_date >= current_time assert event_date >= current_time
def test_get_public_event(db_session, event_fixture): def test_get_public_event(db_session, mock_event):
"""Test retrieving a public event.""" """Test retrieving a public event."""
public_event = crud_event.get_public_event( public_event = crud_event.get_public_event(
db=db_session, db=db_session,
slug=event_fixture.slug slug=mock_event.slug
) )
assert public_event is not None assert public_event is not None
assert public_event.is_public is True assert public_event.is_public is True

View File

@@ -77,9 +77,9 @@ def test_delete_event_theme(db_session: Session, event_theme_fixture) -> None:
assert theme is None assert theme is None
def test_get_active_themes(db_session: Session, event_theme_fixture, event_fixture) -> None: def test_get_active_themes(db_session: Session, event_theme_fixture, mock_event) -> None:
# First, ensure the theme is associated with an event # First, ensure the theme is associated with an event
event_fixture.theme_id = event_theme_fixture.id mock_event.theme_id = event_theme_fixture.id
db_session.commit() db_session.commit()
# Get active themes # Get active themes