diff --git a/backend/app/api/routes/events/guests.py b/backend/app/api/routes/events/guests.py index 07fa513..a13c64f 100644 --- a/backend/app/api/routes/events/guests.py +++ b/backend/app/api/routes/events/guests.py @@ -13,6 +13,14 @@ router = APIRouter() @router.post("/", response_model=GuestRead, operation_id="create_guest") def create_guest(guest_in: GuestCreate, db: Session = Depends(get_db)): + if guest_in.invitation_code: + existing_guest = guest_crud.get_by_invitation_code(db, guest_in.invitation_code) + if existing_guest: + raise HTTPException( + status_code=400, detail="Guest with this invitation code already exists" + ) + else: + guest_in.invitation_code = str(uuid.uuid4())[:8] guest = guest_crud.create(db, obj_in=guest_in) return guest diff --git a/backend/app/crud/guest.py b/backend/app/crud/guest.py index 73a6533..6ace207 100644 --- a/backend/app/crud/guest.py +++ b/backend/app/crud/guest.py @@ -11,6 +11,10 @@ import uuid class CRUDGuest(CRUDBase[Guest, GuestCreate, GuestUpdate]): def create(self, db, obj_in: GuestCreate): + + if obj_in.invitation_code is None: + obj_in.invitation_code = str(uuid.uuid4())[:8] + db_guest = Guest( event_id=uuid.UUID(obj_in.event_id) if isinstance(obj_in.event_id, str) else obj_in.event_id, # explicit casting invited_by=uuid.UUID(obj_in.invited_by) if isinstance(obj_in.invited_by, str) else obj_in.invited_by, diff --git a/backend/app/schemas/guests.py b/backend/app/schemas/guests.py index 0985dfc..44cb411 100644 --- a/backend/app/schemas/guests.py +++ b/backend/app/schemas/guests.py @@ -21,7 +21,7 @@ class GuestBase(BaseModel): class GuestCreate(GuestBase): - invitation_code: str + invitation_code: Optional[str] = None class GuestUpdate(BaseModel): @@ -45,4 +45,5 @@ class GuestRead(GuestBase): response_date: Optional[datetime] = None actual_additional_guests: int is_blocked: bool - model_config = ConfigDict(from_attributes=True) \ No newline at end of file + model_config = ConfigDict(from_attributes=True) + invitation_code: str \ No newline at end of file diff --git a/backend/tests/api/routes/events/test_guests.py b/backend/tests/api/routes/events/test_guests.py index d4ef96a..10a9be8 100644 --- a/backend/tests/api/routes/events/test_guests.py +++ b/backend/tests/api/routes/events/test_guests.py @@ -39,6 +39,33 @@ class TestGuestsRouter: assert data["full_name"] == guest_data["full_name"] assert data["email"] == guest_data["email"] + def test_create_guest_fails_on_duplicate_invitation_code(self, guest_data): + # First create the guest successfully + response_initial = self.client.post(self.endpoint, json=guest_data) + assert response_initial.status_code == 200 + + # Attempt to create another guest with the same invitation code + new_guest_data = guest_data.copy() + new_guest_data["email"] = "new.email@example.com" + response_duplicate = self.client.post(self.endpoint, json=new_guest_data) + + assert response_duplicate.status_code == 400 + assert response_duplicate.json()["detail"] == "Guest with this invitation code already exists" + + + def test_create_guest_generates_invitation_code_if_not_provided(self, guest_data): + # Remove invitation_code to test auto-generation + guest_data_without_code = guest_data.copy() + guest_data_without_code.pop("invitation_code", None) + + response = self.client.post(self.endpoint, json=guest_data_without_code) + assert response.status_code == 200 + + data = response.json() + print(data) + assert "invitation_code" in data + assert len(data["invitation_code"]) == 8 + def test_create_guest_missing_required_fields_fails(self): incomplete_payload = { "email": "john.doe@example.com"