From 37b71464f669e1c085d902dbfcb1c52a9342fe7f Mon Sep 17 00:00:00 2001 From: Felipe Cardoso Date: Thu, 23 Jan 2025 13:46:30 +0100 Subject: [PATCH] Add config parsing support in backend Signed-off-by: Felipe Cardoso --- backend/app/api/routes/config.py | 12 ++ backend/app/core/config.py | 2 + backend/app/main.py | 53 ++++++++- backend/app/models/config.py | 93 +++++++++++++++ backend/app/services/config_manager.py | 151 +++++++++++++++++++++++++ backend/requirements.txt | 3 +- 6 files changed, 309 insertions(+), 5 deletions(-) create mode 100644 backend/app/api/routes/config.py create mode 100644 backend/app/models/config.py create mode 100644 backend/app/services/config_manager.py diff --git a/backend/app/api/routes/config.py b/backend/app/api/routes/config.py new file mode 100644 index 0000000..a0354f7 --- /dev/null +++ b/backend/app/api/routes/config.py @@ -0,0 +1,12 @@ +from fastapi import APIRouter, Request + +from app.models.config import TrainingConfig + +router = APIRouter() + + +@router.get("/config", response_model=TrainingConfig) +async def get_training_config(request: Request): + """Retrieves the current training configuration""" + config_manager = request.app.state.config_manager + return await config_manager.get_config() diff --git a/backend/app/core/config.py b/backend/app/core/config.py index 25ec48d..937856b 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -11,10 +11,12 @@ class Settings(BaseSettings): SFTP_PATH: Optional[str] = None SFTP_PORT: int = 22 TRAINING_LOG_REMOTE_PATH: Optional[str] = None + TRAINING_CONFIG_LOCAL_PATH: Optional[str] = None # Local Settings (Optional) LOCAL_PATH: Optional[str] = None TRAINING_LOG_LOCAL_PATH: Optional[str] = None + TRAINING_CONFIG_REMOTE_PATH: Optional[str] = None # API Settings PROJECT_NAME: str = "Training Monitor" diff --git a/backend/app/main.py b/backend/app/main.py index 7f4e60c..176d723 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -6,8 +6,9 @@ import psutil from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware -from app.api.routes import training, samples +from app.api.routes import training, samples, config from app.core.config import settings +from app.services.config_manager import ConfigManager from app.services.sample_manager import SampleManager from app.services.training_monitor import TrainingMonitor @@ -35,17 +36,60 @@ app.add_middleware( # Create and store SampleManager instance sample_manager = SampleManager() training_monitor = TrainingMonitor() +config_manager = ConfigManager() app.state.sample_manager = sample_manager app.state.training_monitor = training_monitor +app.state.config_manager = config_manager + + +async def initialize_services(): + """ + Initializes all service managers in the correct order, ensuring dependencies + are properly set up before they're needed. + """ + logger.info("Starting services initialization...") + + # First, initialize config manager as other services might need configuration + config_manager = ConfigManager() + app.state.config_manager = config_manager + + try: + # Load initial configuration + config = await config_manager.get_config() + logger.info(f"Configuration loaded successfully for job: {config.job}") + + # Store config in app state for easy access + app.state.training_config = config + + # Initialize other managers with access to config + sample_manager = SampleManager() + training_monitor = TrainingMonitor() + + # Store managers in app state + app.state.sample_manager = sample_manager + app.state.training_monitor = training_monitor + + # Start the managers + await sample_manager.startup() + await training_monitor.startup() + + logger.info("All services initialized successfully") + + except Exception as e: + logger.error(f"Failed to initialize services: {str(e)}") + # Re-raise to prevent app from starting with partially initialized services + raise @app.on_event("startup") async def startup_event(): - """Initialize services on startup""" + """ + Startup event handler that ensures all services are properly initialized + before the application starts accepting requests. + """ logger.info("Starting up Training Monitor API") - await sample_manager.startup() - await training_monitor.startup() + await initialize_services() @@ -60,6 +104,7 @@ async def shutdown_event(): # Include routers with versioning app.include_router(training.router, prefix=f"{settings.API_VER_STR}/training", tags=["training"]) app.include_router(samples.router, prefix=f"{settings.API_VER_STR}/samples", tags=["samples"]) +app.include_router(config.router, prefix=f"{settings.API_VER_STR}/config", tags=["config"]) @app.get("/") diff --git a/backend/app/models/config.py b/backend/app/models/config.py new file mode 100644 index 0000000..7b803b3 --- /dev/null +++ b/backend/app/models/config.py @@ -0,0 +1,93 @@ +from typing import List, Optional, Dict, Any + +from pydantic import BaseModel + + +class SampleConfig(BaseModel): + sampler: str + sample_every: int + width: int + height: int + prompts: List[str] + neg: str + seed: int + walk_seed: bool + guidance_scale: float + sample_steps: int + + +class DatasetConfig(BaseModel): + folder_path: str + caption_ext: Optional[str] = None + caption_dropout_rate: Optional[float] = None + shuffle_tokens: Optional[bool] = False + resolution: Optional[List[int]] = None + + +class EMAConfig(BaseModel): + use_ema: Optional[bool] = False + ema_decay: Optional[float] = None + + +class TrainConfig(BaseModel): + batch_size: int + bypass_guidance_embedding: Optional[bool] = False + timestep_type: Optional[str] = None + steps: int + gradient_accumulation: Optional[int] = 1 + train_unet: Optional[bool] = True + train_text_encoder: Optional[bool] = False + gradient_checkpointing: Optional[bool] = False + noise_scheduler: Optional[str] = None + optimizer: Optional[str] = None + lr: Optional[float] = None + ema_config: Optional[EMAConfig] = None + dtype: Optional[str] = None + do_paramiter_swapping: Optional[bool] = False + paramiter_swapping_factor: Optional[float] = None + skip_first_sample: Optional[bool] = False + disable_sampling: Optional[bool] = False + + +class ModelConfig(BaseModel): + name_or_path: str + is_flux: Optional[bool] = False + quantize: Optional[bool] = False + quantize_te: Optional[bool] = False + + +class SaveConfig(BaseModel): + dtype: Optional[str] = None + save_every: Optional[int] = None + max_step_saves_to_keep: Optional[int] = None + save_format: Optional[str] = None + + +class ProcessConfig(BaseModel): + type: str + training_folder: str + performance_log_every: Optional[int] = None + device: Optional[str] = None + trigger_word: Optional[str] = None + save: Optional[SaveConfig] = None + datasets: List[DatasetConfig] + train: TrainConfig + model: ModelConfig + sample: SampleConfig + + +class MetaConfig(BaseModel): + name: Optional[str] = None + version: Optional[str] = None + + +class TrainingConfig(BaseModel): + job: str + config: Dict[str, Any] # This will contain 'name' and 'process' + meta: MetaConfig + + +# And a Config class to represent the middle layer: +class Config(BaseModel): + name: str + process: List[ProcessConfig] diff --git a/backend/app/services/config_manager.py b/backend/app/services/config_manager.py new file mode 100644 index 0000000..8f0a8fc --- /dev/null +++ b/backend/app/services/config_manager.py @@ -0,0 +1,151 @@ +# app/services/config_manager.py +import logging +import os + +import aiofiles +import paramiko +import yaml +from fastapi import HTTPException + +from app.core.config import settings +from app.models.config import TrainingConfig, ProcessConfig, SaveConfig, DatasetConfig + +logger = logging.getLogger(__name__) + + +class ConfigManager: + """ + Manages access to training configuration files, supporting both local and remote (SFTP) paths. + Handles YAML parsing and conversion to strongly-typed configuration objects. + """ + + def __init__(self): + # Initialize paths from settings, defaulting to None if not configured + self.remote_path = getattr(settings, 'TRAINING_CONFIG_REMOTE_PATH', None) + self.local_path = getattr(settings, 'TRAINING_CONFIG_LOCAL_PATH', None) + self.sftp_client = None + self.cached_config = None + + # Validate that at least one path is configured + if not self.remote_path and not self.local_path: + raise ValueError("Either TRAINING_CONFIG_REMOTE_PATH or TRAINING_CONFIG_LOCAL_PATH must be configured") + + logger.info(f"ConfigManager initialized with remote_path={self.remote_path}, local_path={self.local_path}") + + async def _connect_sftp(self): + """Establishes SFTP connection using SSH key authentication""" + try: + key_path = os.path.expanduser(settings.SFTP_KEY_PATH) + logger.info(f"Connecting to SFTP {settings.SFTP_HOST} with key {key_path}") + + ssh = paramiko.SSHClient() + ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + + ssh.connect( + hostname=settings.SFTP_HOST, + username=settings.SFTP_USER, + port=settings.SFTP_PORT, + key_filename=key_path, + ) + + self.sftp_client = ssh.open_sftp() + logger.info("SFTP connection established successfully") + + except Exception as e: + logger.error(f"Failed to establish SFTP connection: {str(e)}") + raise HTTPException(status_code=500, detail=f"SFTP connection failed: {str(e)}") + + def _disconnect_sftp(self): + """Safely closes SFTP connection if it exists""" + if self.sftp_client: + try: + self.sftp_client.close() + self.sftp_client = None + logger.info("SFTP connection closed") + except Exception as e: + logger.error(f"Error closing SFTP connection: {str(e)}") + + async def _read_remote_config(self) -> dict: + """Reads and parses YAML configuration from remote SFTP location""" + if not self.sftp_client: + await self._connect_sftp() + + try: + with self.sftp_client.open(self.remote_path, 'r') as f: + content = f.read() + return yaml.safe_load(content) + except Exception as e: + logger.error(f"Failed to read remote config: {str(e)}") + raise HTTPException(status_code=500, detail=f"Failed to read remote config: {str(e)}") + finally: + self._disconnect_sftp() + + async def _read_local_config(self) -> dict: + """Reads and parses YAML configuration from local filesystem""" + try: + async with aiofiles.open(self.local_path, 'r') as f: + content = await f.read() + return yaml.safe_load(content) + except Exception as e: + logger.error(f"Failed to read local config: {str(e)}") + raise HTTPException(status_code=500, detail=f"Failed to read local config: {str(e)}") + + def _parse_config(self, raw_config: dict) -> TrainingConfig: + """ + Converts raw YAML dictionary into strongly-typed configuration objects. + Handles optional fields and nested configurations. + """ + try: + # Extract the first process configuration (assuming single process for now) + process_data = raw_config['config']['process'][0] + + # Build the process config with all its nested components + process = ProcessConfig( + type=process_data['type'], + training_folder=process_data['training_folder'], + performance_log_every=process_data.get('performance_log_every'), + device=process_data.get('device'), + trigger_word=process_data.get('trigger_word'), + save=SaveConfig(**process_data['save']) if 'save' in process_data else None, + datasets=[DatasetConfig(**ds) for ds in process_data.get('datasets', [])], + train=process_data['train'], + model=process_data['model'], + sample=process_data['sample'] + ) + + # Reconstruct the config dictionary with our parsed process + config_dict = dict(raw_config['config']) + config_dict['process'] = [process] + + # Create the full training config + return TrainingConfig( + job=raw_config.get('job', ''), + config=config_dict, + meta=raw_config.get('meta', {}) + ) + + except Exception as e: + logger.error(f"Failed to parse config: {str(e)}") + raise HTTPException(status_code=500, detail=f"Config parsing failed: {str(e)}") + + async def get_config(self) -> TrainingConfig: + """ + Main method to retrieve and parse configuration. + Automatically handles local or remote access based on configuration. + """ + if self.cached_config is not None: + return self.cached_config + + try: + # Read raw config from appropriate source + raw_config = await self._read_remote_config() if self.remote_path else await self._read_local_config() + + # Parse and return strongly-typed config + + parsed_config = self._parse_config(raw_config) + self.cached_config = parsed_config + return parsed_config + + except Exception as e: + logger.error(f"Failed to get config: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) diff --git a/backend/requirements.txt b/backend/requirements.txt index b834626..c541a2a 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -11,4 +11,5 @@ aiofiles>=23.2.1 pytest>=7.4.3 httpx>=0.25.1 pytest-asyncio>=0.21.1 -psutil>=5.9.8 \ No newline at end of file +psutil>=5.9.8 +PyYAML~=6.0.2 \ No newline at end of file