Add CRUD operations and tests for Event model
This commit introduces a new CRUDEvent class to manage event-related database operations, including retrieval, creation, updating, and deletion of events. It includes corresponding unit tests to ensure the correctness of these functionalities, updates event schemas for enhanced validation, and refines timezone handling for event dates and deadlines.
This commit is contained in:
113
backend/app/crud/event.py
Normal file
113
backend/app/crud/event.py
Normal file
@@ -0,0 +1,113 @@
|
||||
from datetime import timezone
|
||||
from typing import List, Optional, Union, Dict, Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.crud.base import CRUDBase
|
||||
from app.models.event import Event
|
||||
from app.schemas.events import EventCreate, EventUpdate
|
||||
|
||||
|
||||
class CRUDEvent(CRUDBase[Event, EventCreate, EventUpdate]):
|
||||
def get_by_slug(self, db: Session, *, slug: str) -> Optional[Event]:
|
||||
"""Get event by slug."""
|
||||
return db.query(Event).filter(Event.slug == slug).first()
|
||||
|
||||
def get_user_events(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
user_id: UUID,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
include_inactive: bool = False
|
||||
) -> List[Event]:
|
||||
"""Get all events created by a specific user."""
|
||||
query = db.query(Event).filter(Event.created_by == user_id)
|
||||
|
||||
if not include_inactive:
|
||||
query = query.filter(Event.is_active == True)
|
||||
|
||||
return query.order_by(desc(Event.event_date)).offset(skip).limit(limit).all()
|
||||
|
||||
def create_with_owner(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
obj_in: EventCreate,
|
||||
owner_id: UUID
|
||||
) -> Event:
|
||||
"""Create a new event with owner ID."""
|
||||
obj_in_data = obj_in.model_dump()
|
||||
db_obj = Event(**obj_in_data, created_by=owner_id)
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
return db_obj
|
||||
|
||||
def update(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
db_obj: Event,
|
||||
obj_in: Union[EventUpdate, Dict[str, Any]]
|
||||
) -> Event:
|
||||
"""Update an event."""
|
||||
if isinstance(obj_in, dict):
|
||||
update_data = obj_in
|
||||
else:
|
||||
update_data = obj_in.model_dump(exclude_unset=True)
|
||||
|
||||
return super().update(db, db_obj=db_obj, obj_in=update_data)
|
||||
|
||||
def get_upcoming_events(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List[Event]:
|
||||
"""Get upcoming active events ordered by date."""
|
||||
from datetime import datetime
|
||||
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
return (
|
||||
db.query(Event)
|
||||
.filter(Event.is_active == True)
|
||||
.filter(Event.event_date >= now)
|
||||
.order_by(Event.event_date)
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
def get_public_event(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
slug: str,
|
||||
access_code: Optional[str] = None
|
||||
) -> Optional[Event]:
|
||||
"""Get a public event by slug, optionally checking access code."""
|
||||
query = db.query(Event).filter(
|
||||
Event.slug == slug,
|
||||
Event.is_active == True
|
||||
)
|
||||
|
||||
event = query.first()
|
||||
if not event:
|
||||
return None
|
||||
|
||||
if event.is_public:
|
||||
return event
|
||||
|
||||
if access_code and event.access_code == access_code:
|
||||
return event
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# Create a singleton instance for use across the application
|
||||
event = CRUDEvent(Event)
|
||||
@@ -1,8 +1,8 @@
|
||||
from datetime import datetime, time, timezone
|
||||
from typing import Dict, Optional, List, Union
|
||||
from typing import Dict, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
|
||||
class EventBase(BaseModel):
|
||||
@@ -31,6 +31,7 @@ class EventBase(BaseModel):
|
||||
|
||||
@field_validator('timezone')
|
||||
def validate_timezone(cls, v):
|
||||
from zoneinfo import ZoneInfo
|
||||
try:
|
||||
ZoneInfo(v)
|
||||
return v
|
||||
@@ -39,15 +40,21 @@ class EventBase(BaseModel):
|
||||
|
||||
@field_validator('event_date')
|
||||
def validate_event_date(cls, v):
|
||||
if v < datetime.now(tz=timezone.utc):
|
||||
if not v.tzinfo:
|
||||
raise ValueError("Event date must be timezone-aware")
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
if v < now:
|
||||
raise ValueError("Event date cannot be in the past")
|
||||
return v
|
||||
|
||||
@field_validator('rsvp_deadline')
|
||||
def validate_rsvp_deadline(cls, v, values):
|
||||
if v and 'event_date' in values.data:
|
||||
if v > values.data['event_date']:
|
||||
raise ValueError("RSVP deadline must be before event date")
|
||||
if v:
|
||||
if not v.tzinfo:
|
||||
raise ValueError("RSVP deadline must be timezone-aware")
|
||||
if 'event_date' in values.data:
|
||||
if v > values.data['event_date']:
|
||||
raise ValueError("RSVP deadline must be before event date")
|
||||
return v
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user