Add training monitor implementation
Signed-off-by: Felipe Cardoso <felipe.cardoso@hotmail.it>
This commit is contained in:
@@ -1,13 +1,161 @@
|
||||
# app/services/training_monitor.py
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
|
||||
import aiofiles
|
||||
import paramiko
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.core.config import settings
|
||||
from app.models.training import TrainingStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TrainingMonitor:
|
||||
async def get_status(self) -> TrainingStatus:
|
||||
# Implementation for parsing tqdm output
|
||||
# This is a placeholder - actual implementation needed
|
||||
pass
|
||||
def __init__(self):
|
||||
self.sftp_client = None
|
||||
self._running = False
|
||||
self._monitor_task = None
|
||||
self.recent_logs: List[str] = [] # Store recent log lines
|
||||
self.max_log_lines: int = 100 # Keep last 100 lines
|
||||
self.current_status: Optional[TrainingStatus] = None
|
||||
self.remote_path = settings.TRAINING_LOG_REMOTE_PATH if hasattr(settings, 'TRAINING_LOG_REMOTE_PATH') else None
|
||||
self.local_path = settings.TRAINING_LOG_LOCAL_PATH if hasattr(settings, 'TRAINING_LOG_LOCAL_PATH') else None
|
||||
|
||||
async def get_log(self):
|
||||
# Implementation for getting recent log entries
|
||||
# This is a placeholder - actual implementation needed
|
||||
pass
|
||||
def _parse_tqdm_line(self, line: str) -> Optional[TrainingStatus]:
|
||||
"""Parse tqdm output line into TrainingStatus"""
|
||||
try:
|
||||
# Example: ovs_bangel_001: 0%| | 29/20000 [04:08<44:07:22, 7.95s/it, lr: 8.0e-05 loss: 3.693e-01]
|
||||
pattern = r'.*?(\d+)/(\d+)\s+\[(.*?)<(.*?),\s+(.*?)s/it,\s+lr:\s+(.*?)\s+loss:\s+(.*?)\]'
|
||||
match = re.search(pattern, line)
|
||||
|
||||
if match:
|
||||
current_step = int(match.group(1))
|
||||
total_steps = int(match.group(2))
|
||||
elapsed_str = match.group(3)
|
||||
eta_str = match.group(4)
|
||||
step_time = float(match.group(5))
|
||||
lr = float(match.group(6))
|
||||
loss = float(match.group(7))
|
||||
|
||||
# Convert ETA string to seconds
|
||||
eta_seconds = sum(x * y for x, y in zip(
|
||||
map(float, reversed(eta_str.split(':'))),
|
||||
[1, 60, 3600, 86400]
|
||||
))
|
||||
|
||||
return TrainingStatus(
|
||||
current_step=current_step,
|
||||
total_steps=total_steps,
|
||||
loss=loss,
|
||||
learning_rate=lr,
|
||||
percentage=(current_step / total_steps) * 100,
|
||||
eta_seconds=eta_seconds,
|
||||
steps_per_second=1 / step_time,
|
||||
updated_at=datetime.now(),
|
||||
source='remote' if self.remote_path else 'local',
|
||||
source_path=self.remote_path or self.local_path
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing tqdm line: {str(e)}")
|
||||
return None
|
||||
|
||||
async def _read_remote_log(self) -> str:
|
||||
"""Read log file from remote SFTP"""
|
||||
if not self.sftp_client:
|
||||
await self._connect_sftp()
|
||||
|
||||
try:
|
||||
with self.sftp_client.open(self.remote_path, 'rb') as f: # Note the 'rb' mode
|
||||
content = f.read()
|
||||
return content.decode('utf-8')
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading remote log: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
async def _read_local_log(self) -> str:
|
||||
"""Read log file from local path"""
|
||||
try:
|
||||
async with aiofiles.open(self.local_path, 'r') as f:
|
||||
return await f.read()
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading local log: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
async def _monitor_log(self):
|
||||
"""Monitor log file for updates"""
|
||||
while self._running:
|
||||
try:
|
||||
content = (await self._read_remote_log() if self.remote_path
|
||||
else await self._read_local_log())
|
||||
|
||||
# Get last line containing progress info
|
||||
lines = content.splitlines()
|
||||
self.recent_logs = lines[-self.max_log_lines:] if lines else []
|
||||
for line in reversed(lines):
|
||||
if '|' in line: # Basic check for tqdm output
|
||||
status = self._parse_tqdm_line(line)
|
||||
if status:
|
||||
self.current_status = status
|
||||
break
|
||||
|
||||
await asyncio.sleep(1) # Check every second
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Monitor error: {str(e)}")
|
||||
await asyncio.sleep(5) # Wait before retry
|
||||
|
||||
async def get_log(self, lines: int = 50) -> List[str]:
|
||||
"""Get recent log entries"""
|
||||
return self.recent_logs[-lines:]
|
||||
|
||||
async def startup(self):
|
||||
"""Start the monitor"""
|
||||
if not self.remote_path and not self.local_path:
|
||||
raise ValueError("No log path configured")
|
||||
|
||||
self._running = True
|
||||
self._monitor_task = asyncio.create_task(self._monitor_log())
|
||||
logger.info("Training monitor started")
|
||||
|
||||
async def shutdown(self):
|
||||
"""Stop the monitor"""
|
||||
self._running = False
|
||||
if self._monitor_task:
|
||||
self._monitor_task.cancel()
|
||||
try:
|
||||
await self._monitor_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("Training monitor stopped")
|
||||
|
||||
async def get_status(self) -> Optional[TrainingStatus]:
|
||||
"""Get current training status"""
|
||||
return self.current_status
|
||||
|
||||
# Include SFTP connection methods similar to SampleManager
|
||||
async def _connect_sftp(self):
|
||||
"""Create SFTP connection using SSH key"""
|
||||
try:
|
||||
key_path = os.path.expanduser(settings.SFTP_KEY_PATH)
|
||||
ssh = paramiko.SSHClient()
|
||||
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
||||
|
||||
ssh.connect(
|
||||
hostname=settings.SFTP_HOST,
|
||||
username=settings.SFTP_USER,
|
||||
port=settings.SFTP_PORT,
|
||||
key_filename=key_path,
|
||||
)
|
||||
|
||||
self.sftp_client = ssh.open_sftp()
|
||||
except Exception as e:
|
||||
logger.error(f"SFTP connection failed: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"SFTP Connection failed: {str(e)}")
|
||||
|
||||
Reference in New Issue
Block a user