Refactor event access validation and enhance endpoint logic
Centralized event access validation into a reusable `validate_event_access` function, eliminating duplicated code across endpoints. Updated the logic in `get_event` and `get_event_by_slug` to use this function. Adjusted tests to align with the refactored logic and fixed permission-based response statuses.
This commit is contained in:
@@ -1,11 +1,12 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Optional, Any, Dict
|
from typing import Any, Dict
|
||||||
|
from typing import Optional
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
from fastapi import APIRouter, Depends, Query
|
||||||
|
from fastapi import HTTPException, status
|
||||||
from sqlalchemy.exc import SQLAlchemyError
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from starlette.config import environ
|
|
||||||
|
|
||||||
from app.api.dependencies.auth import get_current_user, get_optional_current_user
|
from app.api.dependencies.auth import get_current_user, get_optional_current_user
|
||||||
from app.core.database import get_db
|
from app.core.database import get_db
|
||||||
@@ -17,13 +18,63 @@ from app.schemas.common import PaginatedResponse
|
|||||||
from app.schemas.events import (
|
from app.schemas.events import (
|
||||||
EventCreate,
|
EventCreate,
|
||||||
EventUpdate,
|
EventUpdate,
|
||||||
EventResponse,
|
EventResponse, Event,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
events_router = APIRouter()
|
events_router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
def validate_event_access(
|
||||||
|
*,
|
||||||
|
db: Session,
|
||||||
|
event_obj: Optional[Event],
|
||||||
|
current_user: Optional[User],
|
||||||
|
access_code: Optional[str] = None
|
||||||
|
) -> EventResponse:
|
||||||
|
"""Validate access permissions for an event."""
|
||||||
|
|
||||||
|
if event_obj is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND, detail="Event not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Allow creator or superuser
|
||||||
|
if current_user and (
|
||||||
|
event_obj.created_by == current_user.id or current_user.is_superuser
|
||||||
|
):
|
||||||
|
return event_obj
|
||||||
|
|
||||||
|
# Allow manager
|
||||||
|
if current_user:
|
||||||
|
is_manager = db.query(EventManager).filter_by(
|
||||||
|
event_id=event_obj.id, user_id=current_user.id
|
||||||
|
).first()
|
||||||
|
if is_manager:
|
||||||
|
return event_obj
|
||||||
|
|
||||||
|
# Public event, allow anyone
|
||||||
|
if event_obj.is_public:
|
||||||
|
return event_obj
|
||||||
|
|
||||||
|
# Guest user allowed if authenticated
|
||||||
|
if current_user:
|
||||||
|
guest_entry = db.query(Guest).filter_by(
|
||||||
|
event_id=event_obj.id, user_id=current_user.id
|
||||||
|
).first()
|
||||||
|
if guest_entry:
|
||||||
|
return event_obj
|
||||||
|
|
||||||
|
# Access with invite/access code (generic method if implemented)
|
||||||
|
if access_code and (event_obj.access_code == access_code):
|
||||||
|
return event_obj
|
||||||
|
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Not enough permissions to access this event"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@events_router.post(
|
@events_router.post(
|
||||||
"/",
|
"/",
|
||||||
response_model=EventResponse,
|
response_model=EventResponse,
|
||||||
@@ -184,103 +235,54 @@ def get_event(
|
|||||||
*,
|
*,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
event_id: UUID,
|
event_id: UUID,
|
||||||
current_user: User = Depends(get_current_user)
|
access_code: Optional[str] = Query(None),
|
||||||
|
current_user: Optional[User] = Depends(get_current_user)
|
||||||
) -> EventResponse:
|
) -> EventResponse:
|
||||||
"""Get event by ID."""
|
"""Get event by ID."""
|
||||||
if current_user is None:
|
print("Getting event")
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail="Invalid authentication credentials",
|
|
||||||
headers={"WWW-Authenticate": "Bearer"},
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
event_obj = event.get(db=db, id=event_id)
|
event_obj = event.get(db=db, id=event_id)
|
||||||
if not event_obj:
|
return validate_event_access(
|
||||||
raise HTTPException(
|
db=db,
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
event_obj=event_obj,
|
||||||
detail="Event not found"
|
current_user=current_user,
|
||||||
)
|
access_code=access_code
|
||||||
|
|
||||||
# Allow direct access if user is creator or superuser
|
|
||||||
if event_obj.created_by == current_user.id or current_user.is_superuser:
|
|
||||||
return event_obj
|
|
||||||
|
|
||||||
# Allow direct access if the user is managing the event
|
|
||||||
is_manager = db.query(EventManager).filter(
|
|
||||||
EventManager.event_id == event_id,
|
|
||||||
EventManager.user_id == current_user.id
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if is_manager:
|
|
||||||
return event_obj
|
|
||||||
|
|
||||||
# Allow access if the event is public
|
|
||||||
if event_obj.is_public:
|
|
||||||
return event_obj
|
|
||||||
|
|
||||||
# Allow access if the user is explicitly invited (Guest)
|
|
||||||
guest_entry = db.query(Guest).filter(
|
|
||||||
Guest.event_id == event_id,
|
|
||||||
Guest.user_id == current_user.id
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if guest_entry:
|
|
||||||
return event_obj
|
|
||||||
|
|
||||||
# User does not meet any permitted criteria; deny access
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail="Not enough permissions to access this event"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
except SQLAlchemyError:
|
except SQLAlchemyError:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail="Error retrieving event"
|
detail="Error retrieving event",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@events_router.get(
|
@events_router.get(
|
||||||
"/by-slug/{slug}",
|
"/by-slug/{slug}",
|
||||||
response_model=EventResponse,
|
response_model=EventResponse,
|
||||||
operation_id="get_public_event"
|
operation_id="get_event_by_slug"
|
||||||
)
|
)
|
||||||
def get_public_event(
|
def get_event_by_slug(
|
||||||
*,
|
*,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
slug: str,
|
slug: str,
|
||||||
access_code: Optional[str] = Query(None),
|
access_code: Optional[str] = Query(None),
|
||||||
current_user: Optional[User] = Depends(get_optional_current_user)
|
current_user: Optional[User] = Depends(get_current_user)
|
||||||
) -> EventResponse:
|
) -> EventResponse:
|
||||||
"""Get public event by slug."""
|
"""Get event by slug."""
|
||||||
if current_user is None:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail="Invalid authentication credentials",
|
|
||||||
headers={"WWW-Authenticate": "Bearer"},
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
event_obj = event.get_public_event(db=db, slug=slug, access_code=access_code)
|
event_obj = event.get_public_event(db=db, slug=slug, access_code=access_code)
|
||||||
if not event_obj:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="Event not found"
|
|
||||||
)
|
|
||||||
|
|
||||||
# If event is not public and user is not authenticated, check access code
|
return validate_event_access(
|
||||||
if not event_obj.is_public and not current_user:
|
db=db,
|
||||||
if not access_code or access_code != event_obj.access_code:
|
event_obj=event_obj,
|
||||||
raise HTTPException(
|
current_user=current_user,
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
access_code=access_code
|
||||||
detail="Invalid access code"
|
)
|
||||||
)
|
|
||||||
|
|
||||||
return event_obj
|
|
||||||
except SQLAlchemyError:
|
except SQLAlchemyError:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail="Error retrieving event"
|
detail="Error retrieving event",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -410,9 +410,11 @@ class TestGetPublicEvents:
|
|||||||
assert data["page"] == 1
|
assert data["page"] == 1
|
||||||
assert data["size"] == 100
|
assert data["size"] == 100
|
||||||
|
|
||||||
|
# @pytest.mark.parametrize("endpoint_type", ["id", "slug"])
|
||||||
|
@pytest.mark.parametrize("endpoint_type", ["id"])
|
||||||
class TestGetEvent:
|
class TestGetEvent:
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def setup_method(self, create_test_client, db_session, mock_user):
|
def setup_method(self, create_test_client, db_session, mock_user, endpoint_type):
|
||||||
self.client = create_test_client(
|
self.client = create_test_client(
|
||||||
router=events_router,
|
router=events_router,
|
||||||
prefix="/events",
|
prefix="/events",
|
||||||
@@ -421,6 +423,8 @@ class TestGetEvent:
|
|||||||
)
|
)
|
||||||
self.db_session = db_session
|
self.db_session = db_session
|
||||||
self.mock_user = mock_user
|
self.mock_user = mock_user
|
||||||
|
self.endpoint_type = endpoint_type
|
||||||
|
|
||||||
|
|
||||||
def create_mock_user(
|
def create_mock_user(
|
||||||
self,
|
self,
|
||||||
@@ -501,15 +505,36 @@ class TestGetEvent:
|
|||||||
|
|
||||||
return mock_event
|
return mock_event
|
||||||
|
|
||||||
|
def get_event_endpoint(self, event_obj, access_code=None):
|
||||||
|
"""
|
||||||
|
Helper method to dynamically build the endpoint URL based on the test parameter.
|
||||||
|
"""
|
||||||
|
if self.endpoint_type == "id":
|
||||||
|
endpoint = f"/events/{event_obj.id}"
|
||||||
|
# else:
|
||||||
|
# endpoint = f"/events/by-slug/{event_obj.slug}"
|
||||||
|
if access_code is not None:
|
||||||
|
endpoint += f"?access_code={access_code}"
|
||||||
|
return endpoint
|
||||||
|
|
||||||
def test_get_event_by_creator_success(self):
|
def test_get_event_by_creator_success(self):
|
||||||
mocked_event = self.create_mock_event(created_by=self.mock_user.id)
|
mocked_event = self.create_mock_event(created_by=self.mock_user.id)
|
||||||
|
|
||||||
response = self.client.get(f"/events/{mocked_event.id}")
|
endpoint = self.get_event_endpoint(mocked_event)
|
||||||
|
response = self.client.get(endpoint)
|
||||||
|
|
||||||
assert response.status_code == status.HTTP_200_OK
|
assert response.status_code == status.HTTP_200_OK
|
||||||
assert response.json()["id"] == str(mocked_event.id)
|
assert response.json()["id"] == str(mocked_event.id)
|
||||||
|
|
||||||
|
|
||||||
|
# def test_get_event_by_creator_success(self):
|
||||||
|
# mocked_event = self.create_mock_event(created_by=self.mock_user.id)
|
||||||
|
#
|
||||||
|
# response = self.client.get(f"/events/{mocked_event.id}")
|
||||||
|
#
|
||||||
|
# assert response.status_code == status.HTTP_200_OK
|
||||||
|
# assert response.json()["id"] == str(mocked_event.id)
|
||||||
|
|
||||||
def test_get_event_by_manager_success(self, mock_user):
|
def test_get_event_by_manager_success(self, mock_user):
|
||||||
manager_user = self.create_mock_user(email="manager@example.com")
|
manager_user = self.create_mock_user(email="manager@example.com")
|
||||||
mocked_event = self.create_mock_event(created_by=self.mock_user.id, managers=[manager_user])
|
mocked_event = self.create_mock_event(created_by=self.mock_user.id, managers=[manager_user])
|
||||||
@@ -604,7 +629,7 @@ class TestGetEvent:
|
|||||||
def test_get_event_unauthenticated_user_fails(self, create_test_client, db_session):
|
def test_get_event_unauthenticated_user_fails(self, create_test_client, db_session):
|
||||||
mocked_event = self.create_mock_event(
|
mocked_event = self.create_mock_event(
|
||||||
created_by=self.mock_user.id,
|
created_by=self.mock_user.id,
|
||||||
is_public=True
|
is_public=False
|
||||||
)
|
)
|
||||||
client = create_test_client(
|
client = create_test_client(
|
||||||
router=events_router,
|
router=events_router,
|
||||||
@@ -617,5 +642,5 @@ class TestGetEvent:
|
|||||||
|
|
||||||
response = client.get(f"/events/{mocked_event.id}")
|
response = client.get(f"/events/{mocked_event.id}")
|
||||||
|
|
||||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||||
assert response.json()["detail"] == "Invalid authentication credentials"
|
assert response.json()["detail"] == "Not enough permissions to access this event"
|
||||||
Reference in New Issue
Block a user