Add config parsing support in backend

Signed-off-by: Felipe Cardoso <felipe.cardoso@hotmail.it>
This commit is contained in:
2025-01-23 13:46:30 +01:00
parent f99564434a
commit 37b71464f6
6 changed files with 309 additions and 5 deletions

View File

@@ -0,0 +1,12 @@
from fastapi import APIRouter, Request
from app.models.config import TrainingConfig
router = APIRouter()
@router.get("/config", response_model=TrainingConfig)
async def get_training_config(request: Request):
"""Retrieves the current training configuration"""
config_manager = request.app.state.config_manager
return await config_manager.get_config()

View File

@@ -11,10 +11,12 @@ class Settings(BaseSettings):
SFTP_PATH: Optional[str] = None SFTP_PATH: Optional[str] = None
SFTP_PORT: int = 22 SFTP_PORT: int = 22
TRAINING_LOG_REMOTE_PATH: Optional[str] = None TRAINING_LOG_REMOTE_PATH: Optional[str] = None
TRAINING_CONFIG_LOCAL_PATH: Optional[str] = None
# Local Settings (Optional) # Local Settings (Optional)
LOCAL_PATH: Optional[str] = None LOCAL_PATH: Optional[str] = None
TRAINING_LOG_LOCAL_PATH: Optional[str] = None TRAINING_LOG_LOCAL_PATH: Optional[str] = None
TRAINING_CONFIG_REMOTE_PATH: Optional[str] = None
# API Settings # API Settings
PROJECT_NAME: str = "Training Monitor" PROJECT_NAME: str = "Training Monitor"

View File

