Add training monitor implementation
Signed-off-by: Felipe Cardoso <felipe.cardoso@hotmail.it>
This commit is contained in:
@@ -1,29 +1,29 @@
|
|||||||
from fastapi import APIRouter, HTTPException
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
|
|
||||||
from app.models.training import TrainingStatus
|
from app.models.training import TrainingStatus
|
||||||
from app.services.training_monitor import TrainingMonitor
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
monitor = TrainingMonitor()
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/status", response_model=TrainingStatus)
|
@router.get("/status", response_model=TrainingStatus)
|
||||||
async def get_training_status():
|
async def get_training_status(request: Request):
|
||||||
"""
|
"""
|
||||||
Get current training status including progress, loss, and learning rate
|
Get current training status including progress, loss, and learning rate
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
monitor = request.app.state.training_monitor
|
||||||
return await monitor.get_status()
|
return await monitor.get_status()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
@router.get("/log")
|
@router.get("/log")
|
||||||
async def get_training_log():
|
async def get_training_log(request: Request):
|
||||||
"""
|
"""
|
||||||
Get recent training log entries
|
Get recent training log entries
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
monitor = request.app.state.training_monitor
|
||||||
return await monitor.get_log()
|
return await monitor.get_log()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|||||||
@@ -10,9 +10,11 @@ class Settings(BaseSettings):
|
|||||||
SFTP_KEY_PATH: Optional[str] = "~/.ssh/id_rsa"
|
SFTP_KEY_PATH: Optional[str] = "~/.ssh/id_rsa"
|
||||||
SFTP_PATH: Optional[str] = None
|
SFTP_PATH: Optional[str] = None
|
||||||
SFTP_PORT: int = 22
|
SFTP_PORT: int = 22
|
||||||
|
TRAINING_LOG_REMOTE_PATH: Optional[str] = None
|
||||||
|
|
||||||
# Local Settings (Optional)
|
# Local Settings (Optional)
|
||||||
LOCAL_PATH: Optional[str] = None
|
LOCAL_PATH: Optional[str] = None
|
||||||
|
TRAINING_LOG_LOCAL_PATH: Optional[str] = None
|
||||||
|
|
||||||
# API Settings
|
# API Settings
|
||||||
PROJECT_NAME: str = "Training Monitor"
|
PROJECT_NAME: str = "Training Monitor"
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||||||
from app.api.routes import training, samples
|
from app.api.routes import training, samples
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.services.sample_manager import SampleManager
|
from app.services.sample_manager import SampleManager
|
||||||
|
from app.services.training_monitor import TrainingMonitor
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
@@ -33,7 +34,10 @@ app.add_middleware(
|
|||||||
|
|
||||||
# Create and store SampleManager instance
|
# Create and store SampleManager instance
|
||||||
sample_manager = SampleManager()
|
sample_manager = SampleManager()
|
||||||
|
training_monitor = TrainingMonitor()
|
||||||
|
|
||||||
app.state.sample_manager = sample_manager
|
app.state.sample_manager = sample_manager
|
||||||
|
app.state.training_monitor = training_monitor
|
||||||
|
|
||||||
|
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
@@ -41,6 +45,8 @@ async def startup_event():
|
|||||||
"""Initialize services on startup"""
|
"""Initialize services on startup"""
|
||||||
logger.info("Starting up Training Monitor API")
|
logger.info("Starting up Training Monitor API")
|
||||||
await sample_manager.startup()
|
await sample_manager.startup()
|
||||||
|
await training_monitor.startup()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@app.on_event("shutdown")
|
@app.on_event("shutdown")
|
||||||
@@ -48,6 +54,8 @@ async def shutdown_event():
|
|||||||
"""Cleanup on shutdown"""
|
"""Cleanup on shutdown"""
|
||||||
logger.info("Shutting down Training Monitor API")
|
logger.info("Shutting down Training Monitor API")
|
||||||
await sample_manager.shutdown()
|
await sample_manager.shutdown()
|
||||||
|
await training_monitor.shutdown()
|
||||||
|
|
||||||
|
|
||||||
# Include routers with versioning
|
# Include routers with versioning
|
||||||
app.include_router(training.router, prefix=f"{settings.API_VER_STR}/training", tags=["training"])
|
app.include_router(training.router, prefix=f"{settings.API_VER_STR}/training", tags=["training"])
|
||||||
|
|||||||
@@ -9,6 +9,9 @@ class TrainingStatus(BaseModel):
|
|||||||
total_steps: int
|
total_steps: int
|
||||||
loss: float
|
loss: float
|
||||||
learning_rate: float
|
learning_rate: float
|
||||||
|
percentage: float
|
||||||
eta_seconds: Optional[float]
|
eta_seconds: Optional[float]
|
||||||
steps_per_second: float
|
steps_per_second: float
|
||||||
updated_at: datetime
|
updated_at: datetime
|
||||||
|
source: str # 'local' or 'remote'
|
||||||
|
source_path: str
|
||||||
|
|||||||
@@ -1,13 +1,161 @@
|
|||||||
|
# app/services/training_monitor.py
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
|
import aiofiles
|
||||||
|
import paramiko
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
from app.models.training import TrainingStatus
|
from app.models.training import TrainingStatus
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TrainingMonitor:
|
class TrainingMonitor:
|
||||||
async def get_status(self) -> TrainingStatus:
|
def __init__(self):
|
||||||
# Implementation for parsing tqdm output
|
self.sftp_client = None
|
||||||
# This is a placeholder - actual implementation needed
|
self._running = False
|
||||||
pass
|
self._monitor_task = None
|
||||||
|
self.recent_logs: List[str] = [] # Store recent log lines
|
||||||
|
self.max_log_lines: int = 100 # Keep last 100 lines
|
||||||
|
self.current_status: Optional[TrainingStatus] = None
|
||||||
|
self.remote_path = settings.TRAINING_LOG_REMOTE_PATH if hasattr(settings, 'TRAINING_LOG_REMOTE_PATH') else None
|
||||||
|
self.local_path = settings.TRAINING_LOG_LOCAL_PATH if hasattr(settings, 'TRAINING_LOG_LOCAL_PATH') else None
|
||||||
|
|
||||||
async def get_log(self):
|
def _parse_tqdm_line(self, line: str) -> Optional[TrainingStatus]:
|
||||||
# Implementation for getting recent log entries
|
"""Parse tqdm output line into TrainingStatus"""
|
||||||
# This is a placeholder - actual implementation needed
|
try:
|
||||||
pass
|
# Example: ovs_bangel_001: 0%| | 29/20000 [04:08<44:07:22, 7.95s/it, lr: 8.0e-05 loss: 3.693e-01]
|
||||||
|
pattern = r'.*?(\d+)/(\d+)\s+\[(.*?)<(.*?),\s+(.*?)s/it,\s+lr:\s+(.*?)\s+loss:\s+(.*?)\]'
|
||||||
|
match = re.search(pattern, line)
|
||||||
|
|
||||||
|
if match:
|
||||||
|
current_step = int(match.group(1))
|
||||||
|
total_steps = int(match.group(2))
|
||||||
|
elapsed_str = match.group(3)
|
||||||
|
eta_str = match.group(4)
|
||||||
|
step_time = float(match.group(5))
|
||||||
|
lr = float(match.group(6))
|
||||||
|
loss = float(match.group(7))
|
||||||
|
|
||||||
|
# Convert ETA string to seconds
|
||||||
|
eta_seconds = sum(x * y for x, y in zip(
|
||||||
|
map(float, reversed(eta_str.split(':'))),
|
||||||
|
[1, 60, 3600, 86400]
|
||||||
|
))
|
||||||
|
|
||||||
|
return TrainingStatus(
|
||||||
|
current_step=current_step,
|
||||||
|
total_steps=total_steps,
|
||||||
|
loss=loss,
|
||||||
|
learning_rate=lr,
|
||||||
|
percentage=(current_step / total_steps) * 100,
|
||||||
|
eta_seconds=eta_seconds,
|
||||||
|
steps_per_second=1 / step_time,
|
||||||
|
updated_at=datetime.now(),
|
||||||
|
source='remote' if self.remote_path else 'local',
|
||||||
|
source_path=self.remote_path or self.local_path
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error parsing tqdm line: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _read_remote_log(self) -> str:
|
||||||
|
"""Read log file from remote SFTP"""
|
||||||
|
if not self.sftp_client:
|
||||||
|
await self._connect_sftp()
|
||||||
|
|
||||||
|
try:
|
||||||
|
with self.sftp_client.open(self.remote_path, 'rb') as f: # Note the 'rb' mode
|
||||||
|
content = f.read()
|
||||||
|
return content.decode('utf-8')
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error reading remote log: {str(e)}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
async def _read_local_log(self) -> str:
|
||||||
|
"""Read log file from local path"""
|
||||||
|
try:
|
||||||
|
async with aiofiles.open(self.local_path, 'r') as f:
|
||||||
|
return await f.read()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error reading local log: {str(e)}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
async def _monitor_log(self):
|
||||||
|
"""Monitor log file for updates"""
|
||||||
|
while self._running:
|
||||||
|
try:
|
||||||
|
content = (await self._read_remote_log() if self.remote_path
|
||||||
|
else await self._read_local_log())
|
||||||
|
|
||||||
|
# Get last line containing progress info
|
||||||
|
lines = content.splitlines()
|
||||||
|
self.recent_logs = lines[-self.max_log_lines:] if lines else []
|
||||||
|
for line in reversed(lines):
|
||||||
|
if '|' in line: # Basic check for tqdm output
|
||||||
|
status = self._parse_tqdm_line(line)
|
||||||
|
if status:
|
||||||
|
self.current_status = status
|
||||||
|
break
|
||||||
|
|
||||||
|
await asyncio.sleep(1) # Check every second
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Monitor error: {str(e)}")
|
||||||
|
await asyncio.sleep(5) # Wait before retry
|
||||||
|
|
||||||
|
async def get_log(self, lines: int = 50) -> List[str]:
|
||||||
|
"""Get recent log entries"""
|
||||||
|
return self.recent_logs[-lines:]
|
||||||
|
|
||||||
|
async def startup(self):
|
||||||
|
"""Start the monitor"""
|
||||||
|
if not self.remote_path and not self.local_path:
|
||||||
|
raise ValueError("No log path configured")
|
||||||
|
|
||||||
|
self._running = True
|
||||||
|
self._monitor_task = asyncio.create_task(self._monitor_log())
|
||||||
|
logger.info("Training monitor started")
|
||||||
|
|
||||||
|
async def shutdown(self):
|
||||||
|
"""Stop the monitor"""
|
||||||
|
self._running = False
|
||||||
|
if self._monitor_task:
|
||||||
|
self._monitor_task.cancel()
|
||||||
|
try:
|
||||||
|
await self._monitor_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
logger.info("Training monitor stopped")
|
||||||
|
|
||||||
|
async def get_status(self) -> Optional[TrainingStatus]:
|
||||||
|
"""Get current training status"""
|
||||||
|
return self.current_status
|
||||||
|
|
||||||
|
# Include SFTP connection methods similar to SampleManager
|
||||||
|
async def _connect_sftp(self):
|
||||||
|
"""Create SFTP connection using SSH key"""
|
||||||
|
try:
|
||||||
|
key_path = os.path.expanduser(settings.SFTP_KEY_PATH)
|
||||||
|
ssh = paramiko.SSHClient()
|
||||||
|
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
||||||
|
|
||||||
|
ssh.connect(
|
||||||
|
hostname=settings.SFTP_HOST,
|
||||||
|
username=settings.SFTP_USER,
|
||||||
|
port=settings.SFTP_PORT,
|
||||||
|
key_filename=key_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.sftp_client = ssh.open_sftp()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"SFTP connection failed: {str(e)}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"SFTP Connection failed: {str(e)}")
|
||||||
|
|||||||
Reference in New Issue
Block a user