Add training monitor implementation

Signed-off-by: Felipe Cardoso <felipe.cardoso@hotmail.it>
This commit is contained in:
2025-01-23 09:45:07 +01:00
parent 36ce6ac5ef
commit 1c4d78e916
5 changed files with 174 additions and 13 deletions

View File

@@ -1,29 +1,29 @@
from fastapi import APIRouter, HTTPException
from fastapi import APIRouter, HTTPException, Request
from app.models.training import TrainingStatus
from app.services.training_monitor import TrainingMonitor
router = APIRouter()
monitor = TrainingMonitor()
@router.get("/status", response_model=TrainingStatus)
async def get_training_status():
async def get_training_status(request: Request):
"""
Get current training status including progress, loss, and learning rate
"""
try:
monitor = request.app.state.training_monitor
return await monitor.get_status()
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/log")
async def get_training_log():
async def get_training_log(request: Request):
"""
Get recent training log entries
"""
try:
monitor = request.app.state.training_monitor
return await monitor.get_log()
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -10,9 +10,11 @@ class Settings(BaseSettings):
SFTP_KEY_PATH: Optional[str] = "~/.ssh/id_rsa"
SFTP_PATH: Optional[str] = None
SFTP_PORT: int = 22
TRAINING_LOG_REMOTE_PATH: Optional[str] = None
# Local Settings (Optional)
LOCAL_PATH: Optional[str] = None
TRAINING_LOG_LOCAL_PATH: Optional[str] = None
# API Settings
PROJECT_NAME: str = "Training Monitor"

View File

