Restrict event access and add extensive event tests
All checks were successful
Build and Push Docker Images / changes (push) Successful in 4s
Build and Push Docker Images / build-backend (push) Successful in 51s
Build and Push Docker Images / build-frontend (push) Has been skipped

Updated event API to enforce stricter access controls based on user roles, including creators, managers, superusers, and guests. Added robust test cases for creating, fetching, and handling event access scenarios to ensure consistent behavior across endpoints.
This commit is contained in:
2025-03-09 17:37:48 +01:00
parent 4192911538
commit c5915e57b1
3 changed files with 681 additions and 16 deletions

View File

@@ -1,14 +1,16 @@
from datetime import timezone
import logging
from typing import Optional, Any, Dict
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session
from starlette.config import environ
from app.api.dependencies.auth import get_current_user, get_optional_current_user
from app.core.database import get_db
from app.crud.event import event
from app.models import Guest
from app.models.event_manager import EventManager
from app.models.user import User
from app.schemas.common import PaginatedResponse
@@ -18,7 +20,6 @@ from app.schemas.events import (
EventResponse,
)
import logging
logger = logging.getLogger(__name__)
events_router = APIRouter()
@@ -116,9 +117,17 @@ def get_upcoming_events(
*,
db: Session = Depends(get_db),
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=500)
limit: int = Query(100, ge=1, le=500),
current_user: User = Depends(get_current_user)
) -> Dict[str, Any]:
"""Get upcoming public events with pagination."""
if current_user is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
items = event.get_upcoming_events(db=db, skip=skip, limit=limit)
# Count total upcoming events for pagination
@@ -184,6 +193,7 @@ def get_event(
detail="Invalid authentication credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
event_obj = event.get(db=db, id=event_id)
if not event_obj:
@@ -192,21 +202,38 @@ def get_event(
detail="Event not found"
)
# Check if user is creator or manager
if event_obj.created_by != current_user.id:
# Check if user is a manager
is_manager = db.query(EventManager).filter(
EventManager.event_id == event_id,
EventManager.user_id == current_user.id
).first()
# Allow direct access if user is creator or superuser
if event_obj.created_by == current_user.id or current_user.is_superuser:
return event_obj
if not is_manager and not current_user.is_superuser:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not enough permissions"
)
# Allow direct access if the user is managing the event
is_manager = db.query(EventManager).filter(
EventManager.event_id == event_id,
EventManager.user_id == current_user.id
).first()
if is_manager:
return event_obj
# Allow access if the event is public
if event_obj.is_public:
return event_obj
# Allow access if the user is explicitly invited (Guest)
guest_entry = db.query(Guest).filter(
Guest.event_id == event_id,
Guest.user_id == current_user.id
).first()
if guest_entry:
return event_obj
# User does not meet any permitted criteria; deny access
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not enough permissions to access this event"
)
return event_obj
except SQLAlchemyError:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,

View File

