import uuid from datetime import datetime, timezone from typing import Optional, Any from sqlalchemy.orm import Session from app.crud.base import CRUDBase from app.models.rsvp import RSVP from app.schemas.rsvp import RSVPSchemaCreate, RSVPSchemaUpdate, RSVPStatus class CRUDRSVP(CRUDBase[RSVP, RSVPSchemaCreate, RSVPSchemaUpdate]): def create(self, db: Session, *, obj_in: RSVPSchemaCreate) -> RSVP: """ Creates a new RSVP entry, ensuring event_id and guest_id UUID validation """ event_uuid = uuid.UUID(obj_in.event_id) if isinstance(obj_in.event_id, str) else obj_in.event_id guest_uuid = uuid.UUID(obj_in.guest_id) if isinstance(obj_in.guest_id, str) else obj_in.guest_id rsvp_obj = RSVP( id=str(uuid.uuid4()), event_id=event_uuid, guest_id=guest_uuid, status=obj_in.status, number_of_guests=max(1, obj_in.number_of_guests), response_message=obj_in.response_message, dietary_requirements=obj_in.dietary_requirements, additional_info=obj_in.additional_info, response_date=datetime.now(timezone.utc), ) db.add(rsvp_obj) db.commit() db.refresh(rsvp_obj) return rsvp_obj @staticmethod def get_rsvp_by_event_and_guest(db: Session, *, event_id: str | uuid.UUID, guest_id: str | uuid.UUID) -> \ Optional[RSVP]: event_uuid = uuid.UUID(event_id) if isinstance(event_id, str) else event_id guest_uuid = uuid.UUID(guest_id) if isinstance(guest_id, str) else guest_id return db.query(RSVP).filter( RSVP.event_id == event_uuid, RSVP.guest_id == guest_uuid ).first() def update_rsvp_status( self, db: Session, *, db_obj: RSVP, status: RSVPStatus, number_of_guests: Optional[int] = None, response_message: Optional[str] = None, dietary_requirements: Optional[str] = None, additional_info: Optional[dict[str, Any]] = None ) -> RSVP: db_obj = super().update(db, db_obj=db_obj, obj_in={ "status": status, "number_of_guests": max(1, number_of_guests) if number_of_guests else db_obj.number_of_guests, "response_message": response_message, "dietary_requirements": dietary_requirements, "additional_info": additional_info }) return db_obj def delete_by_event_and_guest(self, db: Session, *, event_id: str | uuid.UUID, guest_id: str | uuid.UUID) -> RSVP: db_obj = self.get_rsvp_by_event_and_guest(db, event_id=event_id, guest_id=guest_id) if not db_obj: raise ValueError("RSVP not found") db.delete(db_obj) db.commit() return db_obj crud_rsvp = CRUDRSVP(RSVP)