140 lines
4.3 KiB
Python
140 lines
4.3 KiB
Python
import logging
|
|
import platform
|
|
from datetime import datetime, timezone
|
|
|
|
import psutil
|
|
from fastapi import FastAPI
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
from app.api.routes import training, samples, config, comparison
|
|
from app.core.config import settings
|
|
from app.services.config_manager import ConfigManager
|
|
from app.services.sample_manager import SampleManager
|
|
from app.services.training_monitor import TrainingMonitor
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
app = FastAPI(
|
|
title="Training Monitor API",
|
|
description="API for monitoring ML training progress and samples",
|
|
version="1.0.0",
|
|
)
|
|
|
|
# Configure CORS
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"], # In production, replace with specific origins
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# Create and store SampleManager instance
|
|
sample_manager = SampleManager()
|
|
training_monitor = TrainingMonitor()
|
|
config_manager = ConfigManager()
|
|
|
|
app.state.sample_manager = sample_manager
|
|
app.state.training_monitor = training_monitor
|
|
app.state.config_manager = config_manager
|
|
|
|
|
|
async def initialize_services():
|
|
"""
|
|
Initializes all service managers in the correct order, ensuring dependencies
|
|
are properly set up before they're needed.
|
|
"""
|
|
logger.info("Starting services initialization...")
|
|
|
|
# First, initialize config manager as other services might need configuration
|
|
config_manager = ConfigManager()
|
|
app.state.config_manager = config_manager
|
|
|
|
try:
|
|
# Load initial configuration
|
|
config = await config_manager.get_config()
|
|
logger.info(f"Configuration loaded successfully for job: {config.job}")
|
|
|
|
# Store config in app state for easy access
|
|
app.state.training_config = config
|
|
|
|
# Initialize other managers with access to config
|
|
sample_manager = SampleManager()
|
|
training_monitor = TrainingMonitor()
|
|
|
|
# Store managers in app state
|
|
app.state.sample_manager = sample_manager
|
|
app.state.training_monitor = training_monitor
|
|
|
|
# Start the managers
|
|
await sample_manager.startup()
|
|
await training_monitor.startup()
|
|
|
|
logger.info("All services initialized successfully")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize services: {str(e)}")
|
|
# Re-raise to prevent app from starting with partially initialized services
|
|
raise
|
|
|
|
|
|
@app.on_event("startup")
|
|
async def startup_event():
|
|
"""
|
|
Startup event handler that ensures all services are properly initialized
|
|
before the application starts accepting requests.
|
|
"""
|
|
logger.info("Starting up Training Monitor API")
|
|
await initialize_services()
|
|
|
|
|
|
|
|
@app.on_event("shutdown")
|
|
async def shutdown_event():
|
|
"""Cleanup on shutdown"""
|
|
logger.info("Shutting down Training Monitor API")
|
|
await sample_manager.shutdown()
|
|
await training_monitor.shutdown()
|
|
|
|
|
|
# Include routers with versioning
|
|
app.include_router(training.router, prefix=f"{settings.API_VER_STR}/training", tags=["training"])
|
|
app.include_router(samples.router, prefix=f"{settings.API_VER_STR}/samples", tags=["samples"])
|
|
app.include_router(config.router, prefix=f"{settings.API_VER_STR}/config", tags=["config"])
|
|
app.include_router(comparison.router, prefix=f"{settings.API_VER_STR}/comparison", tags=["comparison"])
|
|
|
|
|
|
@app.get("/")
|
|
async def root():
|
|
"""
|
|
Root endpoint providing API status and system information
|
|
"""
|
|
return {
|
|
"name": "Training Monitor API",
|
|
"version": "1.0.0",
|
|
"status": "operational",
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
"system_info": {
|
|
"cpu_usage": f"{psutil.cpu_percent()}%",
|
|
"memory_usage": f"{psutil.virtual_memory().percent}%",
|
|
"platform": platform.platform(),
|
|
"python": platform.python_version(),
|
|
},
|
|
"endpoints": {
|
|
"docs": "/docs",
|
|
"health": "/health",
|
|
"training_status": "/api/v1/training/status",
|
|
"training_log": "/api/v1/training/log",
|
|
"samples_list": "/api/v1/samples/list",
|
|
"samples_latest": "/api/v1/samples/latest"
|
|
}
|
|
}
|
|
|
|
@app.get("/health")
|
|
async def health_check():
|
|
return {"status": "healthy"}
|