Files
ai-training-monitor/backend/app/models/config.py
2025-01-23 13:46:30 +01:00

94 lines
2.3 KiB
Python

from typing import List, Optional, Dict, Any
from pydantic import BaseModel
class SampleConfig(BaseModel):
sampler: str
sample_every: int
width: int
height: int
prompts: List[str]
neg: str
seed: int
walk_seed: bool
guidance_scale: float
sample_steps: int
class DatasetConfig(BaseModel):
folder_path: str
caption_ext: Optional[str] = None
caption_dropout_rate: Optional[float] = None
shuffle_tokens: Optional[bool] = False
resolution: Optional[List[int]] = None
class EMAConfig(BaseModel):
use_ema: Optional[bool] = False
ema_decay: Optional[float] = None
class TrainConfig(BaseModel):
batch_size: int
bypass_guidance_embedding: Optional[bool] = False
timestep_type: Optional[str] = None
steps: int
gradient_accumulation: Optional[int] = 1
train_unet: Optional[bool] = True
train_text_encoder: Optional[bool] = False
gradient_checkpointing: Optional[bool] = False
noise_scheduler: Optional[str] = None
optimizer: Optional[str] = None
lr: Optional[float] = None
ema_config: Optional[EMAConfig] = None
dtype: Optional[str] = None
do_paramiter_swapping: Optional[bool] = False
paramiter_swapping_factor: Optional[float] = None
skip_first_sample: Optional[bool] = False
disable_sampling: Optional[bool] = False
class ModelConfig(BaseModel):
name_or_path: str
is_flux: Optional[bool] = False
quantize: Optional[bool] = False
quantize_te: Optional[bool] = False
class SaveConfig(BaseModel):
dtype: Optional[str] = None
save_every: Optional[int] = None
max_step_saves_to_keep: Optional[int] = None
save_format: Optional[str] = None
class ProcessConfig(BaseModel):
type: str
training_folder: str
performance_log_every: Optional[int] = None
device: Optional[str] = None
trigger_word: Optional[str] = None
save: Optional[SaveConfig] = None
datasets: List[DatasetConfig]
train: TrainConfig
model: ModelConfig
sample: SampleConfig
class MetaConfig(BaseModel):
name: Optional[str] = None
version: Optional[str] = None
class TrainingConfig(BaseModel):
job: str
config: Dict[str, Any] # This will contain 'name' and 'process'
meta: MetaConfig
# And a Config class to represent the middle layer:
class Config(BaseModel):
name: str
process: List[ProcessConfig]