Refactor event access validation and enhance endpoint logic
All checks were successful
Build and Push Docker Images / changes (push) Successful in 5s
Build and Push Docker Images / build-backend (push) Successful in 55s
Build and Push Docker Images / build-frontend (push) Has been skipped

Centralized event access validation into a reusable `validate_event_access` function, eliminating duplicated code across endpoints. Updated the logic in `get_event` and `get_event_by_slug` to use this function. Adjusted tests to align with the refactored logic and fixed permission-based response statuses.
This commit is contained in:
2025-03-10 09:18:46 +01:00
parent c5915e57b1
commit e1145525ff
2 changed files with 105 additions and 78 deletions

View File

@@ -1,11 +1,12 @@
import logging import logging
from typing import Optional, Any, Dict from typing import Any, Dict
from typing import Optional
from uuid import UUID from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query, status from fastapi import APIRouter, Depends, Query
from fastapi import HTTPException, status
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from starlette.config import environ
from app.api.dependencies.auth import get_current_user, get_optional_current_user from app.api.dependencies.auth import get_current_user, get_optional_current_user
from app.core.database import get_db from app.core.database import get_db
@@ -17,13 +18,63 @@ from app.schemas.common import PaginatedResponse
from app.schemas.events import ( from app.schemas.events import (
EventCreate, EventCreate,
EventUpdate, EventUpdate,
EventResponse, EventResponse, Event,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
events_router = APIRouter() events_router = APIRouter()
def validate_event_access(
*,
db: Session,
event_obj: Optional[Event],
current_user: Optional[User],
access_code: Optional[str] = None
) -> EventResponse:
"""Validate access permissions for an event."""
if event_obj is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Event not found"
)
# Allow creator or superuser
if current_user and (
event_obj.created_by == current_user.id or current_user.is_superuser
):
return event_obj
# Allow manager
if current_user:
is_manager = db.query(EventManager).filter_by(
event_id=event_obj.id, user_id=current_user.id
).first()
if is_manager:
return event_obj
# Public event, allow anyone
if event_obj.is_public:
return event_obj
# Guest user allowed if authenticated
if current_user:
guest_entry = db.query(Guest).filter_by(
event_id=event_obj.id, user_id=current_user.id
).first()
if guest_entry:
return event_obj
# Access with invite/access code (generic method if implemented)
if access_code and (event_obj.access_code == access_code):
return event_obj
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not enough permissions to access this event"
)
@events_router.post( @events_router.post(
"/", "/",
response_model=EventResponse, response_model=EventResponse,
@@ -184,103 +235,54 @@ def get_event(
*, *,
db: Session = Depends(get_db), db: Session = Depends(get_db),
event_id: UUID, event_id: UUID,
current_user: User = Depends(get_current_user) access_code: Optional[str] = Query(None),
current_user: Optional[User] = Depends(get_current_user)
) -> EventResponse: ) -> EventResponse:
"""Get event by ID.""" """Get event by ID."""
if current_user is None: print("Getting event")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try: try:
event_obj = event.get(db=db, id=event_id) event_obj = event.get(db=db, id=event_id)
if not event_obj: return validate_event_access(
raise HTTPException( db=db,
status_code=status.HTTP_404_NOT_FOUND, event_obj=event_obj,
detail="Event not found" current_user=current_user,
) access_code=access_code
# Allow direct access if user is creator or superuser
if event_obj.created_by == current_user.id or current_user.is_superuser:
return event_obj
# Allow direct access if the user is managing the event
is_manager = db.query(EventManager).filter(
EventManager.event_id == event_id,
EventManager.user_id == current_user.id
).first()
if is_manager:
return event_obj
# Allow access if the event is public
if event_obj.is_public:
return event_obj
# Allow access if the user is explicitly invited (Guest)
guest_entry = db.query(Guest).filter(
Guest.event_id == event_id,
Guest.user_id == current_user.id
).first()
if guest_entry:
return event_obj
# User does not meet any permitted criteria; deny access
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not enough permissions to access this event"
) )
except SQLAlchemyError: except SQLAlchemyError:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error retrieving event" detail="Error retrieving event",
) )
@events_router.get( @events_router.get(
"/by-slug/{slug}", "/by-slug/{slug}",
response_model=EventResponse, response_model=EventResponse,
operation_id="get_public_event" operation_id="get_event_by_slug"
) )
def get_public_event( def get_event_by_slug(
*, *,
db: Session = Depends(get_db), db: Session = Depends(get_db),
slug: str, slug: str,
access_code: Optional[str] = Query(None), access_code: Optional[str] = Query(None),
current_user: Optional[User] = Depends(get_optional_current_user) current_user: Optional[User] = Depends(get_current_user)
) -> EventResponse: ) -> EventResponse:
"""Get public event by slug.""" """Get event by slug."""
if current_user is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try: try:
event_obj = event.get_public_event(db=db, slug=slug, access_code=access_code) event_obj = event.get_public_event(db=db, slug=slug, access_code=access_code)
if not event_obj:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Event not found"
)
# If event is not public and user is not authenticated, check access code return validate_event_access(
if not event_obj.is_public and not current_user: db=db,
if not access_code or access_code != event_obj.access_code: event_obj=event_obj,
raise HTTPException( current_user=current_user,
status_code=status.HTTP_403_FORBIDDEN, access_code=access_code
detail="Invalid access code" )
)
return event_obj
except SQLAlchemyError: except SQLAlchemyError:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error retrieving event" detail="Error retrieving event",
) )

