# app/services/config_manager.py import logging import os import aiofiles import paramiko import yaml from fastapi import HTTPException from app.core.config import settings from app.models.config import TrainingConfig, ProcessConfig, SaveConfig, DatasetConfig logger = logging.getLogger(__name__) class ConfigManager: """ Manages access to training configuration files, supporting both local and remote (SFTP) paths. Handles YAML parsing and conversion to strongly-typed configuration objects. """ def __init__(self): # Initialize paths from settings, defaulting to None if not configured self.remote_path = getattr(settings, 'TRAINING_CONFIG_REMOTE_PATH', None) self.local_path = getattr(settings, 'TRAINING_CONFIG_LOCAL_PATH', None) self.sftp_client = None self.cached_config = None # Validate that at least one path is configured if not self.remote_path and not self.local_path: raise ValueError("Either TRAINING_CONFIG_REMOTE_PATH or TRAINING_CONFIG_LOCAL_PATH must be configured") logger.info(f"ConfigManager initialized with remote_path={self.remote_path}, local_path={self.local_path}") async def _connect_sftp(self): """Establishes SFTP connection using SSH key authentication""" try: key_path = os.path.expanduser(settings.SFTP_KEY_PATH) logger.info(f"Connecting to SFTP {settings.SFTP_HOST} with key {key_path}") ssh = paramiko.SSHClient() ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) ssh.connect( hostname=settings.SFTP_HOST, username=settings.SFTP_USER, port=settings.SFTP_PORT, key_filename=key_path, ) self.sftp_client = ssh.open_sftp() logger.info("SFTP connection established successfully") except Exception as e: logger.error(f"Failed to establish SFTP connection: {str(e)}") raise HTTPException(status_code=500, detail=f"SFTP connection failed: {str(e)}") def _disconnect_sftp(self): """Safely closes SFTP connection if it exists""" if self.sftp_client: try: self.sftp_client.close() self.sftp_client = None logger.info("SFTP connection closed") except Exception as e: logger.error(f"Error closing SFTP connection: {str(e)}") async def _read_remote_config(self) -> dict: """Reads and parses YAML configuration from remote SFTP location""" if not self.sftp_client: await self._connect_sftp() try: with self.sftp_client.open(self.remote_path, 'r') as f: content = f.read() return yaml.safe_load(content) except Exception as e: logger.error(f"Failed to read remote config: {str(e)}") raise HTTPException(status_code=500, detail=f"Failed to read remote config: {str(e)}") finally: self._disconnect_sftp() async def _read_local_config(self) -> dict: """Reads and parses YAML configuration from local filesystem""" try: async with aiofiles.open(self.local_path, 'r') as f: content = await f.read() return yaml.safe_load(content) except Exception as e: logger.error(f"Failed to read local config: {str(e)}") raise HTTPException(status_code=500, detail=f"Failed to read local config: {str(e)}") def _parse_config(self, raw_config: dict) -> TrainingConfig: """ Converts raw YAML dictionary into strongly-typed configuration objects. Handles optional fields and nested configurations. """ try: # Extract the first process configuration (assuming single process for now) process_data = raw_config['config']['process'][0] # Build the process config with all its nested components process = ProcessConfig( type=process_data['type'], training_folder=process_data['training_folder'], performance_log_every=process_data.get('performance_log_every'), device=process_data.get('device'), trigger_word=process_data.get('trigger_word'), save=SaveConfig(**process_data['save']) if 'save' in process_data else None, datasets=[DatasetConfig(**ds) for ds in process_data.get('datasets', [])], train=process_data['train'], model=process_data['model'], sample=process_data['sample'] ) # Reconstruct the config dictionary with our parsed process config_dict = dict(raw_config['config']) config_dict['process'] = [process] # Create the full training config return TrainingConfig( job=raw_config.get('job', ''), config=config_dict, meta=raw_config.get('meta', {}) ) except Exception as e: logger.error(f"Failed to parse config: {str(e)}") raise HTTPException(status_code=500, detail=f"Config parsing failed: {str(e)}") async def get_config(self) -> TrainingConfig: """ Main method to retrieve and parse configuration. Automatically handles local or remote access based on configuration. """ if self.cached_config is not None: return self.cached_config try: # Read raw config from appropriate source raw_config = await self._read_remote_config() if self.remote_path else await self._read_local_config() # Parse and return strongly-typed config parsed_config = self._parse_config(raw_config) self.cached_config = parsed_config return parsed_config except Exception as e: logger.error(f"Failed to get config: {str(e)}") raise HTTPException(status_code=500, detail=str(e))