From 1c4d78e916a8772d15e4d06d314e10a14519acac Mon Sep 17 00:00:00 2001 From: Felipe Cardoso Date: Thu, 23 Jan 2025 09:45:07 +0100 Subject: [PATCH] Add training monitor implementation Signed-off-by: Felipe Cardoso --- backend/app/api/routes/training.py | 10 +- backend/app/core/config.py | 2 + backend/app/main.py | 8 ++ backend/app/models/training.py | 3 + backend/app/services/training_monitor.py | 164 +++++++++++++++++++++-- 5 files changed, 174 insertions(+), 13 deletions(-) diff --git a/backend/app/api/routes/training.py b/backend/app/api/routes/training.py index bcf8b17..4ef2c57 100644 --- a/backend/app/api/routes/training.py +++ b/backend/app/api/routes/training.py @@ -1,29 +1,29 @@ -from fastapi import APIRouter, HTTPException +from fastapi import APIRouter, HTTPException, Request from app.models.training import TrainingStatus -from app.services.training_monitor import TrainingMonitor router = APIRouter() -monitor = TrainingMonitor() @router.get("/status", response_model=TrainingStatus) -async def get_training_status(): +async def get_training_status(request: Request): """ Get current training status including progress, loss, and learning rate """ try: + monitor = request.app.state.training_monitor return await monitor.get_status() except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.get("/log") -async def get_training_log(): +async def get_training_log(request: Request): """ Get recent training log entries """ try: + monitor = request.app.state.training_monitor return await monitor.get_log() except Exception as e: raise HTTPException(status_code=500, detail=str(e)) diff --git a/backend/app/core/config.py b/backend/app/core/config.py index f5360de..25ec48d 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -10,9 +10,11 @@ class Settings(BaseSettings): SFTP_KEY_PATH: Optional[str] = "~/.ssh/id_rsa" SFTP_PATH: Optional[str] = None SFTP_PORT: int = 22 + TRAINING_LOG_REMOTE_PATH: Optional[str] = None # Local Settings (Optional) LOCAL_PATH: Optional[str] = None + TRAINING_LOG_LOCAL_PATH: Optional[str] = None # API Settings PROJECT_NAME: str = "Training Monitor" diff --git a/backend/app/main.py b/backend/app/main.py index e2831bd..7f4e60c 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -9,6 +9,7 @@ from fastapi.middleware.cors import CORSMiddleware from app.api.routes import training, samples from app.core.config import settings from app.services.sample_manager import SampleManager +from app.services.training_monitor import TrainingMonitor # Configure logging logging.basicConfig( @@ -33,7 +34,10 @@ app.add_middleware( # Create and store SampleManager instance sample_manager = SampleManager() +training_monitor = TrainingMonitor() + app.state.sample_manager = sample_manager +app.state.training_monitor = training_monitor @app.on_event("startup") @@ -41,6 +45,8 @@ async def startup_event(): """Initialize services on startup""" logger.info("Starting up Training Monitor API") await sample_manager.startup() + await training_monitor.startup() + @app.on_event("shutdown") @@ -48,6 +54,8 @@ async def shutdown_event(): """Cleanup on shutdown""" logger.info("Shutting down Training Monitor API") await sample_manager.shutdown() + await training_monitor.shutdown() + # Include routers with versioning app.include_router(training.router, prefix=f"{settings.API_VER_STR}/training", tags=["training"]) diff --git a/backend/app/models/training.py b/backend/app/models/training.py index 41342b6..92a5a33 100644 --- a/backend/app/models/training.py +++ b/backend/app/models/training.py @@ -9,6 +9,9 @@ class TrainingStatus(BaseModel): total_steps: int loss: float learning_rate: float + percentage: float eta_seconds: Optional[float] steps_per_second: float updated_at: datetime + source: str # 'local' or 'remote' + source_path: str diff --git a/backend/app/services/training_monitor.py b/backend/app/services/training_monitor.py index 0ebf278..5f741eb 100644 --- a/backend/app/services/training_monitor.py +++ b/backend/app/services/training_monitor.py @@ -1,13 +1,161 @@ +# app/services/training_monitor.py +import asyncio +import logging +import os +import re +from datetime import datetime +from typing import Optional, List + +import aiofiles +import paramiko +from fastapi import HTTPException + +from app.core.config import settings from app.models.training import TrainingStatus +logger = logging.getLogger(__name__) + class TrainingMonitor: - async def get_status(self) -> TrainingStatus: - # Implementation for parsing tqdm output - # This is a placeholder - actual implementation needed - pass + def __init__(self): + self.sftp_client = None + self._running = False + self._monitor_task = None + self.recent_logs: List[str] = [] # Store recent log lines + self.max_log_lines: int = 100 # Keep last 100 lines + self.current_status: Optional[TrainingStatus] = None + self.remote_path = settings.TRAINING_LOG_REMOTE_PATH if hasattr(settings, 'TRAINING_LOG_REMOTE_PATH') else None + self.local_path = settings.TRAINING_LOG_LOCAL_PATH if hasattr(settings, 'TRAINING_LOG_LOCAL_PATH') else None - async def get_log(self): - # Implementation for getting recent log entries - # This is a placeholder - actual implementation needed - pass + def _parse_tqdm_line(self, line: str) -> Optional[TrainingStatus]: + """Parse tqdm output line into TrainingStatus""" + try: + # Example: ovs_bangel_001: 0%| | 29/20000 [04:08<44:07:22, 7.95s/it, lr: 8.0e-05 loss: 3.693e-01] + pattern = r'.*?(\d+)/(\d+)\s+\[(.*?)<(.*?),\s+(.*?)s/it,\s+lr:\s+(.*?)\s+loss:\s+(.*?)\]' + match = re.search(pattern, line) + + if match: + current_step = int(match.group(1)) + total_steps = int(match.group(2)) + elapsed_str = match.group(3) + eta_str = match.group(4) + step_time = float(match.group(5)) + lr = float(match.group(6)) + loss = float(match.group(7)) + + # Convert ETA string to seconds + eta_seconds = sum(x * y for x, y in zip( + map(float, reversed(eta_str.split(':'))), + [1, 60, 3600, 86400] + )) + + return TrainingStatus( + current_step=current_step, + total_steps=total_steps, + loss=loss, + learning_rate=lr, + percentage=(current_step / total_steps) * 100, + eta_seconds=eta_seconds, + steps_per_second=1 / step_time, + updated_at=datetime.now(), + source='remote' if self.remote_path else 'local', + source_path=self.remote_path or self.local_path + ) + except Exception as e: + logger.error(f"Error parsing tqdm line: {str(e)}") + return None + + async def _read_remote_log(self) -> str: + """Read log file from remote SFTP""" + if not self.sftp_client: + await self._connect_sftp() + + try: + with self.sftp_client.open(self.remote_path, 'rb') as f: # Note the 'rb' mode + content = f.read() + return content.decode('utf-8') + except Exception as e: + logger.error(f"Error reading remote log: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) + + async def _read_local_log(self) -> str: + """Read log file from local path""" + try: + async with aiofiles.open(self.local_path, 'r') as f: + return await f.read() + except Exception as e: + logger.error(f"Error reading local log: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) + + async def _monitor_log(self): + """Monitor log file for updates""" + while self._running: + try: + content = (await self._read_remote_log() if self.remote_path + else await self._read_local_log()) + + # Get last line containing progress info + lines = content.splitlines() + self.recent_logs = lines[-self.max_log_lines:] if lines else [] + for line in reversed(lines): + if '|' in line: # Basic check for tqdm output + status = self._parse_tqdm_line(line) + if status: + self.current_status = status + break + + await asyncio.sleep(1) # Check every second + + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"Monitor error: {str(e)}") + await asyncio.sleep(5) # Wait before retry + + async def get_log(self, lines: int = 50) -> List[str]: + """Get recent log entries""" + return self.recent_logs[-lines:] + + async def startup(self): + """Start the monitor""" + if not self.remote_path and not self.local_path: + raise ValueError("No log path configured") + + self._running = True + self._monitor_task = asyncio.create_task(self._monitor_log()) + logger.info("Training monitor started") + + async def shutdown(self): + """Stop the monitor""" + self._running = False + if self._monitor_task: + self._monitor_task.cancel() + try: + await self._monitor_task + except asyncio.CancelledError: + pass + logger.info("Training monitor stopped") + + async def get_status(self) -> Optional[TrainingStatus]: + """Get current training status""" + return self.current_status + + # Include SFTP connection methods similar to SampleManager + async def _connect_sftp(self): + """Create SFTP connection using SSH key""" + try: + key_path = os.path.expanduser(settings.SFTP_KEY_PATH) + ssh = paramiko.SSHClient() + ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + + ssh.connect( + hostname=settings.SFTP_HOST, + username=settings.SFTP_USER, + port=settings.SFTP_PORT, + key_filename=key_path, + ) + + self.sftp_client = ssh.open_sftp() + except Exception as e: + logger.error(f"SFTP connection failed: {str(e)}") + raise HTTPException(status_code=500, detail=f"SFTP Connection failed: {str(e)}")