234 lines
8.2 KiB
Python
234 lines
8.2 KiB
Python
# app/services/training_monitor.py
|
|
import asyncio
|
|
import logging
|
|
import os
|
|
import re
|
|
from datetime import datetime
|
|
from typing import Optional, List
|
|
|
|
import aiofiles
|
|
import paramiko
|
|
from fastapi import HTTPException
|
|
|
|
from app.core.config import settings
|
|
from app.models.training import TrainingStatus
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TrainingMonitor:
|
|
def __init__(self):
|
|
self.sftp_client = None
|
|
self._running = False
|
|
self._monitor_task = None
|
|
self.recent_logs: List[str] = []
|
|
self.max_log_lines: int = 500
|
|
self.current_status: Optional[TrainingStatus] = None
|
|
self.remote_path = settings.TRAINING_LOG_REMOTE_PATH if hasattr(settings, 'TRAINING_LOG_REMOTE_PATH') else None
|
|
self.local_path = settings.TRAINING_LOG_LOCAL_PATH if hasattr(settings, 'TRAINING_LOG_LOCAL_PATH') else None
|
|
self._file_handle = None
|
|
self._last_position = 0
|
|
|
|
def _parse_tqdm_line(self, line: str) -> Optional[TrainingStatus]:
|
|
"""Parse tqdm output line into TrainingStatus"""
|
|
try:
|
|
# Example: ovs_bangel_001: 0%| | 29/20000 [04:08<44:07:22, 7.95s/it, lr: 8.0e-05 loss: 3.693e-01]
|
|
pattern = r'.*?(\d+)/(\d+)\s+\[(.*?)<(.*?),\s+(.*?)s/it,\s+lr:\s+(.*?)\s+loss:\s+(.*?)\]'
|
|
match = re.search(pattern, line)
|
|
|
|
if match:
|
|
current_step = int(match.group(1))
|
|
total_steps = int(match.group(2))
|
|
elapsed_str = match.group(3)
|
|
eta_str = match.group(4)
|
|
step_time = float(match.group(5))
|
|
lr = float(match.group(6))
|
|
loss = float(match.group(7))
|
|
|
|
# Convert ETA string to seconds
|
|
eta_seconds = sum(x * y for x, y in zip(
|
|
map(float, reversed(eta_str.split(':'))),
|
|
[1, 60, 3600, 86400]
|
|
))
|
|
|
|
return TrainingStatus(
|
|
current_step=current_step,
|
|
total_steps=total_steps,
|
|
loss=loss,
|
|
learning_rate=lr,
|
|
percentage=(current_step / total_steps) * 100,
|
|
eta_seconds=eta_seconds,
|
|
steps_per_second=1 / step_time,
|
|
updated_at=datetime.now(),
|
|
source='remote' if self.remote_path else 'local',
|
|
source_path=self.remote_path or self.local_path
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error parsing tqdm line: {str(e)}")
|
|
return None
|
|
|
|
async def _read_remote_log(self) -> str:
|
|
"""Read log file from remote SFTP"""
|
|
if not self.sftp_client:
|
|
await self._connect_sftp()
|
|
|
|
try:
|
|
with self.sftp_client.open(self.remote_path, 'rb') as f: # Note the 'rb' mode
|
|
content = f.read()
|
|
return content.decode('utf-8')
|
|
except Exception as e:
|
|
logger.error(f"Error reading remote log: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
async def _read_local_log(self) -> str:
|
|
"""Read log file from local path"""
|
|
try:
|
|
async with aiofiles.open(self.local_path, 'r') as f:
|
|
return await f.read()
|
|
except Exception as e:
|
|
logger.error(f"Error reading local log: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
async def _open_log_file(self):
|
|
"""Open and maintain file handle"""
|
|
if self.remote_path:
|
|
if not self.sftp_client:
|
|
await self._connect_sftp()
|
|
self._file_handle = self.sftp_client.open(self.remote_path, 'rb')
|
|
else:
|
|
self._file_handle = await aiofiles.open(self.local_path, 'rb')
|
|
|
|
async def _read_new_content(self) -> str:
|
|
"""Read only new content since last read"""
|
|
if not self._file_handle:
|
|
await self._open_log_file()
|
|
|
|
try:
|
|
# Get file size
|
|
if self.remote_path:
|
|
self._file_handle.seek(0, 2) # Seek to end
|
|
file_size = self._file_handle.tell()
|
|
else:
|
|
file_size = os.path.getsize(self.local_path)
|
|
|
|
if file_size < self._last_position:
|
|
# File has been truncated/rotated
|
|
logger.info("Log file has been truncated, reading from start")
|
|
self._last_position = 0
|
|
|
|
# Seek to last position
|
|
self._file_handle.seek(self._last_position)
|
|
|
|
# Read new content
|
|
new_content = self._file_handle.read()
|
|
if isinstance(new_content, bytes):
|
|
new_content = new_content.decode('utf-8')
|
|
|
|
# Update position
|
|
self._last_position = file_size
|
|
|
|
return new_content
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error reading log: {str(e)}")
|
|
# Try to reopen the file on error
|
|
await self._reopen_log_file()
|
|
return ""
|
|
|
|
async def _reopen_log_file(self):
|
|
"""Reopen file handle in case of errors"""
|
|
try:
|
|
if self._file_handle:
|
|
self._file_handle.close()
|
|
except Exception:
|
|
pass
|
|
self._file_handle = None
|
|
await self._open_log_file()
|
|
|
|
|
|
async def _monitor_log(self):
|
|
"""Monitor log file for updates"""
|
|
while self._running:
|
|
try:
|
|
new_content = await self._read_new_content()
|
|
if new_content:
|
|
# Process new lines
|
|
new_lines = new_content.splitlines()
|
|
if new_lines:
|
|
# Update recent logs
|
|
self.recent_logs.extend(new_lines)
|
|
self.recent_logs = self.recent_logs[-self.max_log_lines:]
|
|
|
|
# Update status from last progress line
|
|
for line in reversed(new_lines):
|
|
if '|' in line:
|
|
status = self._parse_tqdm_line(line)
|
|
if status:
|
|
self.current_status = status
|
|
break
|
|
|
|
await asyncio.sleep(1)
|
|
|
|
except asyncio.CancelledError:
|
|
break
|
|
except Exception as e:
|
|
logger.error(f"Monitor error: {str(e)}")
|
|
await asyncio.sleep(5)
|
|
|
|
async def get_log(self, lines: int = 50) -> List[str]:
|
|
"""Get recent log entries"""
|
|
return self.recent_logs[-lines:]
|
|
|
|
async def startup(self):
|
|
"""Start the monitor"""
|
|
if not self.remote_path and not self.local_path:
|
|
raise ValueError("No log path configured")
|
|
|
|
self._running = True
|
|
self._monitor_task = asyncio.create_task(self._monitor_log())
|
|
logger.info("Training monitor started")
|
|
|
|
async def shutdown(self):
|
|
"""Stop the monitor"""
|
|
self._running = False
|
|
if self._monitor_task:
|
|
self._monitor_task.cancel()
|
|
try:
|
|
await self._monitor_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
# Close file handle
|
|
if self._file_handle:
|
|
try:
|
|
self._file_handle.close()
|
|
except Exception:
|
|
pass
|
|
self._file_handle = None
|
|
|
|
logger.info("Training monitor stopped")
|
|
|
|
async def get_status(self) -> Optional[TrainingStatus]:
|
|
"""Get current training status"""
|
|
return self.current_status
|
|
|
|
# Include SFTP connection methods similar to SampleManager
|
|
async def _connect_sftp(self):
|
|
"""Create SFTP connection using SSH key"""
|
|
try:
|
|
key_path = os.path.expanduser(settings.SFTP_KEY_PATH)
|
|
ssh = paramiko.SSHClient()
|
|
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
|
|
|
ssh.connect(
|
|
hostname=settings.SFTP_HOST,
|
|
username=settings.SFTP_USER,
|
|
port=settings.SFTP_PORT,
|
|
key_filename=key_path,
|
|
)
|
|
|
|
self.sftp_client = ssh.open_sftp()
|
|
except Exception as e:
|
|
logger.error(f"SFTP connection failed: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=f"SFTP Connection failed: {str(e)}")
|