152 lines
6.0 KiB
Python
152 lines
6.0 KiB
Python
# app/services/config_manager.py
|
|
import logging
|
|
import os
|
|
|
|
import aiofiles
|
|
import paramiko
|
|
import yaml
|
|
from fastapi import HTTPException
|
|
|
|
from app.core.config import settings
|
|
from app.models.config import TrainingConfig, ProcessConfig, SaveConfig, DatasetConfig
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ConfigManager:
|
|
"""
|
|
Manages access to training configuration files, supporting both local and remote (SFTP) paths.
|
|
Handles YAML parsing and conversion to strongly-typed configuration objects.
|
|
"""
|
|
|
|
def __init__(self):
|
|
# Initialize paths from settings, defaulting to None if not configured
|
|
self.remote_path = getattr(settings, 'TRAINING_CONFIG_REMOTE_PATH', None)
|
|
self.local_path = getattr(settings, 'TRAINING_CONFIG_LOCAL_PATH', None)
|
|
self.sftp_client = None
|
|
self.cached_config = None
|
|
|
|
# Validate that at least one path is configured
|
|
if not self.remote_path and not self.local_path:
|
|
raise ValueError("Either TRAINING_CONFIG_REMOTE_PATH or TRAINING_CONFIG_LOCAL_PATH must be configured")
|
|
|
|
logger.info(f"ConfigManager initialized with remote_path={self.remote_path}, local_path={self.local_path}")
|
|
|
|
async def _connect_sftp(self):
|
|
"""Establishes SFTP connection using SSH key authentication"""
|
|
try:
|
|
key_path = os.path.expanduser(settings.SFTP_KEY_PATH)
|
|
logger.info(f"Connecting to SFTP {settings.SFTP_HOST} with key {key_path}")
|
|
|
|
ssh = paramiko.SSHClient()
|
|
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
|
|
|
ssh.connect(
|
|
hostname=settings.SFTP_HOST,
|
|
username=settings.SFTP_USER,
|
|
port=settings.SFTP_PORT,
|
|
key_filename=key_path,
|
|
)
|
|
|
|
self.sftp_client = ssh.open_sftp()
|
|
logger.info("SFTP connection established successfully")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to establish SFTP connection: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=f"SFTP connection failed: {str(e)}")
|
|
|
|
def _disconnect_sftp(self):
|
|
"""Safely closes SFTP connection if it exists"""
|
|
if self.sftp_client:
|
|
try:
|
|
self.sftp_client.close()
|
|
self.sftp_client = None
|
|
logger.info("SFTP connection closed")
|
|
except Exception as e:
|
|
logger.error(f"Error closing SFTP connection: {str(e)}")
|
|
|
|
async def _read_remote_config(self) -> dict:
|
|
"""Reads and parses YAML configuration from remote SFTP location"""
|
|
if not self.sftp_client:
|
|
await self._connect_sftp()
|
|
|
|
try:
|
|
with self.sftp_client.open(self.remote_path, 'r') as f:
|
|
content = f.read()
|
|
return yaml.safe_load(content)
|
|
except Exception as e:
|
|
logger.error(f"Failed to read remote config: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=f"Failed to read remote config: {str(e)}")
|
|
finally:
|
|
self._disconnect_sftp()
|
|
|
|
async def _read_local_config(self) -> dict:
|
|
"""Reads and parses YAML configuration from local filesystem"""
|
|
try:
|
|
async with aiofiles.open(self.local_path, 'r') as f:
|
|
content = await f.read()
|
|
return yaml.safe_load(content)
|
|
except Exception as e:
|
|
logger.error(f"Failed to read local config: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=f"Failed to read local config: {str(e)}")
|
|
|
|
def _parse_config(self, raw_config: dict) -> TrainingConfig:
|
|
"""
|
|
Converts raw YAML dictionary into strongly-typed configuration objects.
|
|
Handles optional fields and nested configurations.
|
|
"""
|
|
try:
|
|
# Extract the first process configuration (assuming single process for now)
|
|
process_data = raw_config['config']['process'][0]
|
|
|
|
# Build the process config with all its nested components
|
|
process = ProcessConfig(
|
|
type=process_data['type'],
|
|
training_folder=process_data['training_folder'],
|
|
performance_log_every=process_data.get('performance_log_every'),
|
|
device=process_data.get('device'),
|
|
trigger_word=process_data.get('trigger_word'),
|
|
save=SaveConfig(**process_data['save']) if 'save' in process_data else None,
|
|
datasets=[DatasetConfig(**ds) for ds in process_data.get('datasets', [])],
|
|
train=process_data['train'],
|
|
model=process_data['model'],
|
|
sample=process_data['sample']
|
|
)
|
|
|
|
# Reconstruct the config dictionary with our parsed process
|
|
config_dict = dict(raw_config['config'])
|
|
config_dict['process'] = [process]
|
|
|
|
# Create the full training config
|
|
return TrainingConfig(
|
|
job=raw_config.get('job', ''),
|
|
config=config_dict,
|
|
meta=raw_config.get('meta', {})
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to parse config: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=f"Config parsing failed: {str(e)}")
|
|
|
|
async def get_config(self) -> TrainingConfig:
|
|
"""
|
|
Main method to retrieve and parse configuration.
|
|
Automatically handles local or remote access based on configuration.
|
|
"""
|
|
if self.cached_config is not None:
|
|
return self.cached_config
|
|
|
|
try:
|
|
# Read raw config from appropriate source
|
|
raw_config = await self._read_remote_config() if self.remote_path else await self._read_local_config()
|
|
|
|
# Parse and return strongly-typed config
|
|
|
|
parsed_config = self._parse_config(raw_config)
|
|
self.cached_config = parsed_config
|
|
return parsed_config
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to get config: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|