Refactor event retrieval and improve test coverage
All checks were successful
Build and Push Docker Images / changes (push) Successful in 4s
Build and Push Docker Images / build-backend (push) Successful in 51s
Build and Push Docker Images / build-frontend (push) Has been skipped

Removed redundant code for event retrieval and standardized logic by introducing a unified method for generating event endpoints. Updated tests to align with these changes, adding support for slug-based access and handling finer permission cases. Minor issues with test formatting and comments were also addressed.
This commit is contained in:
2025-03-11 06:47:58 +01:00
parent 80ff350053
commit f245145087
2 changed files with 37 additions and 40 deletions

View File

@@ -239,7 +239,6 @@ def get_event(
current_user: Optional[User] = Depends(get_current_user) current_user: Optional[User] = Depends(get_current_user)
) -> EventResponse: ) -> EventResponse:
"""Get event by ID.""" """Get event by ID."""
print("Getting event")
try: try:
event_obj = event.get(db=db, id=event_id) event_obj = event.get(db=db, id=event_id)
return validate_event_access( return validate_event_access(
@@ -270,8 +269,7 @@ def get_event_by_slug(
) -> EventResponse: ) -> EventResponse:
"""Get event by slug.""" """Get event by slug."""
try: try:
event_obj = event.get_public_event(db=db, slug=slug, access_code=access_code) event_obj = event.get_by_slug(db=db, slug=slug)
return validate_event_access( return validate_event_access(
db=db, db=db,
event_obj=event_obj, event_obj=event_obj,

View File

@@ -1,5 +1,6 @@
import uuid import uuid
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from types import SimpleNamespace
from typing import Optional, List from typing import Optional, List
from uuid import UUID, uuid4 from uuid import UUID, uuid4
@@ -98,8 +99,7 @@ class TestCreateEvent:
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY, response.json() assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY, response.json()
assert "Event date cannot be in the past" in response.text assert "Event date cannot be in the past" in response.text
def test_create_event_invalid_slug_fails(self, db_session, mock_user, invalid_slug_event_data): def test_create_event_invalid_slug_fails(self, db_session, mock_user, invalid_slug_event_data):
response = self.client.post("/events/", json=invalid_slug_event_data) response = self.client.post("/events/", json=invalid_slug_event_data)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY, response.json() assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY, response.json()
@@ -271,7 +271,6 @@ class TestGetUpcomingEvents:
assert data["total"] == 15 # check total is correct assert data["total"] == 15 # check total is correct
assert data["size"] == 100 # default size assert data["size"] == 100 # default size
def test_get_upcoming_events_with_query_limit(self): def test_get_upcoming_events_with_query_limit(self):
"""Verifies the limit parameter works correctly.""" """Verifies the limit parameter works correctly."""
# create 5 events # create 5 events
@@ -326,6 +325,7 @@ class TestGetUpcomingEvents:
data = response.json() data = response.json()
assert isinstance(data["items"], list) and len(data["items"]) == 0 assert isinstance(data["items"], list) and len(data["items"]) == 0
class TestGetPublicEvents: class TestGetPublicEvents:
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def setup_method(self, create_test_client, db_session, mock_user): def setup_method(self, create_test_client, db_session, mock_user):
@@ -410,8 +410,9 @@ class TestGetPublicEvents:
assert data["page"] == 1 assert data["page"] == 1
assert data["size"] == 100 assert data["size"] == 100
# @pytest.mark.parametrize("endpoint_type", ["id", "slug"])
@pytest.mark.parametrize("endpoint_type", ["id"]) @pytest.mark.parametrize("endpoint_type", ["id", "slug"])
# @pytest.mark.parametrize("endpoint_type", ["id"])
class TestGetEvent: class TestGetEvent:
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def setup_method(self, create_test_client, db_session, mock_user, endpoint_type): def setup_method(self, create_test_client, db_session, mock_user, endpoint_type):
@@ -425,7 +426,6 @@ class TestGetEvent:
self.mock_user = mock_user self.mock_user = mock_user
self.endpoint_type = endpoint_type self.endpoint_type = endpoint_type
def create_mock_user( def create_mock_user(
self, self,
email="testuser@example.com", email="testuser@example.com",
@@ -458,10 +458,12 @@ class TestGetEvent:
is_public: bool = False, is_public: bool = False,
is_active: bool = True, is_active: bool = True,
managers: Optional[List] = None, managers: Optional[List] = None,
access_code: Optional[str] = None,
manager_role: EventManagerRole = EventManagerRole.ADMIN, manager_role: EventManagerRole = EventManagerRole.ADMIN,
): ):
if event_date is None: if event_date is None:
event_date = datetime.now(timezone.utc) + timedelta(days=10) # Default to 10 days in future if not specified event_date = datetime.now(timezone.utc) + timedelta(
days=10) # Default to 10 days in future if not specified
# Create Event instance # Create Event instance
mock_event = Event( mock_event = Event(
@@ -476,6 +478,7 @@ class TestGetEvent:
rsvp_enabled=True, rsvp_enabled=True,
gift_registry_enabled=True, gift_registry_enabled=True,
updates_enabled=True, updates_enabled=True,
access_code=access_code
) )
self.db_session.add(mock_event) self.db_session.add(mock_event)
@@ -511,14 +514,14 @@ class TestGetEvent:
""" """
if self.endpoint_type == "id": if self.endpoint_type == "id":
endpoint = f"/events/{event_obj.id}" endpoint = f"/events/{event_obj.id}"
# else: else:
# endpoint = f"/events/by-slug/{event_obj.slug}" endpoint = f"/events/by-slug/{event_obj.slug}"
if access_code is not None: if access_code is not None:
endpoint += f"?access_code={access_code}" endpoint += f"?access_code={access_code}"
return endpoint return endpoint
def test_get_event_by_creator_success(self): def test_get_event_by_creator_success(self):
mocked_event = self.create_mock_event(created_by=self.mock_user.id) mocked_event = self.create_mock_event(created_by=self.mock_user.id, is_public=True)
endpoint = self.get_event_endpoint(mocked_event) endpoint = self.get_event_endpoint(mocked_event)
response = self.client.get(endpoint) response = self.client.get(endpoint)
@@ -526,22 +529,14 @@ class TestGetEvent:
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert response.json()["id"] == str(mocked_event.id) assert response.json()["id"] == str(mocked_event.id)
# def test_get_event_by_creator_success(self):
# mocked_event = self.create_mock_event(created_by=self.mock_user.id)
#
# response = self.client.get(f"/events/{mocked_event.id}")
#
# assert response.status_code == status.HTTP_200_OK
# assert response.json()["id"] == str(mocked_event.id)
def test_get_event_by_manager_success(self, mock_user): def test_get_event_by_manager_success(self, mock_user):
manager_user = self.create_mock_user(email="manager@example.com") manager_user = self.create_mock_user(email="manager@example.com")
mocked_event = self.create_mock_event(created_by=self.mock_user.id, managers=[manager_user]) mocked_event = self.create_mock_event(created_by=self.mock_user.id, managers=[manager_user], is_public=True)
self.client.user = manager_user self.client.user = manager_user
response = self.client.get(f"/events/{mocked_event.id}") endpoint = self.get_event_endpoint(mocked_event)
response = self.client.get(endpoint)
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert response.json()["id"] == str(mocked_event.id) assert response.json()["id"] == str(mocked_event.id)
@@ -549,23 +544,24 @@ class TestGetEvent:
def test_get_event_by_superuser_success(self): def test_get_event_by_superuser_success(self):
superuser = self.create_mock_user(email="superuser@example.com", is_superuser=True) superuser = self.create_mock_user(email="superuser@example.com", is_superuser=True)
mocked_event = self.create_mock_event(created_by=self.mock_user.id) mocked_event = self.create_mock_event(created_by=self.mock_user.id, is_public=True)
self.client.user = superuser self.client.user = superuser
response = self.client.get(f"/events/{mocked_event.id}") endpoint = self.get_event_endpoint(mocked_event)
response = self.client.get(endpoint)
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert response.json()["id"] == str(mocked_event.id) assert response.json()["id"] == str(mocked_event.id)
def test_get_event_not_exists(self):
random_event_id = uuid4()
def test_get_event_not_exists(self): endpoint = self.get_event_endpoint(SimpleNamespace(**{"id": random_event_id, "slug": "random-slug"}))
random_event_id = uuid4() response = self.client.get(endpoint)
response = self.client.get(f"/events/{random_event_id}") assert response.status_code == status.HTTP_404_NOT_FOUND
assert response.json()["detail"] == "Event not found"
assert response.status_code == status.HTTP_404_NOT_FOUND
assert response.json()["detail"] == "Event not found"
def test_get_public_event_by_non_related_user_success(self): def test_get_public_event_by_non_related_user_success(self):
other_user = self.create_mock_user(email="randomuser@example.com") other_user = self.create_mock_user(email="randomuser@example.com")
@@ -577,12 +573,13 @@ class TestGetEvent:
self.client.user = other_user self.client.user = other_user
response = self.client.get(f"/events/{mocked_event.id}") endpoint = self.get_event_endpoint(mocked_event)
response = self.client.get(endpoint)
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert response.json()["id"] == str(mocked_event.id) assert response.json()["id"] == str(mocked_event.id)
def test_get_private_event_by_guest_user_success(self): def test_get_private_event_by_guest_user_success(self):
guest_user = self.create_mock_user(email="guestuser@example.com") guest_user = self.create_mock_user(email="guestuser@example.com")
@@ -603,29 +600,30 @@ class TestGetEvent:
self.client.user = guest_user self.client.user = guest_user
response = self.client.get(f"/events/{mocked_event.id}") endpoint = self.get_event_endpoint(mocked_event)
response = self.client.get(endpoint)
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert response.json()["id"] == str(mocked_event.id) assert response.json()["id"] == str(mocked_event.id)
def test_get_private_event_non_related_user_forbidden(self): def test_get_private_event_non_related_user_forbidden(self):
creator_user = self.create_mock_user(email="creator@example.com") creator_user = self.create_mock_user(email="creator@example.com")
other_user = self.create_mock_user(email="nonrelated@example.com", is_superuser=False) other_user = self.create_mock_user(email="nonrelated@example.com", is_superuser=False)
mocked_event = self.create_mock_event( mocked_event = self.create_mock_event(
created_by=creator_user.id, created_by=creator_user.id,
access_code="1234",
is_public=False is_public=False
) )
self.client.user = other_user self.client.user = other_user
response = self.client.get(f"/events/{mocked_event.id}") endpoint = self.get_event_endpoint(mocked_event, access_code="123")
response = self.client.get(endpoint)
assert response.status_code == status.HTTP_403_FORBIDDEN assert response.status_code == status.HTTP_403_FORBIDDEN
assert response.json()["detail"] == "Not enough permissions to access this event" assert response.json()["detail"] == "Not enough permissions to access this event"
def test_get_event_unauthenticated_user_fails(self, create_test_client, db_session): def test_get_event_unauthenticated_user_fails(self, create_test_client, db_session):
mocked_event = self.create_mock_event( mocked_event = self.create_mock_event(
created_by=self.mock_user.id, created_by=self.mock_user.id,
@@ -640,7 +638,8 @@ class TestGetEvent:
self.client.user = None # Simulate no authenticated user self.client.user = None # Simulate no authenticated user
response = client.get(f"/events/{mocked_event.id}") endpoint = self.get_event_endpoint(mocked_event)
response = client.get(endpoint)
assert response.status_code == status.HTTP_403_FORBIDDEN assert response.status_code == status.HTTP_403_FORBIDDEN
assert response.json()["detail"] == "Not enough permissions to access this event" assert response.json()["detail"] == "Not enough permissions to access this event"