@@ -9,6 +9,7 @@ from fastapi.middleware.cors import CORSMiddleware
from app.api.routes import training, samples
from app.core.config import settings
from app.services.sample_manager import SampleManager
from app.services.training_monitor import TrainingMonitor
# Configure logging
logging.basicConfig(
@@ -33,7 +34,10 @@ app.add_middleware(
# Create and store SampleManager instance
sample_manager = SampleManager()
training_monitor = TrainingMonitor()
app.state.sample_manager = sample_manager
app.state.training_monitor = training_monitor
@app.on_event("startup")
@@ -41,6 +45,8 @@ async def startup_event():
"""Initialize services on startup"""
logger.info("Starting up Training Monitor API")
await sample_manager.startup()
await training_monitor.startup()
@app.on_event("shutdown")
@@ -48,6 +54,8 @@ async def shutdown_event():
"""Cleanup on shutdown"""
logger.info("Shutting down Training Monitor API")
await sample_manager.shutdown()
await training_monitor.shutdown()
# Include routers with versioning
app.include_router(training.router, prefix=f"{settings.API_VER_STR}/training", tags=["training"])

View File

@@ -9,6 +9,9 @@ class TrainingStatus(BaseModel):
total_steps: int
loss: float
learning_rate: float
percentage: float
eta_seconds: Optional[float]
steps_per_second: float
updated_at: datetime
source: str # 'local' or 'remote'
source_path: str

View File

@@ -1,13 +1,161 @@
# app/services/training_monitor.py
import asyncio
import logging
import os
import re
from datetime import datetime
from typing import Optional, List
import aiofiles
import paramiko
from fastapi import HTTPException
from app.core.config import settings
from app.models.training import TrainingStatus
logger = logging.getLogger(__name__)
class TrainingMonitor:
async def get_status(self) -> TrainingStatus:
# Implementation for parsing tqdm output
# This is a placeholder - actual implementation needed
pass
def __init__(self):
self.sftp_client = None
self._running = False
self._monitor_task = None
self.recent_logs: List[str] = [] # Store recent log lines
self.max_log_lines: int = 100 # Keep last 100 lines
self.current_status: Optional[TrainingStatus] = None
self.remote_path = settings.TRAINING_LOG_REMOTE_PATH if hasattr(settings, 'TRAINING_LOG_REMOTE_PATH') else None
self.local_path = settings.TRAINING_LOG_LOCAL_PATH if hasattr(settings, 'TRAINING_LOG_LOCAL_PATH') else None
async def get_log(self):
# Implementation for getting recent log entries
# This is a placeholder - actual implementation needed
pass
def _parse_tqdm_line(self, line: str) -> Optional[TrainingStatus]:
"""Parse tqdm output line into TrainingStatus"""
try:
# Example: ovs_bangel_001: 0%| | 29/20000 [04:08<44:07:22, 7.95s/it, lr: 8.0e-05 loss: 3.693e-01]
pattern = r'.*?(\d+)/(\d+)\s+\[(.*?)<(.*?),\s+(.*?)s/it,\s+lr:\s+(.*?)\s+loss:\s+(.*?)\]'
match = re.search(pattern, line)
if match:
current_step = int(match.group(1))
total_steps = int(match.group(2))
elapsed_str = match.group(3)
eta_str = match.group(4)
step_time = float(match.group(5))
lr = float(match.group(6))
loss = float(match.group(7))
# Convert ETA string to seconds
eta_seconds = sum(x * y for x, y in zip(
map(float, reversed(eta_str.split(':'))),
[1, 60, 3600, 86400]
))
return TrainingStatus(
current_step=current_step,
total_steps=total_steps,
loss=loss,
learning_rate=lr,
percentage=(current_step / total_steps) * 100,
eta_seconds=eta_seconds,
steps_per_second=1 / step_time,
updated_at=datetime.now(),
source='remote' if self.remote_path else 'local',
source_path=self.remote_path or self.local_path
)
except Exception as e:
logger.error(f"Error parsing tqdm line: {str(e)}")
return None
async def _read_remote_log(self) -> str:
"""Read log file from remote SFTP"""
if not self.sftp_client:
await self._connect_sftp()
try:
with self.sftp_client.open(self.remote_path, 'rb') as f: # Note the 'rb' mode
content = f.read()
return content.decode('utf-8')
except Exception as e:
logger.error(f"Error reading remote log: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
async def _read_local_log(self) -> str:
"""Read log file from local path"""
try:
async with aiofiles.open(self.local_path, 'r') as f:
return await f.read()
except Exception as e:
logger.error(f"Error reading local log: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
async def _monitor_log(self):
"""Monitor log file for updates"""
while self._running:
try:
content = (await self._read_remote_log() if self.remote_path
else await self._read_local_log())
# Get last line containing progress info
lines = content.splitlines()
self.recent_logs = lines[-self.max_log_lines:] if lines else []
for line in reversed(lines):
if '|' in line: # Basic check for tqdm output
status = self._parse_tqdm_line(line)
if status:
self.current_status = status
break
await asyncio.sleep(1) # Check every second
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Monitor error: {str(e)}")
await asyncio.sleep(5) # Wait before retry
async def get_log(self, lines: int = 50) -> List[str]:
"""Get recent log entries"""
return self.recent_logs[-lines:]
async def startup(self):
"""Start the monitor"""
if not self.remote_path and not self.local_path:
raise ValueError("No log path configured")
self._running = True
self._monitor_task = asyncio.create_task(self._monitor_log())
logger.info("Training monitor started")
async def shutdown(self):
"""Stop the monitor"""
self._running = False
if self._monitor_task:
self._monitor_task.cancel()
try:
await self._monitor_task
except asyncio.CancelledError:
pass
logger.info("Training monitor stopped")
async def get_status(self) -> Optional[TrainingStatus]:
"""Get current training status"""
return self.current_status
# Include SFTP connection methods similar to SampleManager
async def _connect_sftp(self):
"""Create SFTP connection using SSH key"""
try:
key_path = os.path.expanduser(settings.SFTP_KEY_PATH)
ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
ssh.connect(
hostname=settings.SFTP_HOST,
username=settings.SFTP_USER,
port=settings.SFTP_PORT,
key_filename=key_path,
)
self.sftp_client = ssh.open_sftp()
except Exception as e:
logger.error(f"SFTP connection failed: {str(e)}")
raise HTTPException(status_code=500, detail=f"SFTP Connection failed: {str(e)}")