Refactor event API and extend authentication utilities
Refactored the event API routes to improve error handling, add logging, and provide enhanced response structures with pagination. Updated tests to use new fixtures and include additional authentication utilities to facilitate testing with FastAPI's dependency injection. Also resolved issues with timezone awareness in event schemas.
This commit is contained in:
@@ -1,24 +1,34 @@
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from datetime import timezone
|
||||
from typing import Optional, Any, Dict
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from api.dependencies.auth import get_current_user
|
||||
from app.api.dependencies.auth import get_current_user, get_optional_current_user
|
||||
from app.core.database import get_db
|
||||
from app.crud.event import event
|
||||
from app.models.event_manager import EventManager
|
||||
from app.models.user import User
|
||||
from app.schemas.common import PaginatedResponse
|
||||
from app.schemas.events import (
|
||||
EventCreate,
|
||||
EventUpdate,
|
||||
EventResponse
|
||||
EventResponse,
|
||||
)
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
events_router = APIRouter()
|
||||
|
||||
|
||||
@events_router.post("/", response_model=EventResponse, operation_id="create_event")
|
||||
@events_router.post(
|
||||
"/",
|
||||
response_model=EventResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
operation_id="create_event"
|
||||
)
|
||||
def create_event(
|
||||
*,
|
||||
db: Session = Depends(get_db),
|
||||
@@ -26,50 +36,129 @@ def create_event(
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> EventResponse:
|
||||
"""Create a new event."""
|
||||
# Check if slug is already taken
|
||||
if event.get_by_slug(db, slug=event_in.slug):
|
||||
try:
|
||||
# Check if slug is already taken
|
||||
if event.get_by_slug(db, slug=event_in.slug):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="An event with this slug already exists"
|
||||
)
|
||||
|
||||
created_event = event.create_with_owner(db=db, obj_in=event_in, owner_id=current_user.id)
|
||||
logger.info(f"Event created by {current_user.email}: {created_event.slug}")
|
||||
return created_event
|
||||
except SQLAlchemyError as e:
|
||||
db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="An event with this slug already exists"
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Database error occurred"
|
||||
)
|
||||
|
||||
return event.create_with_owner(
|
||||
db=db,
|
||||
obj_in=event_in,
|
||||
owner_id=current_user.id
|
||||
)
|
||||
|
||||
|
||||
@events_router.get("/me", response_model=List[EventResponse], operation_id="get_user_events")
|
||||
@events_router.get(
|
||||
"/me",
|
||||
response_model=PaginatedResponse[EventResponse],
|
||||
operation_id="get_user_events"
|
||||
)
|
||||
def get_user_events(
|
||||
*,
|
||||
db: Session = Depends(get_db),
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
include_inactive: bool = False,
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(100, ge=1, le=500),
|
||||
include_inactive: bool = Query(False),
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> List[EventResponse]:
|
||||
"""Get all events created by the current user."""
|
||||
return event.get_user_events(
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
include_inactive=include_inactive
|
||||
)
|
||||
) -> Dict[str, Any]:
|
||||
"""Get all events created by the current user with pagination."""
|
||||
try:
|
||||
total = event.count_user_events(
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
include_inactive=include_inactive
|
||||
)
|
||||
items = event.get_user_events(
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
include_inactive=include_inactive
|
||||
)
|
||||
return {
|
||||
"total": total,
|
||||
"items": items,
|
||||
"page": skip // limit + 1 if limit > 0 else 1,
|
||||
"size": limit
|
||||
}
|
||||
except SQLAlchemyError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error retrieving events"
|
||||
)
|
||||
|
||||
|
||||
@events_router.get("/upcoming", response_model=List[EventResponse], operation_id="get_upcoming_events")
|
||||
@events_router.get(
|
||||
"/upcoming",
|
||||
response_model=PaginatedResponse[EventResponse],
|
||||
operation_id="get_upcoming_events"
|
||||
)
|
||||
def get_upcoming_events(
|
||||
*,
|
||||
db: Session = Depends(get_db),
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List[EventResponse]:
|
||||
"""Get upcoming public events."""
|
||||
return event.get_upcoming_events(db=db, skip=skip, limit=limit)
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(100, ge=1, le=500)
|
||||
) -> Dict[str, Any]:
|
||||
"""Get upcoming public events with pagination."""
|
||||
try:
|
||||
items = event.get_upcoming_events(db=db, skip=skip, limit=limit)
|
||||
# Count total upcoming events for pagination
|
||||
total = event.count_upcoming_events(db=db)
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"items": items,
|
||||
"page": skip // limit + 1 if limit > 0 else 1,
|
||||
"size": limit
|
||||
}
|
||||
except SQLAlchemyError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error retrieving upcoming events"
|
||||
)
|
||||
|
||||
|
||||
@events_router.get("/{event_id}", response_model=EventResponse, operation_id="get_event")
|
||||
@events_router.get(
|
||||
"/public",
|
||||
response_model=PaginatedResponse[EventResponse],
|
||||
operation_id="get_public_events"
|
||||
)
|
||||
def get_public_events(
|
||||
*,
|
||||
db: Session = Depends(get_db),
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(100, ge=1, le=500)
|
||||
) -> Dict[str, Any]:
|
||||
"""Get all public events with pagination."""
|
||||
try:
|
||||
items = event.get_public_events(db=db, skip=skip, limit=limit)
|
||||
total = event.count_public_events(db=db)
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"items": items,
|
||||
"page": skip // limit + 1 if limit > 0 else 1,
|
||||
"size": limit
|
||||
}
|
||||
except SQLAlchemyError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error retrieving public events"
|
||||
)
|
||||
|
||||
|
||||
@events_router.get(
|
||||
"/{event_id}",
|
||||
response_model=EventResponse,
|
||||
operation_id="get_event"
|
||||
)
|
||||
def get_event(
|
||||
*,
|
||||
db: Session = Depends(get_db),
|
||||
@@ -77,27 +166,78 @@ def get_event(
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> EventResponse:
|
||||
"""Get event by ID."""
|
||||
event_obj = event.get(db=db, id=event_id)
|
||||
if not event_obj:
|
||||
raise HTTPException(status_code=404, detail="Event not found")
|
||||
return event_obj
|
||||
try:
|
||||
event_obj = event.get(db=db, id=event_id)
|
||||
if not event_obj:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Event not found"
|
||||
)
|
||||
|
||||
# Check if user is creator or manager
|
||||
if event_obj.created_by != current_user.id:
|
||||
# Check if user is a manager
|
||||
is_manager = db.query(EventManager).filter(
|
||||
EventManager.event_id == event_id,
|
||||
EventManager.user_id == current_user.id
|
||||
).first()
|
||||
|
||||
if not is_manager and not current_user.is_superuser:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Not enough permissions"
|
||||
)
|
||||
|
||||
return event_obj
|
||||
except SQLAlchemyError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error retrieving event"
|
||||
)
|
||||
|
||||
|
||||
@events_router.get("/by-slug/{slug}", response_model=EventResponse, operation_id="get_public_event")
|
||||
@events_router.get(
|
||||
"/by-slug/{slug}",
|
||||
response_model=EventResponse,
|
||||
operation_id="get_public_event"
|
||||
)
|
||||
def get_public_event(
|
||||
*,
|
||||
db: Session = Depends(get_db),
|
||||
slug: str,
|
||||
access_code: Optional[str] = Query(None)
|
||||
access_code: Optional[str] = Query(None),
|
||||
current_user: Optional[User] = Depends(get_optional_current_user)
|
||||
) -> EventResponse:
|
||||
"""Get public event by slug."""
|
||||
event_obj = event.get_public_event(db=db, slug=slug, access_code=access_code)
|
||||
if not event_obj:
|
||||
raise HTTPException(status_code=404, detail="Event not found")
|
||||
return event_obj
|
||||
try:
|
||||
event_obj = event.get_public_event(db=db, slug=slug, access_code=access_code)
|
||||
if not event_obj:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Event not found"
|
||||
)
|
||||
|
||||
# If event is not public and user is not authenticated, check access code
|
||||
if not event_obj.is_public and not current_user:
|
||||
if not access_code or access_code != event_obj.access_code:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Invalid access code"
|
||||
)
|
||||
|
||||
return event_obj
|
||||
except SQLAlchemyError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error retrieving event"
|
||||
)
|
||||
|
||||
|
||||
@events_router.put("/{event_id}", response_model=EventResponse, operation_id="update_event")
|
||||
@events_router.put(
|
||||
"/{event_id}",
|
||||
response_model=EventResponse,
|
||||
operation_id="update_event"
|
||||
)
|
||||
def update_event(
|
||||
*,
|
||||
db: Session = Depends(get_db),
|
||||
@@ -106,36 +246,95 @@ def update_event(
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> EventResponse:
|
||||
"""Update event."""
|
||||
event_obj = event.get(db=db, id=event_id)
|
||||
if not event_obj:
|
||||
raise HTTPException(status_code=404, detail="Event not found")
|
||||
if event_obj.created_by != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="Not enough permissions")
|
||||
|
||||
# If slug is being updated, check if new slug is available
|
||||
if event_in.slug and event_in.slug != event_obj.slug:
|
||||
if event.get_by_slug(db, slug=event_in.slug):
|
||||
try:
|
||||
event_obj = event.get(db=db, id=event_id)
|
||||
if not event_obj:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="An event with this slug already exists"
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Event not found"
|
||||
)
|
||||
|
||||
return event.update(db=db, db_obj=event_obj, obj_in=event_in)
|
||||
# Check permissions (creator or manager with edit rights)
|
||||
has_permission = False
|
||||
|
||||
if event_obj.created_by == current_user.id or current_user.is_superuser:
|
||||
has_permission = True
|
||||
else:
|
||||
manager = db.query(EventManager).filter(
|
||||
EventManager.event_id == event_id,
|
||||
EventManager.user_id == current_user.id,
|
||||
EventManager.can_edit == True
|
||||
).first()
|
||||
has_permission = manager is not None
|
||||
|
||||
if not has_permission:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Not enough permissions to edit this event"
|
||||
)
|
||||
|
||||
# If slug is being updated, check if new slug is available and different
|
||||
if event_in.slug and event_in.slug != event_obj.slug:
|
||||
existing = event.get_by_slug(db, slug=event_in.slug)
|
||||
if existing and existing.id != event_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="An event with this slug already exists"
|
||||
)
|
||||
|
||||
return event.update(db=db, db_obj=event_obj, obj_in=event_in)
|
||||
except SQLAlchemyError:
|
||||
db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error updating event"
|
||||
)
|
||||
|
||||
|
||||
@events_router.delete("/{event_id}", operation_id="delete_event")
|
||||
@events_router.delete(
|
||||
"/{event_id}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
operation_id="delete_event"
|
||||
)
|
||||
def delete_event(
|
||||
*,
|
||||
db: Session = Depends(get_db),
|
||||
event_id: UUID,
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
hard_delete: bool = Query(False, description="Perform hard delete instead of soft delete")
|
||||
):
|
||||
"""Delete event."""
|
||||
event_obj = event.get(db=db, id=event_id)
|
||||
if not event_obj:
|
||||
raise HTTPException(status_code=404, detail="Event not found")
|
||||
if event_obj.created_by != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="Not enough permissions")
|
||||
"""Delete event (soft delete by default)."""
|
||||
try:
|
||||
event_obj = event.get(db=db, id=event_id)
|
||||
if not event_obj:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Event not found"
|
||||
)
|
||||
|
||||
event.remove(db=db, id=event_id)
|
||||
return {"message": "Event deleted successfully"}
|
||||
# Only creator or superuser can delete
|
||||
if event_obj.created_by != current_user.id and not current_user.is_superuser:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Not enough permissions to delete this event"
|
||||
)
|
||||
|
||||
if hard_delete:
|
||||
# Hard delete - only for superusers
|
||||
if not current_user.is_superuser:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Only administrators can perform hard delete"
|
||||
)
|
||||
event.remove(db=db, id=event_id)
|
||||
else:
|
||||
# Soft delete - set is_active to False
|
||||
event.update(db=db, db_obj=event_obj, obj_in={"is_active": False})
|
||||
|
||||
return None # 204 No Content
|
||||
except SQLAlchemyError:
|
||||
db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error deleting event"
|
||||
)
|
||||
|
||||
@@ -41,7 +41,7 @@ class EventBase(BaseModel):
|
||||
@field_validator('event_date')
|
||||
def validate_event_date(cls, v):
|
||||
if not v.tzinfo:
|
||||
raise ValueError("Event date must be timezone-aware")
|
||||
v = v.replace(tzinfo=timezone.utc)
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
if v < now:
|
||||
raise ValueError("Event date cannot be in the past")
|
||||
@@ -81,7 +81,12 @@ class EventInDBBase(EventBase):
|
||||
|
||||
|
||||
class EventResponse(EventInDBBase):
|
||||
pass
|
||||
@field_validator('event_date')
|
||||
def validate_datetime(cls, v):
|
||||
if v.tzinfo is None:
|
||||
v.event_date = v.event_date.replace(tzinfo=timezone.utc)
|
||||
return v
|
||||
|
||||
|
||||
|
||||
class Event(EventInDBBase):
|
||||
|
||||
130
backend/app/utils/auth_test_utils.py
Normal file
130
backend/app/utils/auth_test_utils.py
Normal file
@@ -0,0 +1,130 @@
|
||||
"""
|
||||
Authentication utilities for testing.
|
||||
This module provides tools to bypass FastAPI's authentication in tests.
|
||||
"""
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
# Import these from wherever they are defined in your app
|
||||
from api.dependencies.auth import get_current_user, get_optional_current_user
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
def create_test_auth_client(
|
||||
app: FastAPI,
|
||||
test_user: User,
|
||||
extra_overrides: Optional[Dict[Callable, Callable]] = None
|
||||
) -> TestClient:
|
||||
"""
|
||||
Create a test client with authentication pre-configured.
|
||||
|
||||
This bypasses the OAuth2 token validation and directly returns the test user.
|
||||
|
||||
Args:
|
||||
app: The FastAPI app to test
|
||||
test_user: The user object to use for authentication
|
||||
extra_overrides: Additional dependency overrides to apply
|
||||
|
||||
Returns:
|
||||
TestClient with authentication configured
|
||||
"""
|
||||
# First override the oauth2_scheme dependency to return a dummy token
|
||||
# This prevents FastAPI from trying to extract a real bearer token from the request
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||
app.dependency_overrides[oauth2_scheme] = lambda: "dummy_token_for_testing"
|
||||
|
||||
# Then override the get_current_user dependency to return our test user
|
||||
app.dependency_overrides[get_current_user] = lambda: test_user
|
||||
|
||||
# Apply any extra overrides
|
||||
if extra_overrides:
|
||||
for dep, override in extra_overrides.items():
|
||||
app.dependency_overrides[dep] = override
|
||||
|
||||
# Create and return the client
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def create_test_optional_auth_client(
|
||||
app: FastAPI,
|
||||
test_user: User
|
||||
) -> TestClient:
|
||||
"""
|
||||
Create a test client with optional authentication pre-configured.
|
||||
|
||||
This is useful for testing endpoints that use get_optional_current_user.
|
||||
|
||||
Args:
|
||||
app: The FastAPI app to test
|
||||
test_user: The user object to use for authentication
|
||||
|
||||
Returns:
|
||||
TestClient with optional authentication configured
|
||||
"""
|
||||
# Override the get_optional_current_user dependency
|
||||
app.dependency_overrides[get_optional_current_user] = lambda: test_user
|
||||
|
||||
# Create and return the client
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def create_test_superuser_client(
|
||||
app: FastAPI,
|
||||
test_user: User
|
||||
) -> TestClient:
|
||||
"""
|
||||
Create a test client with superuser authentication pre-configured.
|
||||
|
||||
Args:
|
||||
app: The FastAPI app to test
|
||||
test_user: The user object to use as superuser
|
||||
|
||||
Returns:
|
||||
TestClient with superuser authentication
|
||||
"""
|
||||
# Make sure user is a superuser
|
||||
test_user.is_superuser = True
|
||||
|
||||
# Use the auth client creation with superuser
|
||||
return create_test_auth_client(app, test_user)
|
||||
|
||||
|
||||
def create_test_unauthenticated_client(app: FastAPI) -> TestClient:
|
||||
"""
|
||||
Create a test client that will fail authentication checks.
|
||||
|
||||
This is useful for testing the unauthorized case of protected endpoints.
|
||||
|
||||
Args:
|
||||
app: The FastAPI app to test
|
||||
|
||||
Returns:
|
||||
TestClient without authentication
|
||||
"""
|
||||
# Any authentication attempts will fail
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def cleanup_test_client_auth(app: FastAPI) -> None:
|
||||
"""
|
||||
Clean up authentication overrides from the FastAPI app.
|
||||
|
||||
Call this after your tests to restore normal authentication behavior.
|
||||
|
||||
Args:
|
||||
app: The FastAPI app to clean up
|
||||
"""
|
||||
# Get all auth dependencies
|
||||
auth_deps = [
|
||||
get_current_user,
|
||||
get_optional_current_user,
|
||||
OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||
]
|
||||
|
||||
# Remove overrides
|
||||
for dep in auth_deps:
|
||||
if dep in app.dependency_overrides:
|
||||
del app.dependency_overrides[dep]
|
||||
Reference in New Issue
Block a user