Add config parsing support in backend
Signed-off-by: Felipe Cardoso <felipe.cardoso@hotmail.it>
This commit is contained in:
12
backend/app/api/routes/config.py
Normal file
12
backend/app/api/routes/config.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
from fastapi import APIRouter, Request
|
||||||
|
|
||||||
|
from app.models.config import TrainingConfig
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/config", response_model=TrainingConfig)
|
||||||
|
async def get_training_config(request: Request):
|
||||||
|
"""Retrieves the current training configuration"""
|
||||||
|
config_manager = request.app.state.config_manager
|
||||||
|
return await config_manager.get_config()
|
||||||
@@ -11,10 +11,12 @@ class Settings(BaseSettings):
|
|||||||
SFTP_PATH: Optional[str] = None
|
SFTP_PATH: Optional[str] = None
|
||||||
SFTP_PORT: int = 22
|
SFTP_PORT: int = 22
|
||||||
TRAINING_LOG_REMOTE_PATH: Optional[str] = None
|
TRAINING_LOG_REMOTE_PATH: Optional[str] = None
|
||||||
|
TRAINING_CONFIG_LOCAL_PATH: Optional[str] = None
|
||||||
|
|
||||||
# Local Settings (Optional)
|
# Local Settings (Optional)
|
||||||
LOCAL_PATH: Optional[str] = None
|
LOCAL_PATH: Optional[str] = None
|
||||||
TRAINING_LOG_LOCAL_PATH: Optional[str] = None
|
TRAINING_LOG_LOCAL_PATH: Optional[str] = None
|
||||||
|
TRAINING_CONFIG_REMOTE_PATH: Optional[str] = None
|
||||||
|
|
||||||
# API Settings
|
# API Settings
|
||||||
PROJECT_NAME: str = "Training Monitor"
|
PROJECT_NAME: str = "Training Monitor"
|
||||||
|
|||||||
@@ -6,8 +6,9 @@ import psutil
|
|||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
from app.api.routes import training, samples
|
from app.api.routes import training, samples, config
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
|
from app.services.config_manager import ConfigManager
|
||||||
from app.services.sample_manager import SampleManager
|
from app.services.sample_manager import SampleManager
|
||||||
from app.services.training_monitor import TrainingMonitor
|
from app.services.training_monitor import TrainingMonitor
|
||||||
|
|
||||||
@@ -35,17 +36,60 @@ app.add_middleware(
|
|||||||
# Create and store SampleManager instance
|
# Create and store SampleManager instance
|
||||||
sample_manager = SampleManager()
|
sample_manager = SampleManager()
|
||||||
training_monitor = TrainingMonitor()
|
training_monitor = TrainingMonitor()
|
||||||
|
config_manager = ConfigManager()
|
||||||
|
|
||||||
app.state.sample_manager = sample_manager
|
app.state.sample_manager = sample_manager
|
||||||
app.state.training_monitor = training_monitor
|
app.state.training_monitor = training_monitor
|
||||||
|
app.state.config_manager = config_manager
|
||||||
|
|
||||||
|
|
||||||
|
async def initialize_services():
|
||||||
|
"""
|
||||||
|
Initializes all service managers in the correct order, ensuring dependencies
|
||||||
|
are properly set up before they're needed.
|
||||||
|
"""
|
||||||
|
logger.info("Starting services initialization...")
|
||||||
|
|
||||||
|
# First, initialize config manager as other services might need configuration
|
||||||
|
config_manager = ConfigManager()
|
||||||
|
app.state.config_manager = config_manager
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Load initial configuration
|
||||||
|
config = await config_manager.get_config()
|
||||||
|
logger.info(f"Configuration loaded successfully for job: {config.job}")
|
||||||
|
|
||||||
|
# Store config in app state for easy access
|
||||||
|
app.state.training_config = config
|
||||||
|
|
||||||
|
# Initialize other managers with access to config
|
||||||
|
sample_manager = SampleManager()
|
||||||
|
training_monitor = TrainingMonitor()
|
||||||
|
|
||||||
|
# Store managers in app state
|
||||||
|
app.state.sample_manager = sample_manager
|
||||||
|
app.state.training_monitor = training_monitor
|
||||||
|
|
||||||
|
# Start the managers
|
||||||
|
await sample_manager.startup()
|
||||||
|
await training_monitor.startup()
|
||||||
|
|
||||||
|
logger.info("All services initialized successfully")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize services: {str(e)}")
|
||||||
|
# Re-raise to prevent app from starting with partially initialized services
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
async def startup_event():
|
async def startup_event():
|
||||||
"""Initialize services on startup"""
|
"""
|
||||||
|
Startup event handler that ensures all services are properly initialized
|
||||||
|
before the application starts accepting requests.
|
||||||
|
"""
|
||||||
logger.info("Starting up Training Monitor API")
|
logger.info("Starting up Training Monitor API")
|
||||||
await sample_manager.startup()
|
await initialize_services()
|
||||||
await training_monitor.startup()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -60,6 +104,7 @@ async def shutdown_event():
|
|||||||
# Include routers with versioning
|
# Include routers with versioning
|
||||||
app.include_router(training.router, prefix=f"{settings.API_VER_STR}/training", tags=["training"])
|
app.include_router(training.router, prefix=f"{settings.API_VER_STR}/training", tags=["training"])
|
||||||
app.include_router(samples.router, prefix=f"{settings.API_VER_STR}/samples", tags=["samples"])
|
app.include_router(samples.router, prefix=f"{settings.API_VER_STR}/samples", tags=["samples"])
|
||||||
|
app.include_router(config.router, prefix=f"{settings.API_VER_STR}/config", tags=["config"])
|
||||||
|
|
||||||
|
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
|
|||||||
93
backend/app/models/config.py
Normal file
93
backend/app/models/config.py
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
from typing import List, Optional, Dict, Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class SampleConfig(BaseModel):
|
||||||
|
sampler: str
|
||||||
|
sample_every: int
|
||||||
|
width: int
|
||||||
|
height: int
|
||||||
|
prompts: List[str]
|
||||||
|
neg: str
|
||||||
|
seed: int
|
||||||
|
walk_seed: bool
|
||||||
|
guidance_scale: float
|
||||||
|
sample_steps: int
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetConfig(BaseModel):
|
||||||
|
folder_path: str
|
||||||
|
caption_ext: Optional[str] = None
|
||||||
|
caption_dropout_rate: Optional[float] = None
|
||||||
|
shuffle_tokens: Optional[bool] = False
|
||||||
|
resolution: Optional[List[int]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class EMAConfig(BaseModel):
|
||||||
|
use_ema: Optional[bool] = False
|
||||||
|
ema_decay: Optional[float] = None
|
||||||
|
|
||||||
|
|
||||||
|
class TrainConfig(BaseModel):
|
||||||
|
batch_size: int
|
||||||
|
bypass_guidance_embedding: Optional[bool] = False
|
||||||
|
timestep_type: Optional[str] = None
|
||||||
|
steps: int
|
||||||
|
gradient_accumulation: Optional[int] = 1
|
||||||
|
train_unet: Optional[bool] = True
|
||||||
|
train_text_encoder: Optional[bool] = False
|
||||||
|
gradient_checkpointing: Optional[bool] = False
|
||||||
|
noise_scheduler: Optional[str] = None
|
||||||
|
optimizer: Optional[str] = None
|
||||||
|
lr: Optional[float] = None
|
||||||
|
ema_config: Optional[EMAConfig] = None
|
||||||
|
dtype: Optional[str] = None
|
||||||
|
do_paramiter_swapping: Optional[bool] = False
|
||||||
|
paramiter_swapping_factor: Optional[float] = None
|
||||||
|
skip_first_sample: Optional[bool] = False
|
||||||
|
disable_sampling: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
|
class ModelConfig(BaseModel):
|
||||||
|
name_or_path: str
|
||||||
|
is_flux: Optional[bool] = False
|
||||||
|
quantize: Optional[bool] = False
|
||||||
|
quantize_te: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
|
class SaveConfig(BaseModel):
|
||||||
|
dtype: Optional[str] = None
|
||||||
|
save_every: Optional[int] = None
|
||||||
|
max_step_saves_to_keep: Optional[int] = None
|
||||||
|
save_format: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ProcessConfig(BaseModel):
|
||||||
|
type: str
|
||||||
|
training_folder: str
|
||||||
|
performance_log_every: Optional[int] = None
|
||||||
|
device: Optional[str] = None
|
||||||
|
trigger_word: Optional[str] = None
|
||||||
|
save: Optional[SaveConfig] = None
|
||||||
|
datasets: List[DatasetConfig]
|
||||||
|
train: TrainConfig
|
||||||
|
model: ModelConfig
|
||||||
|
sample: SampleConfig
|
||||||
|
|
||||||
|
|
||||||
|
class MetaConfig(BaseModel):
|
||||||
|
name: Optional[str] = None
|
||||||
|
version: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class TrainingConfig(BaseModel):
|
||||||
|
job: str
|
||||||
|
config: Dict[str, Any] # This will contain 'name' and 'process'
|
||||||
|
meta: MetaConfig
|
||||||
|
|
||||||
|
|
||||||
|
# And a Config class to represent the middle layer:
|
||||||
|
class Config(BaseModel):
|
||||||
|
name: str
|
||||||
|
process: List[ProcessConfig]
|
||||||
151
backend/app/services/config_manager.py
Normal file
151
backend/app/services/config_manager.py
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
# app/services/config_manager.py
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
import aiofiles
|
||||||
|
import paramiko
|
||||||
|
import yaml
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
from app.models.config import TrainingConfig, ProcessConfig, SaveConfig, DatasetConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigManager:
|
||||||
|
"""
|
||||||
|
Manages access to training configuration files, supporting both local and remote (SFTP) paths.
|
||||||
|
Handles YAML parsing and conversion to strongly-typed configuration objects.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# Initialize paths from settings, defaulting to None if not configured
|
||||||
|
self.remote_path = getattr(settings, 'TRAINING_CONFIG_REMOTE_PATH', None)
|
||||||
|
self.local_path = getattr(settings, 'TRAINING_CONFIG_LOCAL_PATH', None)
|
||||||
|
self.sftp_client = None
|
||||||
|
self.cached_config = None
|
||||||
|
|
||||||
|
# Validate that at least one path is configured
|
||||||
|
if not self.remote_path and not self.local_path:
|
||||||
|
raise ValueError("Either TRAINING_CONFIG_REMOTE_PATH or TRAINING_CONFIG_LOCAL_PATH must be configured")
|
||||||
|
|
||||||
|
logger.info(f"ConfigManager initialized with remote_path={self.remote_path}, local_path={self.local_path}")
|
||||||
|
|
||||||
|
async def _connect_sftp(self):
|
||||||
|
"""Establishes SFTP connection using SSH key authentication"""
|
||||||
|
try:
|
||||||
|
key_path = os.path.expanduser(settings.SFTP_KEY_PATH)
|
||||||
|
logger.info(f"Connecting to SFTP {settings.SFTP_HOST} with key {key_path}")
|
||||||
|
|
||||||
|
ssh = paramiko.SSHClient()
|
||||||
|
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
||||||
|
|
||||||
|
ssh.connect(
|
||||||
|
hostname=settings.SFTP_HOST,
|
||||||
|
username=settings.SFTP_USER,
|
||||||
|
port=settings.SFTP_PORT,
|
||||||
|
key_filename=key_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.sftp_client = ssh.open_sftp()
|
||||||
|
logger.info("SFTP connection established successfully")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to establish SFTP connection: {str(e)}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"SFTP connection failed: {str(e)}")
|
||||||
|
|
||||||
|
def _disconnect_sftp(self):
|
||||||
|
"""Safely closes SFTP connection if it exists"""
|
||||||
|
if self.sftp_client:
|
||||||
|
try:
|
||||||
|
self.sftp_client.close()
|
||||||
|
self.sftp_client = None
|
||||||
|
logger.info("SFTP connection closed")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error closing SFTP connection: {str(e)}")
|
||||||
|
|
||||||
|
async def _read_remote_config(self) -> dict:
|
||||||
|
"""Reads and parses YAML configuration from remote SFTP location"""
|
||||||
|
if not self.sftp_client:
|
||||||
|
await self._connect_sftp()
|
||||||
|
|
||||||
|
try:
|
||||||
|
with self.sftp_client.open(self.remote_path, 'r') as f:
|
||||||
|
content = f.read()
|
||||||
|
return yaml.safe_load(content)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to read remote config: {str(e)}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"Failed to read remote config: {str(e)}")
|
||||||
|
finally:
|
||||||
|
self._disconnect_sftp()
|
||||||
|
|
||||||
|
async def _read_local_config(self) -> dict:
|
||||||
|
"""Reads and parses YAML configuration from local filesystem"""
|
||||||
|
try:
|
||||||
|
async with aiofiles.open(self.local_path, 'r') as f:
|
||||||
|
content = await f.read()
|
||||||
|
return yaml.safe_load(content)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to read local config: {str(e)}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"Failed to read local config: {str(e)}")
|
||||||
|
|
||||||
|
def _parse_config(self, raw_config: dict) -> TrainingConfig:
|
||||||
|
"""
|
||||||
|
Converts raw YAML dictionary into strongly-typed configuration objects.
|
||||||
|
Handles optional fields and nested configurations.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Extract the first process configuration (assuming single process for now)
|
||||||
|
process_data = raw_config['config']['process'][0]
|
||||||
|
|
||||||
|
# Build the process config with all its nested components
|
||||||
|
process = ProcessConfig(
|
||||||
|
type=process_data['type'],
|
||||||
|
training_folder=process_data['training_folder'],
|
||||||
|
performance_log_every=process_data.get('performance_log_every'),
|
||||||
|
device=process_data.get('device'),
|
||||||
|
trigger_word=process_data.get('trigger_word'),
|
||||||
|
save=SaveConfig(**process_data['save']) if 'save' in process_data else None,
|
||||||
|
datasets=[DatasetConfig(**ds) for ds in process_data.get('datasets', [])],
|
||||||
|
train=process_data['train'],
|
||||||
|
model=process_data['model'],
|
||||||
|
sample=process_data['sample']
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reconstruct the config dictionary with our parsed process
|
||||||
|
config_dict = dict(raw_config['config'])
|
||||||
|
config_dict['process'] = [process]
|
||||||
|
|
||||||
|
# Create the full training config
|
||||||
|
return TrainingConfig(
|
||||||
|
job=raw_config.get('job', ''),
|
||||||
|
config=config_dict,
|
||||||
|
meta=raw_config.get('meta', {})
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to parse config: {str(e)}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"Config parsing failed: {str(e)}")
|
||||||
|
|
||||||
|
async def get_config(self) -> TrainingConfig:
|
||||||
|
"""
|
||||||
|
Main method to retrieve and parse configuration.
|
||||||
|
Automatically handles local or remote access based on configuration.
|
||||||
|
"""
|
||||||
|
if self.cached_config is not None:
|
||||||
|
return self.cached_config
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Read raw config from appropriate source
|
||||||
|
raw_config = await self._read_remote_config() if self.remote_path else await self._read_local_config()
|
||||||
|
|
||||||
|
# Parse and return strongly-typed config
|
||||||
|
|
||||||
|
parsed_config = self._parse_config(raw_config)
|
||||||
|
self.cached_config = parsed_config
|
||||||
|
return parsed_config
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get config: {str(e)}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
@@ -11,4 +11,5 @@ aiofiles>=23.2.1
|
|||||||
pytest>=7.4.3
|
pytest>=7.4.3
|
||||||
httpx>=0.25.1
|
httpx>=0.25.1
|
||||||
pytest-asyncio>=0.21.1
|
pytest-asyncio>=0.21.1
|
||||||
psutil>=5.9.8
|
psutil>=5.9.8
|
||||||
|
PyYAML~=6.0.2
|
||||||
Reference in New Issue
Block a user