Files
ai-training-monitor/backend/app/services/training_monitor.py
2025-01-23 09:45:07 +01:00

162 lines
6.0 KiB
Python

# app/services/training_monitor.py
import asyncio
import logging
import os
import re
from datetime import datetime
from typing import Optional, List
import aiofiles
import paramiko
from fastapi import HTTPException
from app.core.config import settings
from app.models.training import TrainingStatus
logger = logging.getLogger(__name__)
class TrainingMonitor:
def __init__(self):
self.sftp_client = None
self._running = False
self._monitor_task = None
self.recent_logs: List[str] = [] # Store recent log lines
self.max_log_lines: int = 100 # Keep last 100 lines
self.current_status: Optional[TrainingStatus] = None
self.remote_path = settings.TRAINING_LOG_REMOTE_PATH if hasattr(settings, 'TRAINING_LOG_REMOTE_PATH') else None
self.local_path = settings.TRAINING_LOG_LOCAL_PATH if hasattr(settings, 'TRAINING_LOG_LOCAL_PATH') else None
def _parse_tqdm_line(self, line: str) -> Optional[TrainingStatus]:
"""Parse tqdm output line into TrainingStatus"""
try:
# Example: ovs_bangel_001: 0%| | 29/20000 [04:08<44:07:22, 7.95s/it, lr: 8.0e-05 loss: 3.693e-01]
pattern = r'.*?(\d+)/(\d+)\s+\[(.*?)<(.*?),\s+(.*?)s/it,\s+lr:\s+(.*?)\s+loss:\s+(.*?)\]'
match = re.search(pattern, line)
if match:
current_step = int(match.group(1))
total_steps = int(match.group(2))
elapsed_str = match.group(3)
eta_str = match.group(4)
step_time = float(match.group(5))
lr = float(match.group(6))
loss = float(match.group(7))
# Convert ETA string to seconds
eta_seconds = sum(x * y for x, y in zip(
map(float, reversed(eta_str.split(':'))),
[1, 60, 3600, 86400]
))
return TrainingStatus(
current_step=current_step,
total_steps=total_steps,
loss=loss,
learning_rate=lr,
percentage=(current_step / total_steps) * 100,
eta_seconds=eta_seconds,
steps_per_second=1 / step_time,
updated_at=datetime.now(),
source='remote' if self.remote_path else 'local',
source_path=self.remote_path or self.local_path
)
except Exception as e:
logger.error(f"Error parsing tqdm line: {str(e)}")
return None
async def _read_remote_log(self) -> str:
"""Read log file from remote SFTP"""
if not self.sftp_client:
await self._connect_sftp()
try:
with self.sftp_client.open(self.remote_path, 'rb') as f: # Note the 'rb' mode
content = f.read()
return content.decode('utf-8')
except Exception as e:
logger.error(f"Error reading remote log: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
async def _read_local_log(self) -> str:
"""Read log file from local path"""
try:
async with aiofiles.open(self.local_path, 'r') as f:
return await f.read()
except Exception as e:
logger.error(f"Error reading local log: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
async def _monitor_log(self):
"""Monitor log file for updates"""
while self._running:
try:
content = (await self._read_remote_log() if self.remote_path
else await self._read_local_log())
# Get last line containing progress info
lines = content.splitlines()
self.recent_logs = lines[-self.max_log_lines:] if lines else []
for line in reversed(lines):
if '|' in line: # Basic check for tqdm output
status = self._parse_tqdm_line(line)
if status:
self.current_status = status
break
await asyncio.sleep(1) # Check every second
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Monitor error: {str(e)}")
await asyncio.sleep(5) # Wait before retry
async def get_log(self, lines: int = 50) -> List[str]:
"""Get recent log entries"""
return self.recent_logs[-lines:]
async def startup(self):
"""Start the monitor"""
if not self.remote_path and not self.local_path:
raise ValueError("No log path configured")
self._running = True
self._monitor_task = asyncio.create_task(self._monitor_log())
logger.info("Training monitor started")
async def shutdown(self):
"""Stop the monitor"""
self._running = False
if self._monitor_task:
self._monitor_task.cancel()
try:
await self._monitor_task
except asyncio.CancelledError:
pass
logger.info("Training monitor stopped")
async def get_status(self) -> Optional[TrainingStatus]:
"""Get current training status"""
return self.current_status
# Include SFTP connection methods similar to SampleManager
async def _connect_sftp(self):
"""Create SFTP connection using SSH key"""
try:
key_path = os.path.expanduser(settings.SFTP_KEY_PATH)
ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
ssh.connect(
hostname=settings.SFTP_HOST,
username=settings.SFTP_USER,
port=settings.SFTP_PORT,
key_filename=key_path,
)
self.sftp_client = ssh.open_sftp()
except Exception as e:
logger.error(f"SFTP connection failed: {str(e)}")
raise HTTPException(status_code=500, detail=f"SFTP Connection failed: {str(e)}")