Refactor event retrieval and improve test coverage
Removed redundant code for event retrieval and standardized logic by introducing a unified method for generating event endpoints. Updated tests to align with these changes, adding support for slug-based access and handling finer permission cases. Minor issues with test formatting and comments were also addressed.
This commit is contained in:
@@ -239,7 +239,6 @@ def get_event(
|
|||||||
current_user: Optional[User] = Depends(get_current_user)
|
current_user: Optional[User] = Depends(get_current_user)
|
||||||
) -> EventResponse:
|
) -> EventResponse:
|
||||||
"""Get event by ID."""
|
"""Get event by ID."""
|
||||||
print("Getting event")
|
|
||||||
try:
|
try:
|
||||||
event_obj = event.get(db=db, id=event_id)
|
event_obj = event.get(db=db, id=event_id)
|
||||||
return validate_event_access(
|
return validate_event_access(
|
||||||
@@ -270,8 +269,7 @@ def get_event_by_slug(
|
|||||||
) -> EventResponse:
|
) -> EventResponse:
|
||||||
"""Get event by slug."""
|
"""Get event by slug."""
|
||||||
try:
|
try:
|
||||||
event_obj = event.get_public_event(db=db, slug=slug, access_code=access_code)
|
event_obj = event.get_by_slug(db=db, slug=slug)
|
||||||
|
|
||||||
return validate_event_access(
|
return validate_event_access(
|
||||||
db=db,
|
db=db,
|
||||||
event_obj=event_obj,
|
event_obj=event_obj,
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from types import SimpleNamespace
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
from uuid import UUID, uuid4
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
@@ -98,8 +99,7 @@ class TestCreateEvent:
|
|||||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY, response.json()
|
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY, response.json()
|
||||||
assert "Event date cannot be in the past" in response.text
|
assert "Event date cannot be in the past" in response.text
|
||||||
|
|
||||||
def test_create_event_invalid_slug_fails(self, db_session, mock_user, invalid_slug_event_data):
|
def test_create_event_invalid_slug_fails(self, db_session, mock_user, invalid_slug_event_data):
|
||||||
|
|
||||||
response = self.client.post("/events/", json=invalid_slug_event_data)
|
response = self.client.post("/events/", json=invalid_slug_event_data)
|
||||||
|
|
||||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY, response.json()
|
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY, response.json()
|
||||||
@@ -271,7 +271,6 @@ class TestGetUpcomingEvents:
|
|||||||
assert data["total"] == 15 # check total is correct
|
assert data["total"] == 15 # check total is correct
|
||||||
assert data["size"] == 100 # default size
|
assert data["size"] == 100 # default size
|
||||||
|
|
||||||
|
|
||||||
def test_get_upcoming_events_with_query_limit(self):
|
def test_get_upcoming_events_with_query_limit(self):
|
||||||
"""Verifies the limit parameter works correctly."""
|
"""Verifies the limit parameter works correctly."""
|
||||||
# create 5 events
|
# create 5 events
|
||||||
@@ -326,6 +325,7 @@ class TestGetUpcomingEvents:
|
|||||||
data = response.json()
|
data = response.json()
|
||||||
assert isinstance(data["items"], list) and len(data["items"]) == 0
|
assert isinstance(data["items"], list) and len(data["items"]) == 0
|
||||||
|
|
||||||
|
|
||||||
class TestGetPublicEvents:
|
class TestGetPublicEvents:
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def setup_method(self, create_test_client, db_session, mock_user):
|
def setup_method(self, create_test_client, db_session, mock_user):
|
||||||
@@ -410,8 +410,9 @@ class TestGetPublicEvents:
|
|||||||
assert data["page"] == 1
|
assert data["page"] == 1
|
||||||
assert data["size"] == 100
|
assert data["size"] == 100
|
||||||
|
|
||||||
# @pytest.mark.parametrize("endpoint_type", ["id", "slug"])
|
|
||||||
@pytest.mark.parametrize("endpoint_type", ["id"])
|
@pytest.mark.parametrize("endpoint_type", ["id", "slug"])
|
||||||
|
# @pytest.mark.parametrize("endpoint_type", ["id"])
|
||||||
class TestGetEvent:
|
class TestGetEvent:
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def setup_method(self, create_test_client, db_session, mock_user, endpoint_type):
|
def setup_method(self, create_test_client, db_session, mock_user, endpoint_type):
|
||||||
@@ -425,7 +426,6 @@ class TestGetEvent:
|
|||||||
self.mock_user = mock_user
|
self.mock_user = mock_user
|
||||||
self.endpoint_type = endpoint_type
|
self.endpoint_type = endpoint_type
|
||||||
|
|
||||||
|
|
||||||
def create_mock_user(
|
def create_mock_user(
|
||||||
self,
|
self,
|
||||||
email="testuser@example.com",
|
email="testuser@example.com",
|
||||||
@@ -458,10 +458,12 @@ class TestGetEvent:
|
|||||||
is_public: bool = False,
|
is_public: bool = False,
|
||||||
is_active: bool = True,
|
is_active: bool = True,
|
||||||
managers: Optional[List] = None,
|
managers: Optional[List] = None,
|
||||||
|
access_code: Optional[str] = None,
|
||||||
manager_role: EventManagerRole = EventManagerRole.ADMIN,
|
manager_role: EventManagerRole = EventManagerRole.ADMIN,
|
||||||
):
|
):
|
||||||
if event_date is None:
|
if event_date is None:
|
||||||
event_date = datetime.now(timezone.utc) + timedelta(days=10) # Default to 10 days in future if not specified
|
event_date = datetime.now(timezone.utc) + timedelta(
|
||||||
|
days=10) # Default to 10 days in future if not specified
|
||||||
|
|
||||||
# Create Event instance
|
# Create Event instance
|
||||||
mock_event = Event(
|
mock_event = Event(
|
||||||
@@ -476,6 +478,7 @@ class TestGetEvent:
|
|||||||
rsvp_enabled=True,
|
rsvp_enabled=True,
|
||||||
gift_registry_enabled=True,
|
gift_registry_enabled=True,
|
||||||
updates_enabled=True,
|
updates_enabled=True,
|
||||||
|
access_code=access_code
|
||||||
)
|
)
|
||||||
|
|
||||||
self.db_session.add(mock_event)
|
self.db_session.add(mock_event)
|
||||||
@@ -511,14 +514,14 @@ class TestGetEvent:
|
|||||||
"""
|
"""
|
||||||
if self.endpoint_type == "id":
|
if self.endpoint_type == "id":
|
||||||
endpoint = f"/events/{event_obj.id}"
|
endpoint = f"/events/{event_obj.id}"
|
||||||
# else:
|
else:
|
||||||
# endpoint = f"/events/by-slug/{event_obj.slug}"
|
endpoint = f"/events/by-slug/{event_obj.slug}"
|
||||||
if access_code is not None:
|
if access_code is not None:
|
||||||
endpoint += f"?access_code={access_code}"
|
endpoint += f"?access_code={access_code}"
|
||||||
return endpoint
|
return endpoint
|
||||||
|
|
||||||
def test_get_event_by_creator_success(self):
|
def test_get_event_by_creator_success(self):
|
||||||
mocked_event = self.create_mock_event(created_by=self.mock_user.id)
|
mocked_event = self.create_mock_event(created_by=self.mock_user.id, is_public=True)
|
||||||
|
|
||||||
endpoint = self.get_event_endpoint(mocked_event)
|
endpoint = self.get_event_endpoint(mocked_event)
|
||||||
response = self.client.get(endpoint)
|
response = self.client.get(endpoint)
|
||||||
@@ -526,22 +529,14 @@ class TestGetEvent:
|
|||||||
assert response.status_code == status.HTTP_200_OK
|
assert response.status_code == status.HTTP_200_OK
|
||||||
assert response.json()["id"] == str(mocked_event.id)
|
assert response.json()["id"] == str(mocked_event.id)
|
||||||
|
|
||||||
|
|
||||||
# def test_get_event_by_creator_success(self):
|
|
||||||
# mocked_event = self.create_mock_event(created_by=self.mock_user.id)
|
|
||||||
#
|
|
||||||
# response = self.client.get(f"/events/{mocked_event.id}")
|
|
||||||
#
|
|
||||||
# assert response.status_code == status.HTTP_200_OK
|
|
||||||
# assert response.json()["id"] == str(mocked_event.id)
|
|
||||||
|
|
||||||
def test_get_event_by_manager_success(self, mock_user):
|
def test_get_event_by_manager_success(self, mock_user):
|
||||||
manager_user = self.create_mock_user(email="manager@example.com")
|
manager_user = self.create_mock_user(email="manager@example.com")
|
||||||
mocked_event = self.create_mock_event(created_by=self.mock_user.id, managers=[manager_user])
|
mocked_event = self.create_mock_event(created_by=self.mock_user.id, managers=[manager_user], is_public=True)
|
||||||
|
|
||||||
self.client.user = manager_user
|
self.client.user = manager_user
|
||||||
|
|
||||||
response = self.client.get(f"/events/{mocked_event.id}")
|
endpoint = self.get_event_endpoint(mocked_event)
|
||||||
|
response = self.client.get(endpoint)
|
||||||
|
|
||||||
assert response.status_code == status.HTTP_200_OK
|
assert response.status_code == status.HTTP_200_OK
|
||||||
assert response.json()["id"] == str(mocked_event.id)
|
assert response.json()["id"] == str(mocked_event.id)
|
||||||
@@ -549,23 +544,24 @@ class TestGetEvent:
|
|||||||
def test_get_event_by_superuser_success(self):
|
def test_get_event_by_superuser_success(self):
|
||||||
superuser = self.create_mock_user(email="superuser@example.com", is_superuser=True)
|
superuser = self.create_mock_user(email="superuser@example.com", is_superuser=True)
|
||||||
|
|
||||||
mocked_event = self.create_mock_event(created_by=self.mock_user.id)
|
mocked_event = self.create_mock_event(created_by=self.mock_user.id, is_public=True)
|
||||||
|
|
||||||
self.client.user = superuser
|
self.client.user = superuser
|
||||||
|
|
||||||
response = self.client.get(f"/events/{mocked_event.id}")
|
endpoint = self.get_event_endpoint(mocked_event)
|
||||||
|
response = self.client.get(endpoint)
|
||||||
|
|
||||||
assert response.status_code == status.HTTP_200_OK
|
assert response.status_code == status.HTTP_200_OK
|
||||||
assert response.json()["id"] == str(mocked_event.id)
|
assert response.json()["id"] == str(mocked_event.id)
|
||||||
|
|
||||||
|
def test_get_event_not_exists(self):
|
||||||
|
random_event_id = uuid4()
|
||||||
|
|
||||||
def test_get_event_not_exists(self):
|
endpoint = self.get_event_endpoint(SimpleNamespace(**{"id": random_event_id, "slug": "random-slug"}))
|
||||||
random_event_id = uuid4()
|
response = self.client.get(endpoint)
|
||||||
|
|
||||||
response = self.client.get(f"/events/{random_event_id}")
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
assert response.json()["detail"] == "Event not found"
|
||||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
|
||||||
assert response.json()["detail"] == "Event not found"
|
|
||||||
|
|
||||||
def test_get_public_event_by_non_related_user_success(self):
|
def test_get_public_event_by_non_related_user_success(self):
|
||||||
other_user = self.create_mock_user(email="randomuser@example.com")
|
other_user = self.create_mock_user(email="randomuser@example.com")
|
||||||
@@ -577,12 +573,13 @@ class TestGetEvent:
|
|||||||
|
|
||||||
self.client.user = other_user
|
self.client.user = other_user
|
||||||
|
|
||||||
response = self.client.get(f"/events/{mocked_event.id}")
|
endpoint = self.get_event_endpoint(mocked_event)
|
||||||
|
response = self.client.get(endpoint)
|
||||||
|
|
||||||
|
|
||||||
assert response.status_code == status.HTTP_200_OK
|
assert response.status_code == status.HTTP_200_OK
|
||||||
assert response.json()["id"] == str(mocked_event.id)
|
assert response.json()["id"] == str(mocked_event.id)
|
||||||
|
|
||||||
|
|
||||||
def test_get_private_event_by_guest_user_success(self):
|
def test_get_private_event_by_guest_user_success(self):
|
||||||
guest_user = self.create_mock_user(email="guestuser@example.com")
|
guest_user = self.create_mock_user(email="guestuser@example.com")
|
||||||
|
|
||||||
@@ -603,29 +600,30 @@ class TestGetEvent:
|
|||||||
|
|
||||||
self.client.user = guest_user
|
self.client.user = guest_user
|
||||||
|
|
||||||
response = self.client.get(f"/events/{mocked_event.id}")
|
endpoint = self.get_event_endpoint(mocked_event)
|
||||||
|
response = self.client.get(endpoint)
|
||||||
|
|
||||||
assert response.status_code == status.HTTP_200_OK
|
assert response.status_code == status.HTTP_200_OK
|
||||||
assert response.json()["id"] == str(mocked_event.id)
|
assert response.json()["id"] == str(mocked_event.id)
|
||||||
|
|
||||||
|
|
||||||
def test_get_private_event_non_related_user_forbidden(self):
|
def test_get_private_event_non_related_user_forbidden(self):
|
||||||
creator_user = self.create_mock_user(email="creator@example.com")
|
creator_user = self.create_mock_user(email="creator@example.com")
|
||||||
other_user = self.create_mock_user(email="nonrelated@example.com", is_superuser=False)
|
other_user = self.create_mock_user(email="nonrelated@example.com", is_superuser=False)
|
||||||
|
|
||||||
mocked_event = self.create_mock_event(
|
mocked_event = self.create_mock_event(
|
||||||
created_by=creator_user.id,
|
created_by=creator_user.id,
|
||||||
|
access_code="1234",
|
||||||
is_public=False
|
is_public=False
|
||||||
)
|
)
|
||||||
|
|
||||||
self.client.user = other_user
|
self.client.user = other_user
|
||||||
|
|
||||||
response = self.client.get(f"/events/{mocked_event.id}")
|
endpoint = self.get_event_endpoint(mocked_event, access_code="123")
|
||||||
|
response = self.client.get(endpoint)
|
||||||
|
|
||||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||||
assert response.json()["detail"] == "Not enough permissions to access this event"
|
assert response.json()["detail"] == "Not enough permissions to access this event"
|
||||||
|
|
||||||
|
|
||||||
def test_get_event_unauthenticated_user_fails(self, create_test_client, db_session):
|
def test_get_event_unauthenticated_user_fails(self, create_test_client, db_session):
|
||||||
mocked_event = self.create_mock_event(
|
mocked_event = self.create_mock_event(
|
||||||
created_by=self.mock_user.id,
|
created_by=self.mock_user.id,
|
||||||
@@ -640,7 +638,8 @@ class TestGetEvent:
|
|||||||
|
|
||||||
self.client.user = None # Simulate no authenticated user
|
self.client.user = None # Simulate no authenticated user
|
||||||
|
|
||||||
response = client.get(f"/events/{mocked_event.id}")
|
endpoint = self.get_event_endpoint(mocked_event)
|
||||||
|
response = client.get(endpoint)
|
||||||
|
|
||||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||||
assert response.json()["detail"] == "Not enough permissions to access this event"
|
assert response.json()["detail"] == "Not enough permissions to access this event"
|
||||||
|
|||||||
Reference in New Issue
Block a user