Refactor event API and extend authentication utilities
Refactored the event API routes to improve error handling, add logging, and provide enhanced response structures with pagination. Updated tests to use new fixtures and include additional authentication utilities to facilitate testing with FastAPI's dependency injection. Also resolved issues with timezone awareness in event schemas.
This commit is contained in:
@@ -1,24 +1,34 @@
|
|||||||
from datetime import datetime
|
from datetime import timezone
|
||||||
from typing import List, Optional
|
from typing import Optional, Any, Dict
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||||
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from api.dependencies.auth import get_current_user
|
from app.api.dependencies.auth import get_current_user, get_optional_current_user
|
||||||
from app.core.database import get_db
|
from app.core.database import get_db
|
||||||
from app.crud.event import event
|
from app.crud.event import event
|
||||||
|
from app.models.event_manager import EventManager
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
from app.schemas.common import PaginatedResponse
|
||||||
from app.schemas.events import (
|
from app.schemas.events import (
|
||||||
EventCreate,
|
EventCreate,
|
||||||
EventUpdate,
|
EventUpdate,
|
||||||
EventResponse
|
EventResponse,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
events_router = APIRouter()
|
events_router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
@events_router.post("/", response_model=EventResponse, operation_id="create_event")
|
@events_router.post(
|
||||||
|
"/",
|
||||||
|
response_model=EventResponse,
|
||||||
|
status_code=status.HTTP_201_CREATED,
|
||||||
|
operation_id="create_event"
|
||||||
|
)
|
||||||
def create_event(
|
def create_event(
|
||||||
*,
|
*,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
@@ -26,50 +36,129 @@ def create_event(
|
|||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
) -> EventResponse:
|
) -> EventResponse:
|
||||||
"""Create a new event."""
|
"""Create a new event."""
|
||||||
|
try:
|
||||||
# Check if slug is already taken
|
# Check if slug is already taken
|
||||||
if event.get_by_slug(db, slug=event_in.slug):
|
if event.get_by_slug(db, slug=event_in.slug):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail="An event with this slug already exists"
|
detail="An event with this slug already exists"
|
||||||
)
|
)
|
||||||
|
|
||||||
return event.create_with_owner(
|
created_event = event.create_with_owner(db=db, obj_in=event_in, owner_id=current_user.id)
|
||||||
db=db,
|
logger.info(f"Event created by {current_user.email}: {created_event.slug}")
|
||||||
obj_in=event_in,
|
return created_event
|
||||||
owner_id=current_user.id
|
except SQLAlchemyError as e:
|
||||||
|
db.rollback()
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Database error occurred"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@events_router.get("/me", response_model=List[EventResponse], operation_id="get_user_events")
|
@events_router.get(
|
||||||
|
"/me",
|
||||||
|
response_model=PaginatedResponse[EventResponse],
|
||||||
|
operation_id="get_user_events"
|
||||||
|
)
|
||||||
def get_user_events(
|
def get_user_events(
|
||||||
|
*,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
skip: int = 0,
|
skip: int = Query(0, ge=0),
|
||||||
limit: int = 100,
|
limit: int = Query(100, ge=1, le=500),
|
||||||
include_inactive: bool = False,
|
include_inactive: bool = Query(False),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
) -> List[EventResponse]:
|
) -> Dict[str, Any]:
|
||||||
"""Get all events created by the current user."""
|
"""Get all events created by the current user with pagination."""
|
||||||
return event.get_user_events(
|
try:
|
||||||
|
total = event.count_user_events(
|
||||||
|
db=db,
|
||||||
|
user_id=current_user.id,
|
||||||
|
include_inactive=include_inactive
|
||||||
|
)
|
||||||
|
items = event.get_user_events(
|
||||||
db=db,
|
db=db,
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
skip=skip,
|
skip=skip,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
include_inactive=include_inactive
|
include_inactive=include_inactive
|
||||||
)
|
)
|
||||||
|
return {
|
||||||
|
"total": total,
|
||||||
|
"items": items,
|
||||||
|
"page": skip // limit + 1 if limit > 0 else 1,
|
||||||
|
"size": limit
|
||||||
|
}
|
||||||
|
except SQLAlchemyError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Error retrieving events"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@events_router.get("/upcoming", response_model=List[EventResponse], operation_id="get_upcoming_events")
|
@events_router.get(
|
||||||
|
"/upcoming",
|
||||||
|
response_model=PaginatedResponse[EventResponse],
|
||||||
|
operation_id="get_upcoming_events"
|
||||||
|
)
|
||||||
def get_upcoming_events(
|
def get_upcoming_events(
|
||||||
*,
|
*,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
skip: int = 0,
|
skip: int = Query(0, ge=0),
|
||||||
limit: int = 100
|
limit: int = Query(100, ge=1, le=500)
|
||||||
) -> List[EventResponse]:
|
) -> Dict[str, Any]:
|
||||||
"""Get upcoming public events."""
|
"""Get upcoming public events with pagination."""
|
||||||
return event.get_upcoming_events(db=db, skip=skip, limit=limit)
|
try:
|
||||||
|
items = event.get_upcoming_events(db=db, skip=skip, limit=limit)
|
||||||
|
# Count total upcoming events for pagination
|
||||||
|
total = event.count_upcoming_events(db=db)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total": total,
|
||||||
|
"items": items,
|
||||||
|
"page": skip // limit + 1 if limit > 0 else 1,
|
||||||
|
"size": limit
|
||||||
|
}
|
||||||
|
except SQLAlchemyError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Error retrieving upcoming events"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@events_router.get("/{event_id}", response_model=EventResponse, operation_id="get_event")
|
@events_router.get(
|
||||||
|
"/public",
|
||||||
|
response_model=PaginatedResponse[EventResponse],
|
||||||
|
operation_id="get_public_events"
|
||||||
|
)
|
||||||
|
def get_public_events(
|
||||||
|
*,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
skip: int = Query(0, ge=0),
|
||||||
|
limit: int = Query(100, ge=1, le=500)
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Get all public events with pagination."""
|
||||||
|
try:
|
||||||
|
items = event.get_public_events(db=db, skip=skip, limit=limit)
|
||||||
|
total = event.count_public_events(db=db)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total": total,
|
||||||
|
"items": items,
|
||||||
|
"page": skip // limit + 1 if limit > 0 else 1,
|
||||||
|
"size": limit
|
||||||
|
}
|
||||||
|
except SQLAlchemyError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Error retrieving public events"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@events_router.get(
|
||||||
|
"/{event_id}",
|
||||||
|
response_model=EventResponse,
|
||||||
|
operation_id="get_event"
|
||||||
|
)
|
||||||
def get_event(
|
def get_event(
|
||||||
*,
|
*,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
@@ -77,27 +166,78 @@ def get_event(
|
|||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
) -> EventResponse:
|
) -> EventResponse:
|
||||||
"""Get event by ID."""
|
"""Get event by ID."""
|
||||||
|
try:
|
||||||
event_obj = event.get(db=db, id=event_id)
|
event_obj = event.get(db=db, id=event_id)
|
||||||
if not event_obj:
|
if not event_obj:
|
||||||
raise HTTPException(status_code=404, detail="Event not found")
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Event not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if user is creator or manager
|
||||||
|
if event_obj.created_by != current_user.id:
|
||||||
|
# Check if user is a manager
|
||||||
|
is_manager = db.query(EventManager).filter(
|
||||||
|
EventManager.event_id == event_id,
|
||||||
|
EventManager.user_id == current_user.id
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not is_manager and not current_user.is_superuser:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Not enough permissions"
|
||||||
|
)
|
||||||
|
|
||||||
return event_obj
|
return event_obj
|
||||||
|
except SQLAlchemyError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Error retrieving event"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@events_router.get("/by-slug/{slug}", response_model=EventResponse, operation_id="get_public_event")
|
@events_router.get(
|
||||||
|
"/by-slug/{slug}",
|
||||||
|
response_model=EventResponse,
|
||||||
|
operation_id="get_public_event"
|
||||||
|
)
|
||||||
def get_public_event(
|
def get_public_event(
|
||||||
*,
|
*,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
slug: str,
|
slug: str,
|
||||||
access_code: Optional[str] = Query(None)
|
access_code: Optional[str] = Query(None),
|
||||||
|
current_user: Optional[User] = Depends(get_optional_current_user)
|
||||||
) -> EventResponse:
|
) -> EventResponse:
|
||||||
"""Get public event by slug."""
|
"""Get public event by slug."""
|
||||||
|
try:
|
||||||
event_obj = event.get_public_event(db=db, slug=slug, access_code=access_code)
|
event_obj = event.get_public_event(db=db, slug=slug, access_code=access_code)
|
||||||
if not event_obj:
|
if not event_obj:
|
||||||
raise HTTPException(status_code=404, detail="Event not found")
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Event not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
# If event is not public and user is not authenticated, check access code
|
||||||
|
if not event_obj.is_public and not current_user:
|
||||||
|
if not access_code or access_code != event_obj.access_code:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Invalid access code"
|
||||||
|
)
|
||||||
|
|
||||||
return event_obj
|
return event_obj
|
||||||
|
except SQLAlchemyError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Error retrieving event"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@events_router.put("/{event_id}", response_model=EventResponse, operation_id="update_event")
|
@events_router.put(
|
||||||
|
"/{event_id}",
|
||||||
|
response_model=EventResponse,
|
||||||
|
operation_id="update_event"
|
||||||
|
)
|
||||||
def update_event(
|
def update_event(
|
||||||
*,
|
*,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
@@ -106,36 +246,95 @@ def update_event(
|
|||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
) -> EventResponse:
|
) -> EventResponse:
|
||||||
"""Update event."""
|
"""Update event."""
|
||||||
|
try:
|
||||||
event_obj = event.get(db=db, id=event_id)
|
event_obj = event.get(db=db, id=event_id)
|
||||||
if not event_obj:
|
if not event_obj:
|
||||||
raise HTTPException(status_code=404, detail="Event not found")
|
|
||||||
if event_obj.created_by != current_user.id:
|
|
||||||
raise HTTPException(status_code=403, detail="Not enough permissions")
|
|
||||||
|
|
||||||
# If slug is being updated, check if new slug is available
|
|
||||||
if event_in.slug and event_in.slug != event_obj.slug:
|
|
||||||
if event.get_by_slug(db, slug=event_in.slug):
|
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Event not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check permissions (creator or manager with edit rights)
|
||||||
|
has_permission = False
|
||||||
|
|
||||||
|
if event_obj.created_by == current_user.id or current_user.is_superuser:
|
||||||
|
has_permission = True
|
||||||
|
else:
|
||||||
|
manager = db.query(EventManager).filter(
|
||||||
|
EventManager.event_id == event_id,
|
||||||
|
EventManager.user_id == current_user.id,
|
||||||
|
EventManager.can_edit == True
|
||||||
|
).first()
|
||||||
|
has_permission = manager is not None
|
||||||
|
|
||||||
|
if not has_permission:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Not enough permissions to edit this event"
|
||||||
|
)
|
||||||
|
|
||||||
|
# If slug is being updated, check if new slug is available and different
|
||||||
|
if event_in.slug and event_in.slug != event_obj.slug:
|
||||||
|
existing = event.get_by_slug(db, slug=event_in.slug)
|
||||||
|
if existing and existing.id != event_id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail="An event with this slug already exists"
|
detail="An event with this slug already exists"
|
||||||
)
|
)
|
||||||
|
|
||||||
return event.update(db=db, db_obj=event_obj, obj_in=event_in)
|
return event.update(db=db, db_obj=event_obj, obj_in=event_in)
|
||||||
|
except SQLAlchemyError:
|
||||||
|
db.rollback()
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Error updating event"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@events_router.delete("/{event_id}", operation_id="delete_event")
|
@events_router.delete(
|
||||||
|
"/{event_id}",
|
||||||
|
status_code=status.HTTP_204_NO_CONTENT,
|
||||||
|
operation_id="delete_event"
|
||||||
|
)
|
||||||
def delete_event(
|
def delete_event(
|
||||||
*,
|
*,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
event_id: UUID,
|
event_id: UUID,
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user),
|
||||||
|
hard_delete: bool = Query(False, description="Perform hard delete instead of soft delete")
|
||||||
):
|
):
|
||||||
"""Delete event."""
|
"""Delete event (soft delete by default)."""
|
||||||
|
try:
|
||||||
event_obj = event.get(db=db, id=event_id)
|
event_obj = event.get(db=db, id=event_id)
|
||||||
if not event_obj:
|
if not event_obj:
|
||||||
raise HTTPException(status_code=404, detail="Event not found")
|
raise HTTPException(
|
||||||
if event_obj.created_by != current_user.id:
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
raise HTTPException(status_code=403, detail="Not enough permissions")
|
detail="Event not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only creator or superuser can delete
|
||||||
|
if event_obj.created_by != current_user.id and not current_user.is_superuser:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Not enough permissions to delete this event"
|
||||||
|
)
|
||||||
|
|
||||||
|
if hard_delete:
|
||||||
|
# Hard delete - only for superusers
|
||||||
|
if not current_user.is_superuser:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Only administrators can perform hard delete"
|
||||||
|
)
|
||||||
event.remove(db=db, id=event_id)
|
event.remove(db=db, id=event_id)
|
||||||
return {"message": "Event deleted successfully"}
|
else:
|
||||||
|
# Soft delete - set is_active to False
|
||||||
|
event.update(db=db, db_obj=event_obj, obj_in={"is_active": False})
|
||||||
|
|
||||||
|
return None # 204 No Content
|
||||||
|
except SQLAlchemyError:
|
||||||
|
db.rollback()
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Error deleting event"
|
||||||
|
)
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ class EventBase(BaseModel):
|
|||||||
@field_validator('event_date')
|
@field_validator('event_date')
|
||||||
def validate_event_date(cls, v):
|
def validate_event_date(cls, v):
|
||||||
if not v.tzinfo:
|
if not v.tzinfo:
|
||||||
raise ValueError("Event date must be timezone-aware")
|
v = v.replace(tzinfo=timezone.utc)
|
||||||
now = datetime.now(tz=timezone.utc)
|
now = datetime.now(tz=timezone.utc)
|
||||||
if v < now:
|
if v < now:
|
||||||
raise ValueError("Event date cannot be in the past")
|
raise ValueError("Event date cannot be in the past")
|
||||||
@@ -81,7 +81,12 @@ class EventInDBBase(EventBase):
|
|||||||
|
|
||||||
|
|
||||||
class EventResponse(EventInDBBase):
|
class EventResponse(EventInDBBase):
|
||||||
pass
|
@field_validator('event_date')
|
||||||
|
def validate_datetime(cls, v):
|
||||||
|
if v.tzinfo is None:
|
||||||
|
v.event_date = v.event_date.replace(tzinfo=timezone.utc)
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Event(EventInDBBase):
|
class Event(EventInDBBase):
|
||||||
|
|||||||
130
backend/app/utils/auth_test_utils.py
Normal file
130
backend/app/utils/auth_test_utils.py
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
"""
|
||||||
|
Authentication utilities for testing.
|
||||||
|
This module provides tools to bypass FastAPI's authentication in tests.
|
||||||
|
"""
|
||||||
|
from typing import Callable, Dict, Optional
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
|
from starlette.testclient import TestClient
|
||||||
|
|
||||||
|
# Import these from wherever they are defined in your app
|
||||||
|
from api.dependencies.auth import get_current_user, get_optional_current_user
|
||||||
|
from app.models.user import User
|
||||||
|
|
||||||
|
|
||||||
|
def create_test_auth_client(
|
||||||
|
app: FastAPI,
|
||||||
|
test_user: User,
|
||||||
|
extra_overrides: Optional[Dict[Callable, Callable]] = None
|
||||||
|
) -> TestClient:
|
||||||
|
"""
|
||||||
|
Create a test client with authentication pre-configured.
|
||||||
|
|
||||||
|
This bypasses the OAuth2 token validation and directly returns the test user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app: The FastAPI app to test
|
||||||
|
test_user: The user object to use for authentication
|
||||||
|
extra_overrides: Additional dependency overrides to apply
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TestClient with authentication configured
|
||||||
|
"""
|
||||||
|
# First override the oauth2_scheme dependency to return a dummy token
|
||||||
|
# This prevents FastAPI from trying to extract a real bearer token from the request
|
||||||
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||||
|
app.dependency_overrides[oauth2_scheme] = lambda: "dummy_token_for_testing"
|
||||||
|
|
||||||
|
# Then override the get_current_user dependency to return our test user
|
||||||
|
app.dependency_overrides[get_current_user] = lambda: test_user
|
||||||
|
|
||||||
|
# Apply any extra overrides
|
||||||
|
if extra_overrides:
|
||||||
|
for dep, override in extra_overrides.items():
|
||||||
|
app.dependency_overrides[dep] = override
|
||||||
|
|
||||||
|
# Create and return the client
|
||||||
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
def create_test_optional_auth_client(
|
||||||
|
app: FastAPI,
|
||||||
|
test_user: User
|
||||||
|
) -> TestClient:
|
||||||
|
"""
|
||||||
|
Create a test client with optional authentication pre-configured.
|
||||||
|
|
||||||
|
This is useful for testing endpoints that use get_optional_current_user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app: The FastAPI app to test
|
||||||
|
test_user: The user object to use for authentication
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TestClient with optional authentication configured
|
||||||
|
"""
|
||||||
|
# Override the get_optional_current_user dependency
|
||||||
|
app.dependency_overrides[get_optional_current_user] = lambda: test_user
|
||||||
|
|
||||||
|
# Create and return the client
|
||||||
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
def create_test_superuser_client(
|
||||||
|
app: FastAPI,
|
||||||
|
test_user: User
|
||||||
|
) -> TestClient:
|
||||||
|
"""
|
||||||
|
Create a test client with superuser authentication pre-configured.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app: The FastAPI app to test
|
||||||
|
test_user: The user object to use as superuser
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TestClient with superuser authentication
|
||||||
|
"""
|
||||||
|
# Make sure user is a superuser
|
||||||
|
test_user.is_superuser = True
|
||||||
|
|
||||||
|
# Use the auth client creation with superuser
|
||||||
|
return create_test_auth_client(app, test_user)
|
||||||
|
|
||||||
|
|
||||||
|
def create_test_unauthenticated_client(app: FastAPI) -> TestClient:
|
||||||
|
"""
|
||||||
|
Create a test client that will fail authentication checks.
|
||||||
|
|
||||||
|
This is useful for testing the unauthorized case of protected endpoints.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app: The FastAPI app to test
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TestClient without authentication
|
||||||
|
"""
|
||||||
|
# Any authentication attempts will fail
|
||||||
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup_test_client_auth(app: FastAPI) -> None:
|
||||||
|
"""
|
||||||
|
Clean up authentication overrides from the FastAPI app.
|
||||||
|
|
||||||
|
Call this after your tests to restore normal authentication behavior.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app: The FastAPI app to clean up
|
||||||
|
"""
|
||||||
|
# Get all auth dependencies
|
||||||
|
auth_deps = [
|
||||||
|
get_current_user,
|
||||||
|
get_optional_current_user,
|
||||||
|
OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||||
|
]
|
||||||
|
|
||||||
|
# Remove overrides
|
||||||
|
for dep in auth_deps:
|
||||||
|
if dep in app.dependency_overrides:
|
||||||
|
del app.dependency_overrides[dep]
|
||||||
0
backend/tests/api/routes/events/__init__.py
Normal file
0
backend/tests/api/routes/events/__init__.py
Normal file
@@ -1,10 +1,17 @@
|
|||||||
# tests/conftest.py
|
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime, timezone, timedelta
|
from datetime import datetime, timezone, timedelta
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.routing import APIRouter
|
||||||
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.api.dependencies.auth import get_current_user
|
||||||
|
from app.core.database import get_db
|
||||||
from app.models import Event, GiftItem, GiftStatus, GiftPriority, GiftCategory, GiftPurchase, RSVP, RSVPStatus, \
|
from app.models import Event, GiftItem, GiftStatus, GiftPriority, GiftCategory, GiftPurchase, RSVP, RSVPStatus, \
|
||||||
EventMedia, MediaType, MediaPurpose, \
|
EventMedia, MediaType, MediaPurpose, \
|
||||||
EventTheme, Guest, GuestStatus, ActivityType, ActivityLog, EmailTemplate, TemplateType, NotificationLog, \
|
EventTheme, Guest, GuestStatus, ActivityType, ActivityLog, EmailTemplate, TemplateType, NotificationLog, \
|
||||||
@@ -39,6 +46,7 @@ async def async_test_db():
|
|||||||
yield test_engine, AsyncTestingSessionLocal
|
yield test_engine, AsyncTestingSessionLocal
|
||||||
await teardown_async_test_db(test_engine)
|
await teardown_async_test_db(test_engine)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def user_create_data():
|
def user_create_data():
|
||||||
return {
|
return {
|
||||||
@@ -72,7 +80,7 @@ def mock_user(db_session):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def event_fixture(db_session, mock_user):
|
def mock_event(db_session, mock_user):
|
||||||
"""Create a test event fixture."""
|
"""Create a test event fixture."""
|
||||||
event_data = {
|
event_data = {
|
||||||
"title": "Birthday Party",
|
"title": "Birthday Party",
|
||||||
@@ -91,7 +99,6 @@ def event_fixture(db_session, mock_user):
|
|||||||
return event
|
return event
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def gift_item_fixture(db_session, mock_user):
|
def gift_item_fixture(db_session, mock_user):
|
||||||
"""
|
"""
|
||||||
@@ -308,7 +315,7 @@ def gift_purchase_fixture(db_session, gift_item_fixture, guest_fixture):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def gift_category_fixture(db_session, mock_user, event_fixture):
|
def gift_category_fixture(db_session, mock_user, mock_event):
|
||||||
"""
|
"""
|
||||||
Fixture to create and return a GiftCategory instance.
|
Fixture to create and return a GiftCategory instance.
|
||||||
"""
|
"""
|
||||||
@@ -316,7 +323,7 @@ def gift_category_fixture(db_session, mock_user, event_fixture):
|
|||||||
id=uuid.uuid4(),
|
id=uuid.uuid4(),
|
||||||
name="Electronics",
|
name="Electronics",
|
||||||
description="Category for electronic gifts",
|
description="Category for electronic gifts",
|
||||||
event_id=event_fixture.id,
|
event_id=mock_event.id,
|
||||||
created_by=mock_user.id,
|
created_by=mock_user.id,
|
||||||
display_order=0,
|
display_order=0,
|
||||||
is_visible=True
|
is_visible=True
|
||||||
@@ -344,3 +351,32 @@ def theme_data() -> Dict:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def oauth2_scheme():
|
||||||
|
"""Return the OAuth2PasswordBearer instance used by the app."""
|
||||||
|
return OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def create_test_client():
|
||||||
|
def _create_test_client(router: APIRouter, prefix: str, db_session, user):
|
||||||
|
from fastapi import FastAPI
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
# Mimic your dependency overrides here
|
||||||
|
def get_db_override():
|
||||||
|
yield db_session
|
||||||
|
|
||||||
|
def get_current_user_override():
|
||||||
|
return user
|
||||||
|
|
||||||
|
# Include your router
|
||||||
|
app.include_router(router, prefix=prefix)
|
||||||
|
|
||||||
|
# Override dependencies
|
||||||
|
app.dependency_overrides[get_db] = get_db_override
|
||||||
|
app.dependency_overrides[get_current_user] = get_current_user_override
|
||||||
|
|
||||||
|
return TestClient(app)
|
||||||
|
|
||||||
|
return _create_test_client
|
||||||
|
|||||||
@@ -28,23 +28,23 @@ def test_create_event(db_session, mock_user):
|
|||||||
assert event.created_by == mock_user.id
|
assert event.created_by == mock_user.id
|
||||||
|
|
||||||
|
|
||||||
def test_get_event(db_session, event_fixture):
|
def test_get_event(db_session, mock_event):
|
||||||
"""Test retrieving an event by ID."""
|
"""Test retrieving an event by ID."""
|
||||||
stored_event = crud_event.get(db=db_session, id=event_fixture.id)
|
stored_event = crud_event.get(db=db_session, id=mock_event.id)
|
||||||
assert stored_event
|
assert stored_event
|
||||||
assert stored_event.id == event_fixture.id
|
assert stored_event.id == mock_event.id
|
||||||
assert stored_event.title == event_fixture.title
|
assert stored_event.title == mock_event.title
|
||||||
|
|
||||||
|
|
||||||
def test_get_event_by_slug(db_session, event_fixture):
|
def test_get_event_by_slug(db_session, mock_event):
|
||||||
"""Test retrieving an event by slug."""
|
"""Test retrieving an event by slug."""
|
||||||
stored_event = crud_event.get_by_slug(db=db_session, slug=event_fixture.slug)
|
stored_event = crud_event.get_by_slug(db=db_session, slug=mock_event.slug)
|
||||||
assert stored_event
|
assert stored_event
|
||||||
assert stored_event.id == event_fixture.id
|
assert stored_event.id == mock_event.id
|
||||||
assert stored_event.slug == event_fixture.slug
|
assert stored_event.slug == mock_event.slug
|
||||||
|
|
||||||
|
|
||||||
def test_get_user_events(db_session, event_fixture, mock_user):
|
def test_get_user_events(db_session, mock_event, mock_user):
|
||||||
"""Test retrieving all events for a specific user."""
|
"""Test retrieving all events for a specific user."""
|
||||||
events = crud_event.get_user_events(
|
events = crud_event.get_user_events(
|
||||||
db=db_session,
|
db=db_session,
|
||||||
@@ -53,10 +53,10 @@ def test_get_user_events(db_session, event_fixture, mock_user):
|
|||||||
limit=100
|
limit=100
|
||||||
)
|
)
|
||||||
assert len(events) > 0
|
assert len(events) > 0
|
||||||
assert any(event.id == event_fixture.id for event in events)
|
assert any(event.id == mock_event.id for event in events)
|
||||||
|
|
||||||
|
|
||||||
def test_update_event(db_session, event_fixture):
|
def test_update_event(db_session, mock_event):
|
||||||
"""Test updating an event."""
|
"""Test updating an event."""
|
||||||
update_data = EventUpdate(
|
update_data = EventUpdate(
|
||||||
title="Updated Birthday Party",
|
title="Updated Birthday Party",
|
||||||
@@ -64,18 +64,18 @@ def test_update_event(db_session, event_fixture):
|
|||||||
)
|
)
|
||||||
updated_event = crud_event.update(
|
updated_event = crud_event.update(
|
||||||
db=db_session,
|
db=db_session,
|
||||||
db_obj=event_fixture,
|
db_obj=mock_event,
|
||||||
obj_in=update_data
|
obj_in=update_data
|
||||||
)
|
)
|
||||||
assert updated_event.title == "Updated Birthday Party"
|
assert updated_event.title == "Updated Birthday Party"
|
||||||
assert updated_event.description == "Updated description"
|
assert updated_event.description == "Updated description"
|
||||||
|
|
||||||
|
|
||||||
def test_delete_event(db_session, event_fixture):
|
def test_delete_event(db_session, mock_event):
|
||||||
"""Test deleting an event."""
|
"""Test deleting an event."""
|
||||||
event = crud_event.remove(db=db_session, id=event_fixture.id)
|
event = crud_event.remove(db=db_session, id=mock_event.id)
|
||||||
assert event.id == event_fixture.id
|
assert event.id == mock_event.id
|
||||||
deleted_event = crud_event.get(db=db_session, id=event_fixture.id)
|
deleted_event = crud_event.get(db=db_session, id=mock_event.id)
|
||||||
assert deleted_event is None
|
assert deleted_event is None
|
||||||
|
|
||||||
|
|
||||||
@@ -115,11 +115,11 @@ def test_get_upcoming_events(db_session):
|
|||||||
assert event_date >= current_time
|
assert event_date >= current_time
|
||||||
|
|
||||||
|
|
||||||
def test_get_public_event(db_session, event_fixture):
|
def test_get_public_event(db_session, mock_event):
|
||||||
"""Test retrieving a public event."""
|
"""Test retrieving a public event."""
|
||||||
public_event = crud_event.get_public_event(
|
public_event = crud_event.get_public_event(
|
||||||
db=db_session,
|
db=db_session,
|
||||||
slug=event_fixture.slug
|
slug=mock_event.slug
|
||||||
)
|
)
|
||||||
assert public_event is not None
|
assert public_event is not None
|
||||||
assert public_event.is_public is True
|
assert public_event.is_public is True
|
||||||
|
|||||||
@@ -77,9 +77,9 @@ def test_delete_event_theme(db_session: Session, event_theme_fixture) -> None:
|
|||||||
assert theme is None
|
assert theme is None
|
||||||
|
|
||||||
|
|
||||||
def test_get_active_themes(db_session: Session, event_theme_fixture, event_fixture) -> None:
|
def test_get_active_themes(db_session: Session, event_theme_fixture, mock_event) -> None:
|
||||||
# First, ensure the theme is associated with an event
|
# First, ensure the theme is associated with an event
|
||||||
event_fixture.theme_id = event_theme_fixture.id
|
mock_event.theme_id = event_theme_fixture.id
|
||||||
db_session.commit()
|
db_session.commit()
|
||||||
|
|
||||||
# Get active themes
|
# Get active themes
|
||||||
|
|||||||
Reference in New Issue
Block a user