162 lines
6.0 KiB
Python
162 lines
6.0 KiB
Python
# app/services/training_monitor.py
|
|
import asyncio
|
|
import logging
|
|
import os
|
|
import re
|
|
from datetime import datetime
|
|
from typing import Optional, List
|
|
|
|
import aiofiles
|
|
import paramiko
|
|
from fastapi import HTTPException
|
|
|
|
from app.core.config import settings
|
|
from app.models.training import TrainingStatus
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TrainingMonitor:
|
|
def __init__(self):
|
|
self.sftp_client = None
|
|
self._running = False
|
|
self._monitor_task = None
|
|
self.recent_logs: List[str] = [] # Store recent log lines
|
|
self.max_log_lines: int = 100 # Keep last 100 lines
|
|
self.current_status: Optional[TrainingStatus] = None
|
|
self.remote_path = settings.TRAINING_LOG_REMOTE_PATH if hasattr(settings, 'TRAINING_LOG_REMOTE_PATH') else None
|
|
self.local_path = settings.TRAINING_LOG_LOCAL_PATH if hasattr(settings, 'TRAINING_LOG_LOCAL_PATH') else None
|
|
|
|
def _parse_tqdm_line(self, line: str) -> Optional[TrainingStatus]:
|
|
"""Parse tqdm output line into TrainingStatus"""
|
|
try:
|
|
# Example: ovs_bangel_001: 0%| | 29/20000 [04:08<44:07:22, 7.95s/it, lr: 8.0e-05 loss: 3.693e-01]
|
|
pattern = r'.*?(\d+)/(\d+)\s+\[(.*?)<(.*?),\s+(.*?)s/it,\s+lr:\s+(.*?)\s+loss:\s+(.*?)\]'
|
|
match = re.search(pattern, line)
|
|
|
|
if match:
|
|
current_step = int(match.group(1))
|
|
total_steps = int(match.group(2))
|
|
elapsed_str = match.group(3)
|
|
eta_str = match.group(4)
|
|
step_time = float(match.group(5))
|
|
lr = float(match.group(6))
|
|
loss = float(match.group(7))
|
|
|
|
# Convert ETA string to seconds
|
|
eta_seconds = sum(x * y for x, y in zip(
|
|
map(float, reversed(eta_str.split(':'))),
|
|
[1, 60, 3600, 86400]
|
|
))
|
|
|
|
return TrainingStatus(
|
|
current_step=current_step,
|
|
total_steps=total_steps,
|
|
loss=loss,
|
|
learning_rate=lr,
|
|
percentage=(current_step / total_steps) * 100,
|
|
eta_seconds=eta_seconds,
|
|
steps_per_second=1 / step_time,
|
|
updated_at=datetime.now(),
|
|
source='remote' if self.remote_path else 'local',
|
|
source_path=self.remote_path or self.local_path
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error parsing tqdm line: {str(e)}")
|
|
return None
|
|
|
|
async def _read_remote_log(self) -> str:
|
|
"""Read log file from remote SFTP"""
|
|
if not self.sftp_client:
|
|
await self._connect_sftp()
|
|
|
|
try:
|
|
with self.sftp_client.open(self.remote_path, 'rb') as f: # Note the 'rb' mode
|
|
content = f.read()
|
|
return content.decode('utf-8')
|
|
except Exception as e:
|
|
logger.error(f"Error reading remote log: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
async def _read_local_log(self) -> str:
|
|
"""Read log file from local path"""
|
|
try:
|
|
async with aiofiles.open(self.local_path, 'r') as f:
|
|
return await f.read()
|
|
except Exception as e:
|
|
logger.error(f"Error reading local log: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
async def _monitor_log(self):
|
|
"""Monitor log file for updates"""
|
|
while self._running:
|
|
try:
|
|
content = (await self._read_remote_log() if self.remote_path
|
|
else await self._read_local_log())
|
|
|
|
# Get last line containing progress info
|
|
lines = content.splitlines()
|
|
self.recent_logs = lines[-self.max_log_lines:] if lines else []
|
|
for line in reversed(lines):
|
|
if '|' in line: # Basic check for tqdm output
|
|
status = self._parse_tqdm_line(line)
|
|
if status:
|
|
self.current_status = status
|
|
break
|
|
|
|
await asyncio.sleep(1) # Check every second
|
|
|
|
except asyncio.CancelledError:
|
|
break
|
|
except Exception as e:
|
|
logger.error(f"Monitor error: {str(e)}")
|
|
await asyncio.sleep(5) # Wait before retry
|
|
|
|
async def get_log(self, lines: int = 50) -> List[str]:
|
|
"""Get recent log entries"""
|
|
return self.recent_logs[-lines:]
|
|
|
|
async def startup(self):
|
|
"""Start the monitor"""
|
|
if not self.remote_path and not self.local_path:
|
|
raise ValueError("No log path configured")
|
|
|
|
self._running = True
|
|
self._monitor_task = asyncio.create_task(self._monitor_log())
|
|
logger.info("Training monitor started")
|
|
|
|
async def shutdown(self):
|
|
"""Stop the monitor"""
|
|
self._running = False
|
|
if self._monitor_task:
|
|
self._monitor_task.cancel()
|
|
try:
|
|
await self._monitor_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
logger.info("Training monitor stopped")
|
|
|
|
async def get_status(self) -> Optional[TrainingStatus]:
|
|
"""Get current training status"""
|
|
return self.current_status
|
|
|
|
# Include SFTP connection methods similar to SampleManager
|
|
async def _connect_sftp(self):
|
|
"""Create SFTP connection using SSH key"""
|
|
try:
|
|
key_path = os.path.expanduser(settings.SFTP_KEY_PATH)
|
|
ssh = paramiko.SSHClient()
|
|
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
|
|
|
ssh.connect(
|
|
hostname=settings.SFTP_HOST,
|
|
username=settings.SFTP_USER,
|
|
port=settings.SFTP_PORT,
|
|
key_filename=key_path,
|
|
)
|
|
|
|
self.sftp_client = ssh.open_sftp()
|
|
except Exception as e:
|
|
logger.error(f"SFTP connection failed: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=f"SFTP Connection failed: {str(e)}")
|