@@ -0,0 +1,621 @@
import uuid
from datetime import datetime, timedelta, timezone
from typing import Optional, List
from uuid import UUID, uuid4
import pytest
from fastapi import status
from app.api.routes.events.router import events_router
from app.models.event import Event
from app.models import EventManager
from app.models import User
from app.models.event_manager import EventManagerRole, ROLE_PERMISSIONS
from app.models import Guest
@pytest.fixture
def event_data():
future_date = datetime.now(tz=timezone.utc) + timedelta(days=30)
slug = f"test-event-{uuid4().hex[:8]}"
return {
"title": "Test Event",
"slug": slug,
"description": "Test description",
"event_date": future_date.isoformat(),
"timezone": "UTC",
"is_public": True
}
@pytest.fixture
def past_event_data():
past_date = datetime.now(tz=timezone.utc) - timedelta(days=1)
slug = f"past-event-{uuid4().hex[:8]}"
return {
"title": "Past Event",
"slug": slug,
"description": "This event date is in the past",
"event_date": past_date.isoformat(),
"timezone": "UTC",
"is_public": True
}
@pytest.fixture
def invalid_slug_event_data():
future_date = datetime.now(tz=timezone.utc) + timedelta(days=30)
return {
"title": "Invalid Slug Event",
"slug": "INVALID Slug!!",
"description": "Event with invalid slug",
"event_date": future_date.isoformat(),
"timezone": "UTC",
"is_public": True
}
class TestCreateEvent:
"""Test scenarios for the create_event endpoint."""
@pytest.fixture(autouse=True)
def setup_method(self, create_test_client, db_session, mock_user):
self.client = create_test_client(
router=events_router,
prefix="/events",
db_session=db_session,
user=mock_user
)
self.db_session = db_session
self.mock_user = mock_user
def test_create_event_success(self, db_session, mock_user, event_data):
response = self.client.post("/events/", json=event_data)
assert response.status_code == status.HTTP_201_CREATED, response.json()
resp_json = response.json()
assert resp_json["title"] == event_data["title"]
assert resp_json["slug"] == event_data["slug"]
event_id = UUID(resp_json["id"])
db_event = db_session.query(Event).filter(Event.id == event_id).first()
assert db_event is not None
def test_create_event_missing_required_fields(self, db_session, mock_user):
incomplete_data = {
"title": "Incomplete Event"
# Missing required fields like slug and event_date
}
response = self.client.post("/events/", json=incomplete_data)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
def test_create_event_with_past_date_fails(self, db_session, mock_user, past_event_data):
response = self.client.post("/events/", json=past_event_data)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY, response.json()
assert "Event date cannot be in the past" in response.text
def test_create_event_invalid_slug_fails(self, db_session, mock_user, invalid_slug_event_data):
response = self.client.post("/events/", json=invalid_slug_event_data)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY, response.json()
assert "String should match pattern" in response.text
def test_create_event_unauthorized_fails(self, create_test_client, db_session, event_data):
client = create_test_client(
router=events_router,
prefix="/events",
db_session=db_session,
user=None
)
response = client.post("/events/", json=event_data)
assert response.status_code == status.HTTP_401_UNAUTHORIZED, response.json()
assert "Invalid authentication credentials" in response.text
class TestGetUserEvents:
"""Tests for the get_user_events endpoint."""
@pytest.fixture(autouse=True)
def setup_method(self, create_test_client, db_session, mock_user):
self.client = create_test_client(
router=events_router,
prefix="/events",
db_session=db_session,
user=mock_user
)
self.db_session = db_session
self.mock_user = mock_user
def create_mock_events(self, num_events, active=True, public=True):
"""Utility function to create mock events in the database."""
events = []
for i in range(num_events):
event_date = datetime.now(tz=timezone.utc) + timedelta(days=i + 1)
mock_event = Event(
id=uuid4(),
title=f"Mock Event {i}",
slug=f"mock-event-{uuid4().hex[:8]}",
event_date=event_date,
timezone="UTC",
is_public=public,
is_active=active,
created_by=self.mock_user.id,
created_at=datetime.now(tz=timezone.utc),
updated_at=datetime.now(tz=timezone.utc),
)
self.db_session.add(mock_event)
events.append(mock_event)
self.db_session.commit()
return events
def test_get_user_events_success(self):
"""User gets their events correctly."""
created_events = self.create_mock_events(3)
response = self.client.get("/events/me")
assert response.status_code == status.HTTP_200_OK, response.json()
data = response.json()
assert "items" in data
assert len(data["items"]) == 3
returned_titles = {event["title"] for event in data["items"]}
assert all(event.title in returned_titles for event in created_events)
def test_get_user_events_pagination(self):
"""User events are returned paginated."""
self.create_mock_events(10)
response = self.client.get("/events/me?skip=5&limit=3")
assert response.status_code == status.HTTP_200_OK, response.json()
data = response.json()
assert len(data["items"]) == 3
assert data["total"] == 10
def test_get_user_events_include_inactive(self):
"""Inactive events should be included if requested explicitly."""
# Create active and inactive events
self.create_mock_events(2, active=False)
self.create_mock_events(1, active=True)
response_default = self.client.get("/events/me")
data_default = response_default.json()
assert response_default.status_code == status.HTTP_200_OK
assert len(data_default["items"]) == 1 # Active included only by default
response_include_inactive = self.client.get("/events/me?include_inactive=true")
data_inactive = response_include_inactive.json()
assert response_include_inactive.status_code == status.HTTP_200_OK
assert len(data_inactive["items"]) == 3 # Inactive explicitly included
def test_get_user_events_unauthenticated_fails(self, create_test_client):
"""Endpoint must not allow access without authentication."""
unauth_client = create_test_client(
router=events_router,
prefix="/events",
db_session=self.db_session,
user=None
)
response = unauth_client.get("/events/me")
assert response.status_code == status.HTTP_401_UNAUTHORIZED, response.json()
assert "Invalid authentication credentials" in response.text
def test_get_user_events_no_events(self):
"""User without events should receive an empty list."""
response = self.client.get("/events/me")
assert response.status_code == status.HTTP_200_OK, response.json()
data = response.json()
assert "items" in data
assert data["items"] == []
assert data["total"] == 0
class TestGetUpcomingEvents:
"""Tests for the get_upcoming_events endpoint."""
@pytest.fixture(autouse=True)
def setup_method(self, create_test_client, db_session, mock_user):
self.client = create_test_client(
router=events_router,
prefix="/events",
db_session=db_session,
user=mock_user
)
self.db_session = db_session
self.mock_user = mock_user
def create_mock_event(self, days_from_now: int, is_active=True, is_public=True):
"""Utility method for creating mocked event data in database."""
event = Event(
id=uuid4(),
title=f"Upcoming Event {uuid4().hex[:4]}",
slug=f"upcoming-event-{uuid4().hex[:8]}",
description="Upcoming event description",
event_date=datetime.now(tz=timezone.utc) + timedelta(days=days_from_now),
timezone="UTC",
is_public=is_public,
is_active=is_active,
created_by=self.mock_user.id,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
self.db_session.add(event)
self.db_session.commit()
return event
def test_get_upcoming_events_success_default_limit(self):
"""Ensure upcoming events are retrieved according to default limit."""
for i in range(15):
self.create_mock_event(days_from_now=i + 1)
response = self.client.get("/events/upcoming")
assert response.status_code == status.HTTP_200_OK, response.json()
data = response.json()
assert "items" in data # asserting structure explicitly
assert len(data["items"]) == 15 # explicitly checking items list length
assert data["total"] == 15 # check total is correct
assert data["size"] == 100 # default size
def test_get_upcoming_events_with_query_limit(self):
"""Verifies the limit parameter works correctly."""
# create 5 events
for i in range(5):
self.create_mock_event(days_from_now=i + 1)
response = self.client.get("/events/upcoming?limit=3")
assert response.status_code == status.HTTP_200_OK, response.json()
data = response.json()
assert "items" in data # asserting structure explicitly
assert len(data["items"]) == 3
assert data["total"] == 5
assert data["size"] == 3
def test_get_upcoming_events_only_active_and_future_events_returned(self):
"""Ensure only future active events are returned."""
# Past events and inactive events shouldn't be returned
self.create_mock_event(days_from_now=-5) # past event
self.create_mock_event(days_from_now=3, is_active=False) # inactive event
valid_event = self.create_mock_event(days_from_now=5) # active future event
response = self.client.get("/events/upcoming")
assert response.status_code == status.HTTP_200_OK, response.json()
data = response.json()
assert len(data["items"]) == 1
assert data["items"][0]["slug"] == valid_event.slug
def test_get_upcoming_events_unauthenticated_fails(self, create_test_client, db_session):
"""Verify unauthorized users cannot access this endpoint."""
client = create_test_client(
router=events_router,
prefix="/events",
db_session=db_session,
user=None,
)
response = client.get("/events/upcoming")
assert response.status_code == status.HTTP_401_UNAUTHORIZED, response.json()
assert "Invalid authentication credentials" in response.text
def test_get_upcoming_events_no_events_available(self):
"""Check the response if no upcoming events exist."""
response = self.client.get("/events/upcoming")
assert response.status_code == status.HTTP_200_OK, response.json()
data = response.json()
assert isinstance(data["items"], list) and len(data["items"]) == 0
class TestGetPublicEvents:
@pytest.fixture(autouse=True)
def setup_method(self, create_test_client, db_session, mock_user):
self.client = create_test_client(
router=events_router,
prefix="/events",
db_session=db_session,
user=mock_user
)
self.db_session = db_session
self.mock_user = mock_user
def create_mock_event(self, days_from_now: int, is_public=True, is_active=True):
event = Event(
id=uuid4(),
title=f"Event {uuid4().hex[:4]}",
slug=f"event-{uuid4().hex[:8]}",
description="Event description",
event_date=datetime.now(tz=timezone.utc) + timedelta(days=days_from_now),
timezone="UTC",
is_public=is_public,
is_active=is_active,
created_by=self.mock_user.id,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
self.db_session.add(event)
self.db_session.commit()
return event
def test_public_events_success_default_limit(self):
"""Return default limit 100 public events"""
for i in range(120):
self.create_mock_event(days_from_now=i + 1, is_public=True)
response = self.client.get("/events/public")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert len(data["items"]) == 100 # default limit
assert data["total"] == 120
assert data["page"] == 1
assert data["size"] == 100
def test_public_events_pagination(self):
"""Return paginated events"""
for i in range(30):
self.create_mock_event(days_from_now=i + 1, is_public=True)
response = self.client.get("/events/public?skip=10&limit=5")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert len(data["items"]) == 5
assert data["total"] == 30
assert data["page"] == 3 # skip=10, limit=5 means page=3
assert data["size"] == 5
def test_non_public_events_never_returned(self):
"""Ensure events marked is_public=False aren't retrieved"""
for i in range(10):
self.create_mock_event(days_from_now=i + 1, is_public=False)
for i in range(5):
self.create_mock_event(days_from_now=i + 1, is_public=True)
response = self.client.get("/events/public")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert len(data["items"]) == 5 # only public events returned
assert data["total"] == 5
def test_public_events_no_events_available(self):
"""Edge case: when no public events in db"""
response = self.client.get("/events/public")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert len(data["items"]) == 0
assert data["total"] == 0
assert data["page"] == 1
assert data["size"] == 100
class TestGetEvent:
@pytest.fixture(autouse=True)
def setup_method(self, create_test_client, db_session, mock_user):
self.client = create_test_client(
router=events_router,
prefix="/events",
db_session=db_session,
user=mock_user
)
self.db_session = db_session
self.mock_user = mock_user
def create_mock_user(
self,
email="testuser@example.com",
is_superuser=False,
is_active=True
):
user = User(
id=uuid.uuid4(),
email=email,
password_hash="mockhashedpassword",
first_name="Test",
last_name="User",
phone_number="1234567890",
is_active=is_active,
is_superuser=is_superuser,
preferences=None,
)
self.db_session.add(user)
self.db_session.commit()
return user
def create_mock_event(
self,
created_by: UUID,
title: str = "Test Event",
slug: str = "test-event",
description: str = "A sample event for testing purposes.",
event_date: Optional[datetime] = None,
timezone_str: str = "UTC",
is_public: bool = False,
is_active: bool = True,
managers: Optional[List] = None,
manager_role: EventManagerRole = EventManagerRole.ADMIN,
):
if event_date is None:
event_date = datetime.now(timezone.utc) + timedelta(days=10) # Default to 10 days in future if not specified
# Create Event instance
mock_event = Event(
title=title,
slug=slug,
description=description,
event_date=event_date,
timezone=timezone_str,
created_by=created_by,
is_public=is_public,
is_active=is_active,
rsvp_enabled=True,
gift_registry_enabled=True,
updates_enabled=True,
)
self.db_session.add(mock_event)
self.db_session.flush() # Use flush here to get a valid event_id before assigning managers
# If managers are provided, set EventManager instances
if managers:
for manager in managers:
role_permissions = ROLE_PERMISSIONS.get(manager_role, {})
event_manager = EventManager(
user_id=manager.id,
event_id=mock_event.id,
assigned_by=created_by,
role=manager_role,
can_edit=role_permissions.get("can_edit", False),
can_invite=role_permissions.get("can_invite", False),
can_manage_gifts=role_permissions.get("can_manage_gifts", False),
can_send_updates=role_permissions.get("can_send_updates", False),
can_view_analytics=role_permissions.get("can_view_analytics", False),
assigned_at=datetime.now(timezone.utc),
)
self.db_session.add(event_manager)
# persist the changes
self.db_session.commit()
self.db_session.refresh(mock_event)
return mock_event
def test_get_event_by_creator_success(self):
mocked_event = self.create_mock_event(created_by=self.mock_user.id)
response = self.client.get(f"/events/{mocked_event.id}")
assert response.status_code == status.HTTP_200_OK
assert response.json()["id"] == str(mocked_event.id)
def test_get_event_by_manager_success(self, mock_user):
manager_user = self.create_mock_user(email="manager@example.com")
mocked_event = self.create_mock_event(created_by=self.mock_user.id, managers=[manager_user])
self.client.user = manager_user
response = self.client.get(f"/events/{mocked_event.id}")
assert response.status_code == status.HTTP_200_OK
assert response.json()["id"] == str(mocked_event.id)
def test_get_event_by_superuser_success(self):
superuser = self.create_mock_user(email="superuser@example.com", is_superuser=True)
mocked_event = self.create_mock_event(created_by=self.mock_user.id)
self.client.user = superuser
response = self.client.get(f"/events/{mocked_event.id}")
assert response.status_code == status.HTTP_200_OK
assert response.json()["id"] == str(mocked_event.id)
def test_get_event_not_exists(self):
random_event_id = uuid4()
response = self.client.get(f"/events/{random_event_id}")
assert response.status_code == status.HTTP_404_NOT_FOUND
assert response.json()["detail"] == "Event not found"
def test_get_public_event_by_non_related_user_success(self):
other_user = self.create_mock_user(email="randomuser@example.com")
mocked_event = self.create_mock_event(
created_by=self.mock_user.id,
is_public=True
)
self.client.user = other_user
response = self.client.get(f"/events/{mocked_event.id}")
assert response.status_code == status.HTTP_200_OK
assert response.json()["id"] == str(mocked_event.id)
def test_get_private_event_by_guest_user_success(self):
guest_user = self.create_mock_user(email="guestuser@example.com")
mocked_event = self.create_mock_event(
created_by=self.mock_user.id,
is_public=False
)
guest_entry = Guest(
full_name="Guest User",
invitation_code="0000",
user_id=guest_user.id,
event_id=mocked_event.id,
invited_by=self.mock_user.id,
)
self.db_session.add(guest_entry)
self.db_session.commit()
self.client.user = guest_user
response = self.client.get(f"/events/{mocked_event.id}")
assert response.status_code == status.HTTP_200_OK
assert response.json()["id"] == str(mocked_event.id)
def test_get_private_event_non_related_user_forbidden(self):
creator_user = self.create_mock_user(email="creator@example.com")
other_user = self.create_mock_user(email="nonrelated@example.com", is_superuser=False)
mocked_event = self.create_mock_event(
created_by=creator_user.id,
is_public=False
)
self.client.user = other_user
response = self.client.get(f"/events/{mocked_event.id}")
assert response.status_code == status.HTTP_403_FORBIDDEN
assert response.json()["detail"] == "Not enough permissions to access this event"
def test_get_event_unauthenticated_user_fails(self, create_test_client, db_session):
mocked_event = self.create_mock_event(
created_by=self.mock_user.id,
is_public=True
)
client = create_test_client(
router=events_router,
prefix="/events",
db_session=db_session,
user=None,
)
self.client.user = None # Simulate no authenticated user
response = client.get(f"/events/{mocked_event.id}")
assert response.status_code == status.HTTP_401_UNAUTHORIZED
assert response.json()["detail"] == "Invalid authentication credentials"

View File

@@ -78,6 +78,23 @@ def mock_user(db_session):
db_session.commit()
return mock_user
@pytest.fixture
def mock_superuser(db_session):
"""Fixture to create and return a mock User instance."""
mock_user = User(
id=uuid.uuid4(),
email="mocksuperuser@example.com",
password_hash="mockhashedpassword",
first_name="Mock",
last_name="SuperUser",
phone_number="1234567890",
is_active=True,
is_superuser=True,
preferences=None,
)
db_session.add(mock_user)
db_session.commit()
return mock_user
@pytest.fixture
def mock_event(db_session, mock_user):