Files
ai-training-monitor/backend/app/services/training_monitor.py
2025-01-23 09:49:50 +01:00

234 lines
8.2 KiB
Python

# app/services/training_monitor.py
import asyncio
import logging
import os
import re
from datetime import datetime
from typing import Optional, List
import aiofiles
import paramiko
from fastapi import HTTPException
from app.core.config import settings
from app.models.training import TrainingStatus
logger = logging.getLogger(__name__)
class TrainingMonitor:
def __init__(self):
self.sftp_client = None
self._running = False
self._monitor_task = None
self.recent_logs: List[str] = []
self.max_log_lines: int = 500
self.current_status: Optional[TrainingStatus] = None
self.remote_path = settings.TRAINING_LOG_REMOTE_PATH if hasattr(settings, 'TRAINING_LOG_REMOTE_PATH') else None
self.local_path = settings.TRAINING_LOG_LOCAL_PATH if hasattr(settings, 'TRAINING_LOG_LOCAL_PATH') else None
self._file_handle = None
self._last_position = 0
def _parse_tqdm_line(self, line: str) -> Optional[TrainingStatus]:
"""Parse tqdm output line into TrainingStatus"""
try:
# Example: ovs_bangel_001: 0%| | 29/20000 [04:08<44:07:22, 7.95s/it, lr: 8.0e-05 loss: 3.693e-01]
pattern = r'.*?(\d+)/(\d+)\s+\[(.*?)<(.*?),\s+(.*?)s/it,\s+lr:\s+(.*?)\s+loss:\s+(.*?)\]'
match = re.search(pattern, line)
if match:
current_step = int(match.group(1))
total_steps = int(match.group(2))
elapsed_str = match.group(3)
eta_str = match.group(4)
step_time = float(match.group(5))
lr = float(match.group(6))
loss = float(match.group(7))
# Convert ETA string to seconds
eta_seconds = sum(x * y for x, y in zip(
map(float, reversed(eta_str.split(':'))),
[1, 60, 3600, 86400]
))
return TrainingStatus(
current_step=current_step,
total_steps=total_steps,
loss=loss,
learning_rate=lr,
percentage=(current_step / total_steps) * 100,
eta_seconds=eta_seconds,
steps_per_second=1 / step_time,
updated_at=datetime.now(),
source='remote' if self.remote_path else 'local',
source_path=self.remote_path or self.local_path
)
except Exception as e:
logger.error(f"Error parsing tqdm line: {str(e)}")
return None
async def _read_remote_log(self) -> str:
"""Read log file from remote SFTP"""
if not self.sftp_client:
await self._connect_sftp()
try:
with self.sftp_client.open(self.remote_path, 'rb') as f: # Note the 'rb' mode
content = f.read()
return content.decode('utf-8')
except Exception as e:
logger.error(f"Error reading remote log: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
async def _read_local_log(self) -> str:
"""Read log file from local path"""
try:
async with aiofiles.open(self.local_path, 'r') as f:
return await f.read()
except Exception as e:
logger.error(f"Error reading local log: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
async def _open_log_file(self):
"""Open and maintain file handle"""
if self.remote_path:
if not self.sftp_client:
await self._connect_sftp()
self._file_handle = self.sftp_client.open(self.remote_path, 'rb')
else:
self._file_handle = await aiofiles.open(self.local_path, 'rb')
async def _read_new_content(self) -> str:
"""Read only new content since last read"""
if not self._file_handle:
await self._open_log_file()
try:
# Get file size
if self.remote_path:
self._file_handle.seek(0, 2) # Seek to end
file_size = self._file_handle.tell()
else:
file_size = os.path.getsize(self.local_path)
if file_size < self._last_position:
# File has been truncated/rotated
logger.info("Log file has been truncated, reading from start")
self._last_position = 0
# Seek to last position
self._file_handle.seek(self._last_position)
# Read new content
new_content = self._file_handle.read()
if isinstance(new_content, bytes):
new_content = new_content.decode('utf-8')
# Update position
self._last_position = file_size
return new_content
except Exception as e:
logger.error(f"Error reading log: {str(e)}")
# Try to reopen the file on error
await self._reopen_log_file()
return ""
async def _reopen_log_file(self):
"""Reopen file handle in case of errors"""
try:
if self._file_handle:
self._file_handle.close()
except Exception:
pass
self._file_handle = None
await self._open_log_file()
async def _monitor_log(self):
"""Monitor log file for updates"""
while self._running:
try:
new_content = await self._read_new_content()
if new_content:
# Process new lines
new_lines = new_content.splitlines()
if new_lines:
# Update recent logs
self.recent_logs.extend(new_lines)
self.recent_logs = self.recent_logs[-self.max_log_lines:]
# Update status from last progress line
for line in reversed(new_lines):
if '|' in line:
status = self._parse_tqdm_line(line)
if status:
self.current_status = status
break
await asyncio.sleep(1)
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Monitor error: {str(e)}")
await asyncio.sleep(5)
async def get_log(self, lines: int = 50) -> List[str]:
"""Get recent log entries"""
return self.recent_logs[-lines:]
async def startup(self):
"""Start the monitor"""
if not self.remote_path and not self.local_path:
raise ValueError("No log path configured")
self._running = True
self._monitor_task = asyncio.create_task(self._monitor_log())
logger.info("Training monitor started")
async def shutdown(self):
"""Stop the monitor"""
self._running = False
if self._monitor_task:
self._monitor_task.cancel()
try:
await self._monitor_task
except asyncio.CancelledError:
pass
# Close file handle
if self._file_handle:
try:
self._file_handle.close()
except Exception:
pass
self._file_handle = None
logger.info("Training monitor stopped")
async def get_status(self) -> Optional[TrainingStatus]:
"""Get current training status"""
return self.current_status
# Include SFTP connection methods similar to SampleManager
async def _connect_sftp(self):
"""Create SFTP connection using SSH key"""
try:
key_path = os.path.expanduser(settings.SFTP_KEY_PATH)
ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
ssh.connect(
hostname=settings.SFTP_HOST,
username=settings.SFTP_USER,
port=settings.SFTP_PORT,
key_filename=key_path,
)
self.sftp_client = ssh.open_sftp()
except Exception as e:
logger.error(f"SFTP connection failed: {str(e)}")
raise HTTPException(status_code=500, detail=f"SFTP Connection failed: {str(e)}")