View File

@@ -410,9 +410,11 @@ class TestGetPublicEvents:
assert data["page"] == 1 assert data["page"] == 1
assert data["size"] == 100 assert data["size"] == 100
# @pytest.mark.parametrize("endpoint_type", ["id", "slug"])
@pytest.mark.parametrize("endpoint_type", ["id"])
class TestGetEvent: class TestGetEvent:
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def setup_method(self, create_test_client, db_session, mock_user): def setup_method(self, create_test_client, db_session, mock_user, endpoint_type):
self.client = create_test_client( self.client = create_test_client(
router=events_router, router=events_router,
prefix="/events", prefix="/events",
@@ -421,6 +423,8 @@ class TestGetEvent:
) )
self.db_session = db_session self.db_session = db_session
self.mock_user = mock_user self.mock_user = mock_user
self.endpoint_type = endpoint_type
def create_mock_user( def create_mock_user(
self, self,
@@ -501,15 +505,36 @@ class TestGetEvent:
return mock_event return mock_event
def get_event_endpoint(self, event_obj, access_code=None):
"""
Helper method to dynamically build the endpoint URL based on the test parameter.
"""
if self.endpoint_type == "id":
endpoint = f"/events/{event_obj.id}"
# else:
# endpoint = f"/events/by-slug/{event_obj.slug}"
if access_code is not None:
endpoint += f"?access_code={access_code}"
return endpoint
def test_get_event_by_creator_success(self): def test_get_event_by_creator_success(self):
mocked_event = self.create_mock_event(created_by=self.mock_user.id) mocked_event = self.create_mock_event(created_by=self.mock_user.id)
response = self.client.get(f"/events/{mocked_event.id}") endpoint = self.get_event_endpoint(mocked_event)
response = self.client.get(endpoint)
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert response.json()["id"] == str(mocked_event.id) assert response.json()["id"] == str(mocked_event.id)
# def test_get_event_by_creator_success(self):
# mocked_event = self.create_mock_event(created_by=self.mock_user.id)
#
# response = self.client.get(f"/events/{mocked_event.id}")
#
# assert response.status_code == status.HTTP_200_OK
# assert response.json()["id"] == str(mocked_event.id)
def test_get_event_by_manager_success(self, mock_user): def test_get_event_by_manager_success(self, mock_user):
manager_user = self.create_mock_user(email="manager@example.com") manager_user = self.create_mock_user(email="manager@example.com")
mocked_event = self.create_mock_event(created_by=self.mock_user.id, managers=[manager_user]) mocked_event = self.create_mock_event(created_by=self.mock_user.id, managers=[manager_user])
@@ -604,7 +629,7 @@ class TestGetEvent:
def test_get_event_unauthenticated_user_fails(self, create_test_client, db_session): def test_get_event_unauthenticated_user_fails(self, create_test_client, db_session):
mocked_event = self.create_mock_event( mocked_event = self.create_mock_event(
created_by=self.mock_user.id, created_by=self.mock_user.id,
is_public=True is_public=False
) )
client = create_test_client( client = create_test_client(
router=events_router, router=events_router,
@@ -617,5 +642,5 @@ class TestGetEvent:
response = client.get(f"/events/{mocked_event.id}") response = client.get(f"/events/{mocked_event.id}")
assert response.status_code == status.HTTP_401_UNAUTHORIZED assert response.status_code == status.HTTP_403_FORBIDDEN
assert response.json()["detail"] == "Invalid authentication credentials" assert response.json()["detail"] == "Not enough permissions to access this event"