Refactor event access validation and enhance endpoint logic
All checks were successful
Build and Push Docker Images / changes (push) Successful in 5s
Build and Push Docker Images / build-backend (push) Successful in 55s
Build and Push Docker Images / build-frontend (push) Has been skipped

Centralized event access validation into a reusable `validate_event_access` function, eliminating duplicated code across endpoints. Updated the logic in `get_event` and `get_event_by_slug` to use this function. Adjusted tests to align with the refactored logic and fixed permission-based response statuses.
This commit is contained in:
2025-03-10 09:18:46 +01:00
parent c5915e57b1
commit e1145525ff
2 changed files with 105 additions and 78 deletions

View File

@@ -1,11 +1,12 @@
import logging
from typing import Optional, Any, Dict
from typing import Any, Dict
from typing import Optional
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query, status
from fastapi import APIRouter, Depends, Query
from fastapi import HTTPException, status
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session
from starlette.config import environ
from app.api.dependencies.auth import get_current_user, get_optional_current_user
from app.core.database import get_db
@@ -17,13 +18,63 @@ from app.schemas.common import PaginatedResponse
from app.schemas.events import (
EventCreate,
EventUpdate,
EventResponse,
EventResponse, Event,
)
logger = logging.getLogger(__name__)
events_router = APIRouter()
def validate_event_access(
*,
db: Session,
event_obj: Optional[Event],
current_user: Optional[User],
access_code: Optional[str] = None
) -> EventResponse:
"""Validate access permissions for an event."""
if event_obj is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Event not found"
)
# Allow creator or superuser
if current_user and (
event_obj.created_by == current_user.id or current_user.is_superuser
):
return event_obj
# Allow manager
if current_user:
is_manager = db.query(EventManager).filter_by(
event_id=event_obj.id, user_id=current_user.id
).first()
if is_manager:
return event_obj
# Public event, allow anyone
if event_obj.is_public:
return event_obj
# Guest user allowed if authenticated
if current_user:
guest_entry = db.query(Guest).filter_by(
event_id=event_obj.id, user_id=current_user.id
).first()
if guest_entry:
return event_obj
# Access with invite/access code (generic method if implemented)
if access_code and (event_obj.access_code == access_code):
return event_obj
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not enough permissions to access this event"
)
@events_router.post(
"/",
response_model=EventResponse,
@@ -184,103 +235,54 @@ def get_event(
*,
db: Session = Depends(get_db),
event_id: UUID,
current_user: User = Depends(get_current_user)
access_code: Optional[str] = Query(None),
current_user: Optional[User] = Depends(get_current_user)
) -> EventResponse:
"""Get event by ID."""
if current_user is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
headers={"WWW-Authenticate": "Bearer"},
)
print("Getting event")
try:
event_obj = event.get(db=db, id=event_id)
if not event_obj:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Event not found"
)
# Allow direct access if user is creator or superuser
if event_obj.created_by == current_user.id or current_user.is_superuser:
return event_obj
# Allow direct access if the user is managing the event
is_manager = db.query(EventManager).filter(
EventManager.event_id == event_id,
EventManager.user_id == current_user.id
).first()
if is_manager:
return event_obj
# Allow access if the event is public
if event_obj.is_public:
return event_obj
# Allow access if the user is explicitly invited (Guest)
guest_entry = db.query(Guest).filter(
Guest.event_id == event_id,
Guest.user_id == current_user.id
).first()
if guest_entry:
return event_obj
# User does not meet any permitted criteria; deny access
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not enough permissions to access this event"
return validate_event_access(
db=db,
event_obj=event_obj,
current_user=current_user,
access_code=access_code
)
except SQLAlchemyError:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error retrieving event"
detail="Error retrieving event",
)
@events_router.get(
"/by-slug/{slug}",
response_model=EventResponse,
operation_id="get_public_event"
operation_id="get_event_by_slug"
)
def get_public_event(
def get_event_by_slug(
*,
db: Session = Depends(get_db),
slug: str,
access_code: Optional[str] = Query(None),
current_user: Optional[User] = Depends(get_optional_current_user)
current_user: Optional[User] = Depends(get_current_user)
) -> EventResponse:
"""Get public event by slug."""
if current_user is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
headers={"WWW-Authenticate": "Bearer"},
)
"""Get event by slug."""
try:
event_obj = event.get_public_event(db=db, slug=slug, access_code=access_code)
if not event_obj:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Event not found"
)
# If event is not public and user is not authenticated, check access code
if not event_obj.is_public and not current_user:
if not access_code or access_code != event_obj.access_code:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid access code"
)
return validate_event_access(
db=db,
event_obj=event_obj,
current_user=current_user,
access_code=access_code
)
return event_obj
except SQLAlchemyError:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error retrieving event"
detail="Error retrieving event",
)

View File

@@ -410,9 +410,11 @@ class TestGetPublicEvents:
assert data["page"] == 1
assert data["size"] == 100
# @pytest.mark.parametrize("endpoint_type", ["id", "slug"])
@pytest.mark.parametrize("endpoint_type", ["id"])
class TestGetEvent:
@pytest.fixture(autouse=True)
def setup_method(self, create_test_client, db_session, mock_user):
def setup_method(self, create_test_client, db_session, mock_user, endpoint_type):
self.client = create_test_client(
router=events_router,
prefix="/events",
@@ -421,6 +423,8 @@ class TestGetEvent:
)
self.db_session = db_session
self.mock_user = mock_user
self.endpoint_type = endpoint_type
def create_mock_user(
self,
@@ -501,15 +505,36 @@ class TestGetEvent:
return mock_event
def get_event_endpoint(self, event_obj, access_code=None):
"""
Helper method to dynamically build the endpoint URL based on the test parameter.
"""
if self.endpoint_type == "id":
endpoint = f"/events/{event_obj.id}"
# else:
# endpoint = f"/events/by-slug/{event_obj.slug}"
if access_code is not None:
endpoint += f"?access_code={access_code}"
return endpoint
def test_get_event_by_creator_success(self):
mocked_event = self.create_mock_event(created_by=self.mock_user.id)
response = self.client.get(f"/events/{mocked_event.id}")
endpoint = self.get_event_endpoint(mocked_event)
response = self.client.get(endpoint)
assert response.status_code == status.HTTP_200_OK
assert response.json()["id"] == str(mocked_event.id)
# def test_get_event_by_creator_success(self):
# mocked_event = self.create_mock_event(created_by=self.mock_user.id)
#
# response = self.client.get(f"/events/{mocked_event.id}")
#
# assert response.status_code == status.HTTP_200_OK
# assert response.json()["id"] == str(mocked_event.id)
def test_get_event_by_manager_success(self, mock_user):
manager_user = self.create_mock_user(email="manager@example.com")
mocked_event = self.create_mock_event(created_by=self.mock_user.id, managers=[manager_user])
@@ -604,7 +629,7 @@ class TestGetEvent:
def test_get_event_unauthenticated_user_fails(self, create_test_client, db_session):
mocked_event = self.create_mock_event(
created_by=self.mock_user.id,
is_public=True
is_public=False
)
client = create_test_client(
router=events_router,
@@ -617,5 +642,5 @@ class TestGetEvent:
response = client.get(f"/events/{mocked_event.id}")
assert response.status_code == status.HTTP_401_UNAUTHORIZED
assert response.json()["detail"] == "Invalid authentication credentials"
assert response.status_code == status.HTTP_403_FORBIDDEN
assert response.json()["detail"] == "Not enough permissions to access this event"