# app/services/training_monitor.py import asyncio import logging import os import re from datetime import datetime from typing import Optional, List import aiofiles import paramiko from fastapi import HTTPException from app.core.config import settings from app.models.training import TrainingStatus logger = logging.getLogger(__name__) class TrainingMonitor: def __init__(self): self.sftp_client = None self._running = False self._monitor_task = None self.recent_logs: List[str] = [] self.max_log_lines: int = 500 self.current_status: Optional[TrainingStatus] = None self.remote_path = settings.TRAINING_LOG_REMOTE_PATH if hasattr(settings, 'TRAINING_LOG_REMOTE_PATH') else None self.local_path = settings.TRAINING_LOG_LOCAL_PATH if hasattr(settings, 'TRAINING_LOG_LOCAL_PATH') else None self._file_handle = None self._last_position = 0 def _parse_tqdm_line(self, line: str) -> Optional[TrainingStatus]: """Parse tqdm output line into TrainingStatus""" try: # Example: ovs_bangel_001: 0%| | 29/20000 [04:08<44:07:22, 7.95s/it, lr: 8.0e-05 loss: 3.693e-01] pattern = r'.*?(\d+)/(\d+)\s+\[(.*?)<(.*?),\s+(.*?)s/it,\s+lr:\s+(.*?)\s+loss:\s+(.*?)\]' match = re.search(pattern, line) if match: current_step = int(match.group(1)) total_steps = int(match.group(2)) elapsed_str = match.group(3) eta_str = match.group(4) step_time = float(match.group(5)) lr = float(match.group(6)) loss = float(match.group(7)) # Convert ETA string to seconds eta_seconds = sum(x * y for x, y in zip( map(float, reversed(eta_str.split(':'))), [1, 60, 3600, 86400] )) return TrainingStatus( current_step=current_step, total_steps=total_steps, loss=loss, learning_rate=lr, percentage=(current_step / total_steps) * 100, eta_seconds=eta_seconds, steps_per_second=1 / step_time, updated_at=datetime.now(), source='remote' if self.remote_path else 'local', source_path=self.remote_path or self.local_path ) except Exception as e: logger.error(f"Error parsing tqdm line: {str(e)}") return None async def _read_remote_log(self) -> str: """Read log file from remote SFTP""" if not self.sftp_client: await self._connect_sftp() try: with self.sftp_client.open(self.remote_path, 'rb') as f: # Note the 'rb' mode content = f.read() return content.decode('utf-8') except Exception as e: logger.error(f"Error reading remote log: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) async def _read_local_log(self) -> str: """Read log file from local path""" try: async with aiofiles.open(self.local_path, 'r') as f: return await f.read() except Exception as e: logger.error(f"Error reading local log: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) async def _open_log_file(self): """Open and maintain file handle""" if self.remote_path: if not self.sftp_client: await self._connect_sftp() self._file_handle = self.sftp_client.open(self.remote_path, 'rb') else: self._file_handle = await aiofiles.open(self.local_path, 'rb') async def _read_new_content(self) -> str: """Read only new content since last read""" if not self._file_handle: await self._open_log_file() try: # Get file size if self.remote_path: self._file_handle.seek(0, 2) # Seek to end file_size = self._file_handle.tell() else: file_size = os.path.getsize(self.local_path) if file_size < self._last_position: # File has been truncated/rotated logger.info("Log file has been truncated, reading from start") self._last_position = 0 # Seek to last position self._file_handle.seek(self._last_position) # Read new content new_content = self._file_handle.read() if isinstance(new_content, bytes): new_content = new_content.decode('utf-8') # Update position self._last_position = file_size return new_content except Exception as e: logger.error(f"Error reading log: {str(e)}") # Try to reopen the file on error await self._reopen_log_file() return "" async def _reopen_log_file(self): """Reopen file handle in case of errors""" try: if self._file_handle: self._file_handle.close() except Exception: pass self._file_handle = None await self._open_log_file() async def _monitor_log(self): """Monitor log file for updates""" while self._running: try: new_content = await self._read_new_content() if new_content: # Process new lines new_lines = new_content.splitlines() if new_lines: # Update recent logs self.recent_logs.extend(new_lines) self.recent_logs = self.recent_logs[-self.max_log_lines:] # Update status from last progress line for line in reversed(new_lines): if '|' in line: status = self._parse_tqdm_line(line) if status: self.current_status = status break await asyncio.sleep(1) except asyncio.CancelledError: break except Exception as e: logger.error(f"Monitor error: {str(e)}") await asyncio.sleep(5) async def get_log(self, lines: int = 100) -> List[str]: """Get recent log entries""" return self.recent_logs[-lines:] async def startup(self): """Start the monitor""" if not self.remote_path and not self.local_path: raise ValueError("No log path configured") self._running = True self._monitor_task = asyncio.create_task(self._monitor_log()) logger.info("Training monitor started") async def shutdown(self): """Stop the monitor""" self._running = False if self._monitor_task: self._monitor_task.cancel() try: await self._monitor_task except asyncio.CancelledError: pass # Close file handle if self._file_handle: try: self._file_handle.close() except Exception: pass self._file_handle = None logger.info("Training monitor stopped") async def get_status(self) -> Optional[TrainingStatus]: """Get current training status""" return self.current_status # Include SFTP connection methods similar to SampleManager async def _connect_sftp(self): """Create SFTP connection using SSH key""" try: key_path = os.path.expanduser(settings.SFTP_KEY_PATH) ssh = paramiko.SSHClient() ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) ssh.connect( hostname=settings.SFTP_HOST, username=settings.SFTP_USER, port=settings.SFTP_PORT, key_filename=key_path, ) self.sftp_client = ssh.open_sftp() except Exception as e: logger.error(f"SFTP connection failed: {str(e)}") raise HTTPException(status_code=500, detail=f"SFTP Connection failed: {str(e)}")