94 lines
2.3 KiB
Python
94 lines
2.3 KiB
Python
from typing import List, Optional, Dict, Any
|
|
|
|
from pydantic import BaseModel
|
|
|
|
|
|
class SampleConfig(BaseModel):
|
|
sampler: str
|
|
sample_every: int
|
|
width: int
|
|
height: int
|
|
prompts: List[str]
|
|
neg: str
|
|
seed: int
|
|
walk_seed: bool
|
|
guidance_scale: float
|
|
sample_steps: int
|
|
|
|
|
|
class DatasetConfig(BaseModel):
|
|
folder_path: str
|
|
caption_ext: Optional[str] = None
|
|
caption_dropout_rate: Optional[float] = None
|
|
shuffle_tokens: Optional[bool] = False
|
|
resolution: Optional[List[int]] = None
|
|
|
|
|
|
class EMAConfig(BaseModel):
|
|
use_ema: Optional[bool] = False
|
|
ema_decay: Optional[float] = None
|
|
|
|
|
|
class TrainConfig(BaseModel):
|
|
batch_size: int
|
|
bypass_guidance_embedding: Optional[bool] = False
|
|
timestep_type: Optional[str] = None
|
|
steps: int
|
|
gradient_accumulation: Optional[int] = 1
|
|
train_unet: Optional[bool] = True
|
|
train_text_encoder: Optional[bool] = False
|
|
gradient_checkpointing: Optional[bool] = False
|
|
noise_scheduler: Optional[str] = None
|
|
optimizer: Optional[str] = None
|
|
lr: Optional[float] = None
|
|
ema_config: Optional[EMAConfig] = None
|
|
dtype: Optional[str] = None
|
|
do_paramiter_swapping: Optional[bool] = False
|
|
paramiter_swapping_factor: Optional[float] = None
|
|
skip_first_sample: Optional[bool] = False
|
|
disable_sampling: Optional[bool] = False
|
|
|
|
|
|
class ModelConfig(BaseModel):
|
|
name_or_path: str
|
|
is_flux: Optional[bool] = False
|
|
quantize: Optional[bool] = False
|
|
quantize_te: Optional[bool] = False
|
|
|
|
|
|
class SaveConfig(BaseModel):
|
|
dtype: Optional[str] = None
|
|
save_every: Optional[int] = None
|
|
max_step_saves_to_keep: Optional[int] = None
|
|
save_format: Optional[str] = None
|
|
|
|
|
|
class ProcessConfig(BaseModel):
|
|
type: str
|
|
training_folder: str
|
|
performance_log_every: Optional[int] = None
|
|
device: Optional[str] = None
|
|
trigger_word: Optional[str] = None
|
|
save: Optional[SaveConfig] = None
|
|
datasets: List[DatasetConfig]
|
|
train: TrainConfig
|
|
model: ModelConfig
|
|
sample: SampleConfig
|
|
|
|
|
|
class MetaConfig(BaseModel):
|
|
name: Optional[str] = None
|
|
version: Optional[str] = None
|
|
|
|
|
|
class TrainingConfig(BaseModel):
|
|
job: str
|
|
config: Dict[str, Any] # This will contain 'name' and 'process'
|
|
meta: MetaConfig
|
|
|
|
|
|
# And a Config class to represent the middle layer:
|
|
class Config(BaseModel):
|
|
name: str
|
|
process: List[ProcessConfig]
|