Add config parsing support in backend
Signed-off-by: Felipe Cardoso <felipe.cardoso@hotmail.it>
This commit is contained in:
93
backend/app/models/config.py
Normal file
93
backend/app/models/config.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from typing import List, Optional, Dict, Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SampleConfig(BaseModel):
|
||||
sampler: str
|
||||
sample_every: int
|
||||
width: int
|
||||
height: int
|
||||
prompts: List[str]
|
||||
neg: str
|
||||
seed: int
|
||||
walk_seed: bool
|
||||
guidance_scale: float
|
||||
sample_steps: int
|
||||
|
||||
|
||||
class DatasetConfig(BaseModel):
|
||||
folder_path: str
|
||||
caption_ext: Optional[str] = None
|
||||
caption_dropout_rate: Optional[float] = None
|
||||
shuffle_tokens: Optional[bool] = False
|
||||
resolution: Optional[List[int]] = None
|
||||
|
||||
|
||||
class EMAConfig(BaseModel):
|
||||
use_ema: Optional[bool] = False
|
||||
ema_decay: Optional[float] = None
|
||||
|
||||
|
||||
class TrainConfig(BaseModel):
|
||||
batch_size: int
|
||||
bypass_guidance_embedding: Optional[bool] = False
|
||||
timestep_type: Optional[str] = None
|
||||
steps: int
|
||||
gradient_accumulation: Optional[int] = 1
|
||||
train_unet: Optional[bool] = True
|
||||
train_text_encoder: Optional[bool] = False
|
||||
gradient_checkpointing: Optional[bool] = False
|
||||
noise_scheduler: Optional[str] = None
|
||||
optimizer: Optional[str] = None
|
||||
lr: Optional[float] = None
|
||||
ema_config: Optional[EMAConfig] = None
|
||||
dtype: Optional[str] = None
|
||||
do_paramiter_swapping: Optional[bool] = False
|
||||
paramiter_swapping_factor: Optional[float] = None
|
||||
skip_first_sample: Optional[bool] = False
|
||||
disable_sampling: Optional[bool] = False
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
name_or_path: str
|
||||
is_flux: Optional[bool] = False
|
||||
quantize: Optional[bool] = False
|
||||
quantize_te: Optional[bool] = False
|
||||
|
||||
|
||||
class SaveConfig(BaseModel):
|
||||
dtype: Optional[str] = None
|
||||
save_every: Optional[int] = None
|
||||
max_step_saves_to_keep: Optional[int] = None
|
||||
save_format: Optional[str] = None
|
||||
|
||||
|
||||
class ProcessConfig(BaseModel):
|
||||
type: str
|
||||
training_folder: str
|
||||
performance_log_every: Optional[int] = None
|
||||
device: Optional[str] = None
|
||||
trigger_word: Optional[str] = None
|
||||
save: Optional[SaveConfig] = None
|
||||
datasets: List[DatasetConfig]
|
||||
train: TrainConfig
|
||||
model: ModelConfig
|
||||
sample: SampleConfig
|
||||
|
||||
|
||||
class MetaConfig(BaseModel):
|
||||
name: Optional[str] = None
|
||||
version: Optional[str] = None
|
||||
|
||||
|
||||
class TrainingConfig(BaseModel):
|
||||
job: str
|
||||
config: Dict[str, Any] # This will contain 'name' and 'process'
|
||||
meta: MetaConfig
|
||||
|
||||
|
||||
# And a Config class to represent the middle layer:
|
||||
class Config(BaseModel):
|
||||
name: str
|
||||
process: List[ProcessConfig]
|
||||
Reference in New Issue
Block a user