Add training monitor implementation
Signed-off-by: Felipe Cardoso <felipe.cardoso@hotmail.it>
This commit is contained in:
@@ -1,29 +1,29 @@
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
|
||||
from app.models.training import TrainingStatus
|
||||
from app.services.training_monitor import TrainingMonitor
|
||||
|
||||
router = APIRouter()
|
||||
monitor = TrainingMonitor()
|
||||
|
||||
|
||||
@router.get("/status", response_model=TrainingStatus)
|
||||
async def get_training_status():
|
||||
async def get_training_status(request: Request):
|
||||
"""
|
||||
Get current training status including progress, loss, and learning rate
|
||||
"""
|
||||
try:
|
||||
monitor = request.app.state.training_monitor
|
||||
return await monitor.get_status()
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/log")
|
||||
async def get_training_log():
|
||||
async def get_training_log(request: Request):
|
||||
"""
|
||||
Get recent training log entries
|
||||
"""
|
||||
try:
|
||||
monitor = request.app.state.training_monitor
|
||||
return await monitor.get_log()
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -10,9 +10,11 @@ class Settings(BaseSettings):
|
||||
SFTP_KEY_PATH: Optional[str] = "~/.ssh/id_rsa"
|
||||
SFTP_PATH: Optional[str] = None
|
||||
SFTP_PORT: int = 22
|
||||
TRAINING_LOG_REMOTE_PATH: Optional[str] = None
|
||||
|
||||
# Local Settings (Optional)
|
||||
LOCAL_PATH: Optional[str] = None
|
||||
TRAINING_LOG_LOCAL_PATH: Optional[str] = None
|
||||
|
||||
# API Settings
|
||||
PROJECT_NAME: str = "Training Monitor"
|
||||
|
||||
@@ -9,6 +9,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
from app.api.routes import training, samples
|
||||
from app.core.config import settings
|
||||
from app.services.sample_manager import SampleManager
|
||||
from app.services.training_monitor import TrainingMonitor
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
@@ -33,7 +34,10 @@ app.add_middleware(
|
||||
|
||||
# Create and store SampleManager instance
|
||||
sample_manager = SampleManager()
|
||||
training_monitor = TrainingMonitor()
|
||||
|
||||
app.state.sample_manager = sample_manager
|
||||
app.state.training_monitor = training_monitor
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
@@ -41,6 +45,8 @@ async def startup_event():
|
||||
"""Initialize services on startup"""
|
||||
logger.info("Starting up Training Monitor API")
|
||||
await sample_manager.startup()
|
||||
await training_monitor.startup()
|
||||
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
@@ -48,6 +54,8 @@ async def shutdown_event():
|
||||
"""Cleanup on shutdown"""
|
||||
logger.info("Shutting down Training Monitor API")
|
||||
await sample_manager.shutdown()
|
||||
await training_monitor.shutdown()
|
||||
|
||||
|
||||
# Include routers with versioning
|
||||
app.include_router(training.router, prefix=f"{settings.API_VER_STR}/training", tags=["training"])
|
||||
|
||||
@@ -9,6 +9,9 @@ class TrainingStatus(BaseModel):
|
||||
total_steps: int
|
||||
loss: float
|
||||
learning_rate: float
|
||||
percentage: float
|
||||
eta_seconds: Optional[float]
|
||||
steps_per_second: float
|
||||
updated_at: datetime
|
||||
source: str # 'local' or 'remote'
|
||||
source_path: str
|
||||
|
||||
@@ -1,13 +1,161 @@
|
||||
# app/services/training_monitor.py
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
|
||||
import aiofiles
|
||||
import paramiko
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.core.config import settings
|
||||
from app.models.training import TrainingStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TrainingMonitor:
|
||||
async def get_status(self) -> TrainingStatus:
|
||||
# Implementation for parsing tqdm output
|
||||
# This is a placeholder - actual implementation needed
|
||||
pass
|
||||
def __init__(self):
|
||||
self.sftp_client = None
|
||||
self._running = False
|
||||
self._monitor_task = None
|
||||
self.recent_logs: List[str] = [] # Store recent log lines
|
||||
self.max_log_lines: int = 100 # Keep last 100 lines
|
||||
self.current_status: Optional[TrainingStatus] = None
|
||||
self.remote_path = settings.TRAINING_LOG_REMOTE_PATH if hasattr(settings, 'TRAINING_LOG_REMOTE_PATH') else None
|
||||
self.local_path = settings.TRAINING_LOG_LOCAL_PATH if hasattr(settings, 'TRAINING_LOG_LOCAL_PATH') else None
|
||||
|
||||
async def get_log(self):
|
||||
# Implementation for getting recent log entries
|
||||
# This is a placeholder - actual implementation needed
|
||||
pass
|
||||
def _parse_tqdm_line(self, line: str) -> Optional[TrainingStatus]:
|
||||
"""Parse tqdm output line into TrainingStatus"""
|
||||
try:
|
||||
# Example: ovs_bangel_001: 0%| | 29/20000 [04:08<44:07:22, 7.95s/it, lr: 8.0e-05 loss: 3.693e-01]
|
||||
pattern = r'.*?(\d+)/(\d+)\s+\[(.*?)<(.*?),\s+(.*?)s/it,\s+lr:\s+(.*?)\s+loss:\s+(.*?)\]'
|
||||
match = re.search(pattern, line)
|
||||
|
||||
if match:
|
||||
current_step = int(match.group(1))
|
||||
total_steps = int(match.group(2))
|
||||
elapsed_str = match.group(3)
|
||||
eta_str = match.group(4)
|
||||
step_time = float(match.group(5))
|
||||
lr = float(match.group(6))
|
||||
loss = float(match.group(7))
|
||||
|
||||
# Convert ETA string to seconds
|
||||
eta_seconds = sum(x * y for x, y in zip(
|
||||
map(float, reversed(eta_str.split(':'))),
|
||||
[1, 60, 3600, 86400]
|
||||
))
|
||||
|
||||
return TrainingStatus(
|
||||
current_step=current_step,
|
||||
total_steps=total_steps,
|
||||
loss=loss,
|
||||
learning_rate=lr,
|
||||
percentage=(current_step / total_steps) * 100,
|
||||
eta_seconds=eta_seconds,
|
||||
steps_per_second=1 / step_time,
|
||||
updated_at=datetime.now(),
|
||||
source='remote' if self.remote_path else 'local',
|
||||
source_path=self.remote_path or self.local_path
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing tqdm line: {str(e)}")
|
||||
return None
|
||||
|
||||
async def _read_remote_log(self) -> str:
|
||||
"""Read log file from remote SFTP"""
|
||||
if not self.sftp_client:
|
||||
await self._connect_sftp()
|
||||
|
||||
try:
|
||||
with self.sftp_client.open(self.remote_path, 'rb') as f: # Note the 'rb' mode
|
||||
content = f.read()
|
||||
return content.decode('utf-8')
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading remote log: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
async def _read_local_log(self) -> str:
|
||||
"""Read log file from local path"""
|
||||
try:
|
||||
async with aiofiles.open(self.local_path, 'r') as f:
|
||||
return await f.read()
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading local log: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
async def _monitor_log(self):
|
||||
"""Monitor log file for updates"""
|
||||
while self._running:
|
||||
try:
|
||||
content = (await self._read_remote_log() if self.remote_path
|
||||
else await self._read_local_log())
|
||||
|
||||
# Get last line containing progress info
|
||||
lines = content.splitlines()
|
||||
self.recent_logs = lines[-self.max_log_lines:] if lines else []
|
||||
for line in reversed(lines):
|
||||
if '|' in line: # Basic check for tqdm output
|
||||
status = self._parse_tqdm_line(line)
|
||||
if status:
|
||||
self.current_status = status
|
||||
break
|
||||
|
||||
await asyncio.sleep(1) # Check every second
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Monitor error: {str(e)}")
|
||||
await asyncio.sleep(5) # Wait before retry
|
||||
|
||||
async def get_log(self, lines: int = 50) -> List[str]:
|
||||
"""Get recent log entries"""
|
||||
return self.recent_logs[-lines:]
|
||||
|
||||
async def startup(self):
|
||||
"""Start the monitor"""
|
||||
if not self.remote_path and not self.local_path:
|
||||
raise ValueError("No log path configured")
|
||||
|
||||
self._running = True
|
||||
self._monitor_task = asyncio.create_task(self._monitor_log())
|
||||
logger.info("Training monitor started")
|
||||
|
||||
async def shutdown(self):
|
||||
"""Stop the monitor"""
|
||||
self._running = False
|
||||
if self._monitor_task:
|
||||
self._monitor_task.cancel()
|
||||
try:
|
||||
await self._monitor_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("Training monitor stopped")
|
||||
|
||||
async def get_status(self) -> Optional[TrainingStatus]:
|
||||
"""Get current training status"""
|
||||
return self.current_status
|
||||
|
||||
# Include SFTP connection methods similar to SampleManager
|
||||
async def _connect_sftp(self):
|
||||
"""Create SFTP connection using SSH key"""
|
||||
try:
|
||||
key_path = os.path.expanduser(settings.SFTP_KEY_PATH)
|
||||
ssh = paramiko.SSHClient()
|
||||
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
||||
|
||||
ssh.connect(
|
||||
hostname=settings.SFTP_HOST,
|
||||
username=settings.SFTP_USER,
|
||||
port=settings.SFTP_PORT,
|
||||
key_filename=key_path,
|
||||
)
|
||||
|
||||
self.sftp_client = ssh.open_sftp()
|
||||
except Exception as e:
|
||||
logger.error(f"SFTP connection failed: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"SFTP Connection failed: {str(e)}")
|
||||
|
||||
Reference in New Issue
Block a user