@@ -6,8 +6,9 @@ import psutil
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from app.api.routes import training, samples from app.api.routes import training, samples, config
from app.core.config import settings from app.core.config import settings
from app.services.config_manager import ConfigManager
from app.services.sample_manager import SampleManager from app.services.sample_manager import SampleManager
from app.services.training_monitor import TrainingMonitor from app.services.training_monitor import TrainingMonitor
@@ -35,17 +36,60 @@ app.add_middleware(
# Create and store SampleManager instance # Create and store SampleManager instance
sample_manager = SampleManager() sample_manager = SampleManager()
training_monitor = TrainingMonitor() training_monitor = TrainingMonitor()
config_manager = ConfigManager()
app.state.sample_manager = sample_manager app.state.sample_manager = sample_manager
app.state.training_monitor = training_monitor app.state.training_monitor = training_monitor
app.state.config_manager = config_manager
async def initialize_services():
"""
Initializes all service managers in the correct order, ensuring dependencies
are properly set up before they're needed.
"""
logger.info("Starting services initialization...")
# First, initialize config manager as other services might need configuration
config_manager = ConfigManager()
app.state.config_manager = config_manager
try:
# Load initial configuration
config = await config_manager.get_config()
logger.info(f"Configuration loaded successfully for job: {config.job}")
# Store config in app state for easy access
app.state.training_config = config
# Initialize other managers with access to config
sample_manager = SampleManager()
training_monitor = TrainingMonitor()
# Store managers in app state
app.state.sample_manager = sample_manager
app.state.training_monitor = training_monitor
# Start the managers
await sample_manager.startup()
await training_monitor.startup()
logger.info("All services initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize services: {str(e)}")
# Re-raise to prevent app from starting with partially initialized services
raise
@app.on_event("startup") @app.on_event("startup")
async def startup_event(): async def startup_event():
"""Initialize services on startup""" """
Startup event handler that ensures all services are properly initialized
before the application starts accepting requests.
"""
logger.info("Starting up Training Monitor API") logger.info("Starting up Training Monitor API")
await sample_manager.startup() await initialize_services()
await training_monitor.startup()
@@ -60,6 +104,7 @@ async def shutdown_event():
# Include routers with versioning # Include routers with versioning
app.include_router(training.router, prefix=f"{settings.API_VER_STR}/training", tags=["training"]) app.include_router(training.router, prefix=f"{settings.API_VER_STR}/training", tags=["training"])
app.include_router(samples.router, prefix=f"{settings.API_VER_STR}/samples", tags=["samples"]) app.include_router(samples.router, prefix=f"{settings.API_VER_STR}/samples", tags=["samples"])
app.include_router(config.router, prefix=f"{settings.API_VER_STR}/config", tags=["config"])
@app.get("/") @app.get("/")

View File

@@ -0,0 +1,93 @@
from typing import List, Optional, Dict, Any
from pydantic import BaseModel
class SampleConfig(BaseModel):
sampler: str
sample_every: int
width: int
height: int
prompts: List[str]
neg: str
seed: int
walk_seed: bool
guidance_scale: float
sample_steps: int
class DatasetConfig(BaseModel):
folder_path: str
caption_ext: Optional[str] = None
caption_dropout_rate: Optional[float] = None
shuffle_tokens: Optional[bool] = False
resolution: Optional[List[int]] = None
class EMAConfig(BaseModel):
use_ema: Optional[bool] = False
ema_decay: Optional[float] = None
class TrainConfig(BaseModel):
batch_size: int
bypass_guidance_embedding: Optional[bool] = False
timestep_type: Optional[str] = None
steps: int
gradient_accumulation: Optional[int] = 1
train_unet: Optional[bool] = True
train_text_encoder: Optional[bool] = False
gradient_checkpointing: Optional[bool] = False
noise_scheduler: Optional[str] = None
optimizer: Optional[str] = None
lr: Optional[float] = None
ema_config: Optional[EMAConfig] = None
dtype: Optional[str] = None
do_paramiter_swapping: Optional[bool] = False
paramiter_swapping_factor: Optional[float] = None
skip_first_sample: Optional[bool] = False
disable_sampling: Optional[bool] = False
class ModelConfig(BaseModel):
name_or_path: str
is_flux: Optional[bool] = False
quantize: Optional[bool] = False
quantize_te: Optional[bool] = False
class SaveConfig(BaseModel):
dtype: Optional[str] = None
save_every: Optional[int] = None
max_step_saves_to_keep: Optional[int] = None
save_format: Optional[str] = None
class ProcessConfig(BaseModel):
type: str
training_folder: str
performance_log_every: Optional[int] = None
device: Optional[str] = None
trigger_word: Optional[str] = None
save: Optional[SaveConfig] = None
datasets: List[DatasetConfig]
train: TrainConfig
model: ModelConfig
sample: SampleConfig
class MetaConfig(BaseModel):
name: Optional[str] = None
version: Optional[str] = None
class TrainingConfig(BaseModel):
job: str
config: Dict[str, Any] # This will contain 'name' and 'process'
meta: MetaConfig
# And a Config class to represent the middle layer:
class Config(BaseModel):
name: str
process: List[ProcessConfig]

View File

@@ -0,0 +1,151 @@
# app/services/config_manager.py
import logging
import os
import aiofiles
import paramiko
import yaml
from fastapi import HTTPException
from app.core.config import settings
from app.models.config import TrainingConfig, ProcessConfig, SaveConfig, DatasetConfig
logger = logging.getLogger(__name__)
class ConfigManager:
"""
Manages access to training configuration files, supporting both local and remote (SFTP) paths.
Handles YAML parsing and conversion to strongly-typed configuration objects.
"""
def __init__(self):
# Initialize paths from settings, defaulting to None if not configured
self.remote_path = getattr(settings, 'TRAINING_CONFIG_REMOTE_PATH', None)
self.local_path = getattr(settings, 'TRAINING_CONFIG_LOCAL_PATH', None)
self.sftp_client = None
self.cached_config = None
# Validate that at least one path is configured
if not self.remote_path and not self.local_path:
raise ValueError("Either TRAINING_CONFIG_REMOTE_PATH or TRAINING_CONFIG_LOCAL_PATH must be configured")
logger.info(f"ConfigManager initialized with remote_path={self.remote_path}, local_path={self.local_path}")
async def _connect_sftp(self):
"""Establishes SFTP connection using SSH key authentication"""
try:
key_path = os.path.expanduser(settings.SFTP_KEY_PATH)
logger.info(f"Connecting to SFTP {settings.SFTP_HOST} with key {key_path}")
ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
ssh.connect(
hostname=settings.SFTP_HOST,
username=settings.SFTP_USER,
port=settings.SFTP_PORT,
key_filename=key_path,
)
self.sftp_client = ssh.open_sftp()
logger.info("SFTP connection established successfully")
except Exception as e:
logger.error(f"Failed to establish SFTP connection: {str(e)}")
raise HTTPException(status_code=500, detail=f"SFTP connection failed: {str(e)}")
def _disconnect_sftp(self):
"""Safely closes SFTP connection if it exists"""
if self.sftp_client:
try:
self.sftp_client.close()
self.sftp_client = None
logger.info("SFTP connection closed")
except Exception as e:
logger.error(f"Error closing SFTP connection: {str(e)}")
async def _read_remote_config(self) -> dict:
"""Reads and parses YAML configuration from remote SFTP location"""
if not self.sftp_client:
await self._connect_sftp()
try:
with self.sftp_client.open(self.remote_path, 'r') as f:
content = f.read()
return yaml.safe_load(content)
except Exception as e:
logger.error(f"Failed to read remote config: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to read remote config: {str(e)}")
finally:
self._disconnect_sftp()
async def _read_local_config(self) -> dict:
"""Reads and parses YAML configuration from local filesystem"""
try:
async with aiofiles.open(self.local_path, 'r') as f:
content = await f.read()
return yaml.safe_load(content)
except Exception as e:
logger.error(f"Failed to read local config: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to read local config: {str(e)}")
def _parse_config(self, raw_config: dict) -> TrainingConfig:
"""
Converts raw YAML dictionary into strongly-typed configuration objects.
Handles optional fields and nested configurations.
"""
try:
# Extract the first process configuration (assuming single process for now)
process_data = raw_config['config']['process'][0]
# Build the process config with all its nested components
process = ProcessConfig(
type=process_data['type'],
training_folder=process_data['training_folder'],
performance_log_every=process_data.get('performance_log_every'),
device=process_data.get('device'),
trigger_word=process_data.get('trigger_word'),
save=SaveConfig(**process_data['save']) if 'save' in process_data else None,
datasets=[DatasetConfig(**ds) for ds in process_data.get('datasets', [])],
train=process_data['train'],
model=process_data['model'],
sample=process_data['sample']
)
# Reconstruct the config dictionary with our parsed process
config_dict = dict(raw_config['config'])
config_dict['process'] = [process]
# Create the full training config
return TrainingConfig(
job=raw_config.get('job', ''),
config=config_dict,
meta=raw_config.get('meta', {})
)
except Exception as e:
logger.error(f"Failed to parse config: {str(e)}")
raise HTTPException(status_code=500, detail=f"Config parsing failed: {str(e)}")
async def get_config(self) -> TrainingConfig:
"""
Main method to retrieve and parse configuration.
Automatically handles local or remote access based on configuration.
"""
if self.cached_config is not None:
return self.cached_config
try:
# Read raw config from appropriate source
raw_config = await self._read_remote_config() if self.remote_path else await self._read_local_config()
# Parse and return strongly-typed config
parsed_config = self._parse_config(raw_config)
self.cached_config = parsed_config
return parsed_config
except Exception as e:
logger.error(f"Failed to get config: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -11,4 +11,5 @@ aiofiles>=23.2.1
pytest>=7.4.3 pytest>=7.4.3
httpx>=0.25.1 httpx>=0.25.1
pytest-asyncio>=0.21.1 pytest-asyncio>=0.21.1
psutil>=5.9.8 psutil>=5.9.8
PyYAML~=6.0.2