Add CRUD operations and tests for Event model
This commit introduces a new CRUDEvent class to manage event-related database operations, including retrieval, creation, updating, and deletion of events. It includes corresponding unit tests to ensure the correctness of these functionalities, updates event schemas for enhanced validation, and refines timezone handling for event dates and deadlines.
This commit is contained in:
113
backend/app/crud/event.py
Normal file
113
backend/app/crud/event.py
Normal file
@@ -0,0 +1,113 @@
|
||||
from datetime import timezone
|
||||
from typing import List, Optional, Union, Dict, Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.crud.base import CRUDBase
|
||||
from app.models.event import Event
|
||||
from app.schemas.events import EventCreate, EventUpdate
|
||||
|
||||
|
||||
class CRUDEvent(CRUDBase[Event, EventCreate, EventUpdate]):
|
||||
def get_by_slug(self, db: Session, *, slug: str) -> Optional[Event]:
|
||||
"""Get event by slug."""
|
||||
return db.query(Event).filter(Event.slug == slug).first()
|
||||
|
||||
def get_user_events(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
user_id: UUID,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
include_inactive: bool = False
|
||||
) -> List[Event]:
|
||||
"""Get all events created by a specific user."""
|
||||
query = db.query(Event).filter(Event.created_by == user_id)
|
||||
|
||||
if not include_inactive:
|
||||
query = query.filter(Event.is_active == True)
|
||||
|
||||
return query.order_by(desc(Event.event_date)).offset(skip).limit(limit).all()
|
||||
|
||||
def create_with_owner(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
obj_in: EventCreate,
|
||||
owner_id: UUID
|
||||
) -> Event:
|
||||
"""Create a new event with owner ID."""
|
||||
obj_in_data = obj_in.model_dump()
|
||||
db_obj = Event(**obj_in_data, created_by=owner_id)
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
return db_obj
|
||||
|
||||
def update(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
db_obj: Event,
|
||||
obj_in: Union[EventUpdate, Dict[str, Any]]
|
||||
) -> Event:
|
||||
"""Update an event."""
|
||||
if isinstance(obj_in, dict):
|
||||
update_data = obj_in
|
||||
else:
|
||||
update_data = obj_in.model_dump(exclude_unset=True)
|
||||
|
||||
return super().update(db, db_obj=db_obj, obj_in=update_data)
|
||||
|
||||
def get_upcoming_events(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List[Event]:
|
||||
"""Get upcoming active events ordered by date."""
|
||||
from datetime import datetime
|
||||
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
return (
|
||||
db.query(Event)
|
||||
.filter(Event.is_active == True)
|
||||
.filter(Event.event_date >= now)
|
||||
.order_by(Event.event_date)
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
def get_public_event(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
slug: str,
|
||||
access_code: Optional[str] = None
|
||||
) -> Optional[Event]:
|
||||
"""Get a public event by slug, optionally checking access code."""
|
||||
query = db.query(Event).filter(
|
||||
Event.slug == slug,
|
||||
Event.is_active == True
|
||||
)
|
||||
|
||||
event = query.first()
|
||||
if not event:
|
||||
return None
|
||||
|
||||
if event.is_public:
|
||||
return event
|
||||
|
||||
if access_code and event.access_code == access_code:
|
||||
return event
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# Create a singleton instance for use across the application
|
||||
event = CRUDEvent(Event)
|
||||
@@ -1,8 +1,8 @@
|
||||
from datetime import datetime, time, timezone
|
||||
from typing import Dict, Optional, List, Union
|
||||
from typing import Dict, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
|
||||
class EventBase(BaseModel):
|
||||
@@ -31,6 +31,7 @@ class EventBase(BaseModel):
|
||||
|
||||
@field_validator('timezone')
|
||||
def validate_timezone(cls, v):
|
||||
from zoneinfo import ZoneInfo
|
||||
try:
|
||||
ZoneInfo(v)
|
||||
return v
|
||||
@@ -39,15 +40,21 @@ class EventBase(BaseModel):
|
||||
|
||||
@field_validator('event_date')
|
||||
def validate_event_date(cls, v):
|
||||
if v < datetime.now(tz=timezone.utc):
|
||||
if not v.tzinfo:
|
||||
raise ValueError("Event date must be timezone-aware")
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
if v < now:
|
||||
raise ValueError("Event date cannot be in the past")
|
||||
return v
|
||||
|
||||
@field_validator('rsvp_deadline')
|
||||
def validate_rsvp_deadline(cls, v, values):
|
||||
if v and 'event_date' in values.data:
|
||||
if v > values.data['event_date']:
|
||||
raise ValueError("RSVP deadline must be before event date")
|
||||
if v:
|
||||
if not v.tzinfo:
|
||||
raise ValueError("RSVP deadline must be timezone-aware")
|
||||
if 'event_date' in values.data:
|
||||
if v > values.data['event_date']:
|
||||
raise ValueError("RSVP deadline must be before event date")
|
||||
return v
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# tests/conftest.py
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import Dict
|
||||
|
||||
import pytest
|
||||
@@ -73,28 +73,25 @@ def mock_user(db_session):
|
||||
|
||||
@pytest.fixture
|
||||
def event_fixture(db_session, mock_user):
|
||||
"""
|
||||
Fixture to create and return an Event instance with valid data.
|
||||
"""
|
||||
event = Event(
|
||||
id=uuid.uuid4(),
|
||||
title="Birthday Party",
|
||||
slug="birthday-party-1", # Required unique slug
|
||||
description="A special 1st birthday celebration event.",
|
||||
event_date=datetime(2023, 12, 25, tzinfo=timezone.utc),
|
||||
timezone="UTC", # Required timezone
|
||||
created_by=mock_user.id, # Reference to a valid mock_user
|
||||
is_public=False, # Default value
|
||||
is_active=True, # Default value
|
||||
rsvp_enabled=False, # Default value
|
||||
gift_registry_enabled=True, # Default value
|
||||
updates_enabled=True # Default value
|
||||
)
|
||||
"""Create a test event fixture."""
|
||||
event_data = {
|
||||
"title": "Birthday Party",
|
||||
"slug": "birthday-party",
|
||||
"description": "A test birthday party",
|
||||
"event_date": datetime.now(tz=timezone.utc) + timedelta(days=30),
|
||||
"timezone": "UTC",
|
||||
"is_public": True, # Make sure this is set to True
|
||||
"created_by": mock_user.id
|
||||
}
|
||||
|
||||
event = Event(**event_data)
|
||||
db_session.add(event)
|
||||
db_session.commit()
|
||||
db_session.refresh(event)
|
||||
return event
|
||||
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gift_item_fixture(db_session, mock_user):
|
||||
"""
|
||||
|
||||
186
backend/tests/crud/test_event.py
Normal file
186
backend/tests/crud/test_event.py
Normal file
@@ -0,0 +1,186 @@
|
||||
import pytest
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from uuid import UUID, uuid4
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from app.crud.event import event as crud_event
|
||||
from app.schemas.events import EventCreate, EventUpdate
|
||||
from app.models.event import Event
|
||||
|
||||
|
||||
def test_create_event(db_session, mock_user):
|
||||
"""Test creating a new event."""
|
||||
event_data = {
|
||||
"title": "Test Birthday Party",
|
||||
"slug": "test-birthday-party",
|
||||
"description": "A test birthday celebration",
|
||||
"event_date": datetime.now(tz=timezone.utc) + timedelta(days=30),
|
||||
"timezone": "UTC",
|
||||
"is_public": True
|
||||
}
|
||||
event_in = EventCreate(**event_data)
|
||||
event = crud_event.create_with_owner(
|
||||
db=db_session, obj_in=event_in, owner_id=mock_user.id
|
||||
)
|
||||
|
||||
assert event.title == event_data["title"]
|
||||
assert event.slug == event_data["slug"]
|
||||
assert event.created_by == mock_user.id
|
||||
|
||||
|
||||
def test_get_event(db_session, event_fixture):
|
||||
"""Test retrieving an event by ID."""
|
||||
stored_event = crud_event.get(db=db_session, id=event_fixture.id)
|
||||
assert stored_event
|
||||
assert stored_event.id == event_fixture.id
|
||||
assert stored_event.title == event_fixture.title
|
||||
|
||||
|
||||
def test_get_event_by_slug(db_session, event_fixture):
|
||||
"""Test retrieving an event by slug."""
|
||||
stored_event = crud_event.get_by_slug(db=db_session, slug=event_fixture.slug)
|
||||
assert stored_event
|
||||
assert stored_event.id == event_fixture.id
|
||||
assert stored_event.slug == event_fixture.slug
|
||||
|
||||
|
||||
def test_get_user_events(db_session, event_fixture, mock_user):
|
||||
"""Test retrieving all events for a specific user."""
|
||||
events = crud_event.get_user_events(
|
||||
db=db_session,
|
||||
user_id=mock_user.id,
|
||||
skip=0,
|
||||
limit=100
|
||||
)
|
||||
assert len(events) > 0
|
||||
assert any(event.id == event_fixture.id for event in events)
|
||||
|
||||
|
||||
def test_update_event(db_session, event_fixture):
|
||||
"""Test updating an event."""
|
||||
update_data = EventUpdate(
|
||||
title="Updated Birthday Party",
|
||||
description="Updated description"
|
||||
)
|
||||
updated_event = crud_event.update(
|
||||
db=db_session,
|
||||
db_obj=event_fixture,
|
||||
obj_in=update_data
|
||||
)
|
||||
assert updated_event.title == "Updated Birthday Party"
|
||||
assert updated_event.description == "Updated description"
|
||||
|
||||
|
||||
def test_delete_event(db_session, event_fixture):
|
||||
"""Test deleting an event."""
|
||||
event = crud_event.remove(db=db_session, id=event_fixture.id)
|
||||
assert event.id == event_fixture.id
|
||||
deleted_event = crud_event.get(db=db_session, id=event_fixture.id)
|
||||
assert deleted_event is None
|
||||
|
||||
|
||||
def test_get_upcoming_events(db_session):
|
||||
"""Test retrieving upcoming events."""
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
future_date = now + timedelta(days=30)
|
||||
|
||||
future_event_data = {
|
||||
"title": "Future Event",
|
||||
"slug": "future-event",
|
||||
"description": "An upcoming event",
|
||||
"event_date": future_date,
|
||||
"timezone": "UTC",
|
||||
"is_public": True
|
||||
}
|
||||
event_in = EventCreate(**future_event_data)
|
||||
crud_event.create_with_owner(
|
||||
db=db_session,
|
||||
obj_in=event_in,
|
||||
owner_id=uuid4()
|
||||
)
|
||||
|
||||
upcoming_events = crud_event.get_upcoming_events(
|
||||
db=db_session,
|
||||
skip=0,
|
||||
limit=100
|
||||
)
|
||||
assert len(upcoming_events) > 0
|
||||
|
||||
current_time = datetime.now(tz=timezone.utc)
|
||||
for event in upcoming_events:
|
||||
# Add timezone info to event_date if it's naive
|
||||
event_date = event.event_date
|
||||
if event_date.tzinfo is None:
|
||||
event_date = event_date.replace(tzinfo=timezone.utc)
|
||||
assert event_date >= current_time
|
||||
|
||||
|
||||
def test_get_public_event(db_session, event_fixture):
|
||||
"""Test retrieving a public event."""
|
||||
public_event = crud_event.get_public_event(
|
||||
db=db_session,
|
||||
slug=event_fixture.slug
|
||||
)
|
||||
assert public_event is not None
|
||||
assert public_event.is_public is True
|
||||
|
||||
|
||||
def test_get_private_event_with_access_code(db_session):
|
||||
"""Test retrieving a private event with access code."""
|
||||
private_event_data = {
|
||||
"title": "Private Party",
|
||||
"slug": "private-party",
|
||||
"event_date": datetime.now(tz=timezone.utc) + timedelta(days=30),
|
||||
"timezone": "UTC",
|
||||
"is_public": False,
|
||||
"access_code": "secret123"
|
||||
}
|
||||
event_in = EventCreate(**private_event_data)
|
||||
private_event = crud_event.create_with_owner(
|
||||
db=db_session,
|
||||
obj_in=event_in,
|
||||
owner_id=uuid4()
|
||||
)
|
||||
|
||||
# Try accessing with correct access code
|
||||
retrieved_event = crud_event.get_public_event(
|
||||
db=db_session,
|
||||
slug="private-party",
|
||||
access_code="secret123"
|
||||
)
|
||||
assert retrieved_event is not None
|
||||
assert retrieved_event.id == private_event.id
|
||||
|
||||
# Try accessing with wrong access code
|
||||
retrieved_event = crud_event.get_public_event(
|
||||
db=db_session,
|
||||
slug="private-party",
|
||||
access_code="wrong"
|
||||
)
|
||||
assert retrieved_event is None
|
||||
|
||||
|
||||
def test_get_non_existent_event(db_session):
|
||||
"""Test retrieving a non-existent event."""
|
||||
non_existent_id = uuid4()
|
||||
event = crud_event.get(db=db_session, id=non_existent_id)
|
||||
assert event is None
|
||||
|
||||
|
||||
def test_get_non_existent_event_by_slug(db_session):
|
||||
"""Test retrieving a non-existent event by slug."""
|
||||
event = crud_event.get_by_slug(db=db_session, slug="non-existent-slug")
|
||||
assert event is None
|
||||
|
||||
|
||||
def test_create_event_with_invalid_date(db_session, mock_user):
|
||||
"""Test creating an event with a past date."""
|
||||
past_date = datetime.now(tz=timezone.utc) - timedelta(days=1)
|
||||
event_data = {
|
||||
"title": "Past Event",
|
||||
"slug": "past-event",
|
||||
"event_date": past_date,
|
||||
"timezone": "UTC"
|
||||
}
|
||||
with pytest.raises(ValueError):
|
||||
event_in = EventCreate(**event_data)
|
||||
@@ -1,13 +1,12 @@
|
||||
import pytest
|
||||
from datetime import datetime, time, timedelta
|
||||
from datetime import datetime, time, timedelta, timezone
|
||||
from uuid import uuid4, UUID
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from app.schemas.events import EventCreate, EventUpdate, EventResponse
|
||||
|
||||
|
||||
def test_valid_event_create():
|
||||
event_date = datetime.now(ZoneInfo('UTC')) + timedelta(days=1)
|
||||
event_date = datetime.now(tz=timezone.utc) + timedelta(days=1)
|
||||
event_data = {
|
||||
"title": "Emma's First Birthday",
|
||||
"slug": "emmas-first-birthday",
|
||||
@@ -25,7 +24,7 @@ def test_valid_event_create():
|
||||
|
||||
|
||||
def test_invalid_event_create():
|
||||
event_date = datetime.now(ZoneInfo('UTC')) - timedelta(days=1)
|
||||
event_date = datetime.now(tz=timezone.utc) - timedelta(days=1)
|
||||
with pytest.raises(ValueError):
|
||||
EventCreate(
|
||||
title="Past Event",
|
||||
@@ -40,7 +39,7 @@ def test_invalid_timezone():
|
||||
EventCreate(
|
||||
title="Test Event",
|
||||
slug="test-event",
|
||||
event_date=datetime.now(ZoneInfo('UTC')) + timedelta(days=1),
|
||||
event_date=datetime.now(tz=timezone.utc) + timedelta(days=1),
|
||||
timezone="Invalid/Timezone"
|
||||
)
|
||||
|
||||
@@ -56,7 +55,7 @@ def test_event_update_partial():
|
||||
|
||||
|
||||
def test_event_response():
|
||||
event_date = datetime.now(ZoneInfo('UTC')) + timedelta(days=1)
|
||||
event_date = datetime.now(tz=timezone.utc) + timedelta(days=1)
|
||||
event_data = {
|
||||
"id": uuid4(),
|
||||
"title": "Test Event",
|
||||
@@ -64,8 +63,8 @@ def test_event_response():
|
||||
"event_date": event_date,
|
||||
"timezone": "UTC",
|
||||
"created_by": uuid4(),
|
||||
"created_at": datetime.now(ZoneInfo('UTC')),
|
||||
"updated_at": datetime.now(ZoneInfo('UTC'))
|
||||
"created_at": datetime.now(tz=timezone.utc),
|
||||
"updated_at": datetime.now(tz=timezone.utc)
|
||||
}
|
||||
event_response = EventResponse(**event_data)
|
||||
assert event_response.title == "Test Event"
|
||||
@@ -73,7 +72,7 @@ def test_event_response():
|
||||
|
||||
|
||||
def test_invalid_slug_format():
|
||||
event_date = datetime.now(ZoneInfo('UTC')) + timedelta(days=1)
|
||||
event_date = datetime.now(tz=timezone.utc) + timedelta(days=1)
|
||||
with pytest.raises(ValueError):
|
||||
EventCreate(
|
||||
title="Test Event",
|
||||
@@ -84,7 +83,7 @@ def test_invalid_slug_format():
|
||||
|
||||
|
||||
def test_rsvp_deadline_validation():
|
||||
event_date = datetime.now(ZoneInfo('UTC')) + timedelta(days=10)
|
||||
event_date = datetime.now(tz=timezone.utc) + timedelta(days=10)
|
||||
invalid_deadline = event_date + timedelta(days=1)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
|
||||
Reference in New Issue
Block a user