Add config parsing support in backend
Signed-off-by: Felipe Cardoso <felipe.cardoso@hotmail.it>
This commit is contained in:
151
backend/app/services/config_manager.py
Normal file
151
backend/app/services/config_manager.py
Normal file
@@ -0,0 +1,151 @@
|
||||
# app/services/config_manager.py
|
||||
import logging
|
||||
import os
|
||||
|
||||
import aiofiles
|
||||
import paramiko
|
||||
import yaml
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.core.config import settings
|
||||
from app.models.config import TrainingConfig, ProcessConfig, SaveConfig, DatasetConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConfigManager:
|
||||
"""
|
||||
Manages access to training configuration files, supporting both local and remote (SFTP) paths.
|
||||
Handles YAML parsing and conversion to strongly-typed configuration objects.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# Initialize paths from settings, defaulting to None if not configured
|
||||
self.remote_path = getattr(settings, 'TRAINING_CONFIG_REMOTE_PATH', None)
|
||||
self.local_path = getattr(settings, 'TRAINING_CONFIG_LOCAL_PATH', None)
|
||||
self.sftp_client = None
|
||||
self.cached_config = None
|
||||
|
||||
# Validate that at least one path is configured
|
||||
if not self.remote_path and not self.local_path:
|
||||
raise ValueError("Either TRAINING_CONFIG_REMOTE_PATH or TRAINING_CONFIG_LOCAL_PATH must be configured")
|
||||
|
||||
logger.info(f"ConfigManager initialized with remote_path={self.remote_path}, local_path={self.local_path}")
|
||||
|
||||
async def _connect_sftp(self):
|
||||
"""Establishes SFTP connection using SSH key authentication"""
|
||||
try:
|
||||
key_path = os.path.expanduser(settings.SFTP_KEY_PATH)
|
||||
logger.info(f"Connecting to SFTP {settings.SFTP_HOST} with key {key_path}")
|
||||
|
||||
ssh = paramiko.SSHClient()
|
||||
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
||||
|
||||
ssh.connect(
|
||||
hostname=settings.SFTP_HOST,
|
||||
username=settings.SFTP_USER,
|
||||
port=settings.SFTP_PORT,
|
||||
key_filename=key_path,
|
||||
)
|
||||
|
||||
self.sftp_client = ssh.open_sftp()
|
||||
logger.info("SFTP connection established successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to establish SFTP connection: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"SFTP connection failed: {str(e)}")
|
||||
|
||||
def _disconnect_sftp(self):
|
||||
"""Safely closes SFTP connection if it exists"""
|
||||
if self.sftp_client:
|
||||
try:
|
||||
self.sftp_client.close()
|
||||
self.sftp_client = None
|
||||
logger.info("SFTP connection closed")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing SFTP connection: {str(e)}")
|
||||
|
||||
async def _read_remote_config(self) -> dict:
|
||||
"""Reads and parses YAML configuration from remote SFTP location"""
|
||||
if not self.sftp_client:
|
||||
await self._connect_sftp()
|
||||
|
||||
try:
|
||||
with self.sftp_client.open(self.remote_path, 'r') as f:
|
||||
content = f.read()
|
||||
return yaml.safe_load(content)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read remote config: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to read remote config: {str(e)}")
|
||||
finally:
|
||||
self._disconnect_sftp()
|
||||
|
||||
async def _read_local_config(self) -> dict:
|
||||
"""Reads and parses YAML configuration from local filesystem"""
|
||||
try:
|
||||
async with aiofiles.open(self.local_path, 'r') as f:
|
||||
content = await f.read()
|
||||
return yaml.safe_load(content)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read local config: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to read local config: {str(e)}")
|
||||
|
||||
def _parse_config(self, raw_config: dict) -> TrainingConfig:
|
||||
"""
|
||||
Converts raw YAML dictionary into strongly-typed configuration objects.
|
||||
Handles optional fields and nested configurations.
|
||||
"""
|
||||
try:
|
||||
# Extract the first process configuration (assuming single process for now)
|
||||
process_data = raw_config['config']['process'][0]
|
||||
|
||||
# Build the process config with all its nested components
|
||||
process = ProcessConfig(
|
||||
type=process_data['type'],
|
||||
training_folder=process_data['training_folder'],
|
||||
performance_log_every=process_data.get('performance_log_every'),
|
||||
device=process_data.get('device'),
|
||||
trigger_word=process_data.get('trigger_word'),
|
||||
save=SaveConfig(**process_data['save']) if 'save' in process_data else None,
|
||||
datasets=[DatasetConfig(**ds) for ds in process_data.get('datasets', [])],
|
||||
train=process_data['train'],
|
||||
model=process_data['model'],
|
||||
sample=process_data['sample']
|
||||
)
|
||||
|
||||
# Reconstruct the config dictionary with our parsed process
|
||||
config_dict = dict(raw_config['config'])
|
||||
config_dict['process'] = [process]
|
||||
|
||||
# Create the full training config
|
||||
return TrainingConfig(
|
||||
job=raw_config.get('job', ''),
|
||||
config=config_dict,
|
||||
meta=raw_config.get('meta', {})
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse config: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Config parsing failed: {str(e)}")
|
||||
|
||||
async def get_config(self) -> TrainingConfig:
|
||||
"""
|
||||
Main method to retrieve and parse configuration.
|
||||
Automatically handles local or remote access based on configuration.
|
||||
"""
|
||||
if self.cached_config is not None:
|
||||
return self.cached_config
|
||||
|
||||
try:
|
||||
# Read raw config from appropriate source
|
||||
raw_config = await self._read_remote_config() if self.remote_path else await self._read_local_config()
|
||||
|
||||
# Parse and return strongly-typed config
|
||||
|
||||
parsed_config = self._parse_config(raw_config)
|
||||
self.cached_config = parsed_config
|
||||
return parsed_config
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get config: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
Reference in New Issue
Block a user