from typing import List, Optional, Dict, Any from pydantic import BaseModel class SampleConfig(BaseModel): sampler: str sample_every: int width: int height: int prompts: List[str] neg: str seed: int walk_seed: bool guidance_scale: float sample_steps: int class DatasetConfig(BaseModel): folder_path: str caption_ext: Optional[str] = None caption_dropout_rate: Optional[float] = None shuffle_tokens: Optional[bool] = False resolution: Optional[List[int]] = None class EMAConfig(BaseModel): use_ema: Optional[bool] = False ema_decay: Optional[float] = None class TrainConfig(BaseModel): batch_size: int bypass_guidance_embedding: Optional[bool] = False timestep_type: Optional[str] = None steps: int gradient_accumulation: Optional[int] = 1 train_unet: Optional[bool] = True train_text_encoder: Optional[bool] = False gradient_checkpointing: Optional[bool] = False noise_scheduler: Optional[str] = None optimizer: Optional[str] = None lr: Optional[float] = None ema_config: Optional[EMAConfig] = None dtype: Optional[str] = None do_paramiter_swapping: Optional[bool] = False paramiter_swapping_factor: Optional[float] = None skip_first_sample: Optional[bool] = False disable_sampling: Optional[bool] = False class ModelConfig(BaseModel): name_or_path: str is_flux: Optional[bool] = False quantize: Optional[bool] = False quantize_te: Optional[bool] = False class SaveConfig(BaseModel): dtype: Optional[str] = None save_every: Optional[int] = None max_step_saves_to_keep: Optional[int] = None save_format: Optional[str] = None class ProcessConfig(BaseModel): type: str training_folder: str performance_log_every: Optional[int] = None device: Optional[str] = None trigger_word: Optional[str] = None save: Optional[SaveConfig] = None datasets: List[DatasetConfig] train: TrainConfig model: ModelConfig sample: SampleConfig class MetaConfig(BaseModel): name: Optional[str] = None version: Optional[str] = None class TrainingConfig(BaseModel): job: str config: Dict[str, Any] # This will contain 'name' and 'process' meta: MetaConfig # And a Config class to represent the middle layer: class Config(BaseModel): name: str process: List[ProcessConfig]