Compare commits

...

3 Commits

Author SHA1 Message Date
Felipe Cardoso
37b71464f6 Add config parsing support in backend
Signed-off-by: Felipe Cardoso <felipe.cardoso@hotmail.it>
2025-01-23 13:46:30 +01:00
Felipe Cardoso
f99564434a Add samples page
Signed-off-by: Felipe Cardoso <felipe.cardoso@hotmail.it>
2025-01-23 13:39:14 +01:00
Felipe Cardoso
4b9d3e7d55 Refactor samples gallery
Signed-off-by: Felipe Cardoso <felipe.cardoso@hotmail.it>
2025-01-23 13:10:58 +01:00
11 changed files with 417 additions and 15 deletions

View File

@@ -0,0 +1,12 @@
from fastapi import APIRouter, Request
from app.models.config import TrainingConfig
router = APIRouter()
@router.get("/config", response_model=TrainingConfig)
async def get_training_config(request: Request):
"""Retrieves the current training configuration"""
config_manager = request.app.state.config_manager
return await config_manager.get_config()

View File

@@ -10,7 +10,7 @@ router = APIRouter()
@router.get("/list", response_model=List[Sample])
async def list_samples(
request: Request,
limit: int = Query(20, ge=1, le=1000),
limit: int = Query(200, ge=1, le=1000),
offset: int = Query(0, ge=0)
):
"""
@@ -25,7 +25,7 @@ async def list_samples(
@router.get("/latest", response_model=List[Sample])
async def get_latest_samples(
request: Request,
count: int = Query(5, ge=1, le=20)
count: int = Query(20, ge=1, le=100)
):
"""
Get the most recent sample images

View File

@@ -11,10 +11,12 @@ class Settings(BaseSettings):
SFTP_PATH: Optional[str] = None
SFTP_PORT: int = 22
TRAINING_LOG_REMOTE_PATH: Optional[str] = None
TRAINING_CONFIG_LOCAL_PATH: Optional[str] = None
# Local Settings (Optional)
LOCAL_PATH: Optional[str] = None
TRAINING_LOG_LOCAL_PATH: Optional[str] = None
TRAINING_CONFIG_REMOTE_PATH: Optional[str] = None
# API Settings
PROJECT_NAME: str = "Training Monitor"

View File

@@ -6,8 +6,9 @@ import psutil
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.api.routes import training, samples
from app.api.routes import training, samples, config
from app.core.config import settings
from app.services.config_manager import ConfigManager
from app.services.sample_manager import SampleManager
from app.services.training_monitor import TrainingMonitor
@@ -35,17 +36,60 @@ app.add_middleware(
# Create and store SampleManager instance
sample_manager = SampleManager()
training_monitor = TrainingMonitor()
config_manager = ConfigManager()
app.state.sample_manager = sample_manager
app.state.training_monitor = training_monitor
app.state.config_manager = config_manager
async def initialize_services():
"""
Initializes all service managers in the correct order, ensuring dependencies
are properly set up before they're needed.
"""
logger.info("Starting services initialization...")
# First, initialize config manager as other services might need configuration
config_manager = ConfigManager()
app.state.config_manager = config_manager
try:
# Load initial configuration
config = await config_manager.get_config()
logger.info(f"Configuration loaded successfully for job: {config.job}")
# Store config in app state for easy access
app.state.training_config = config
# Initialize other managers with access to config
sample_manager = SampleManager()
training_monitor = TrainingMonitor()
# Store managers in app state
app.state.sample_manager = sample_manager
app.state.training_monitor = training_monitor
# Start the managers
await sample_manager.startup()
await training_monitor.startup()
logger.info("All services initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize services: {str(e)}")
# Re-raise to prevent app from starting with partially initialized services
raise
@app.on_event("startup")
async def startup_event():
"""Initialize services on startup"""
"""
Startup event handler that ensures all services are properly initialized
before the application starts accepting requests.
"""
logger.info("Starting up Training Monitor API")
await sample_manager.startup()
await training_monitor.startup()
await initialize_services()
@@ -60,6 +104,7 @@ async def shutdown_event():
# Include routers with versioning
app.include_router(training.router, prefix=f"{settings.API_VER_STR}/training", tags=["training"])
app.include_router(samples.router, prefix=f"{settings.API_VER_STR}/samples", tags=["samples"])
app.include_router(config.router, prefix=f"{settings.API_VER_STR}/config", tags=["config"])
@app.get("/")

View File

@@ -0,0 +1,93 @@
from typing import List, Optional, Dict, Any
from pydantic import BaseModel
class SampleConfig(BaseModel):
sampler: str
sample_every: int
width: int
height: int
prompts: List[str]
neg: str
seed: int
walk_seed: bool
guidance_scale: float
sample_steps: int
class DatasetConfig(BaseModel):
folder_path: str
caption_ext: Optional[str] = None
caption_dropout_rate: Optional[float] = None
shuffle_tokens: Optional[bool] = False
resolution: Optional[List[int]] = None
class EMAConfig(BaseModel):
use_ema: Optional[bool] = False
ema_decay: Optional[float] = None
class TrainConfig(BaseModel):
batch_size: int
bypass_guidance_embedding: Optional[bool] = False
timestep_type: Optional[str] = None
steps: int
gradient_accumulation: Optional[int] = 1
train_unet: Optional[bool] = True
train_text_encoder: Optional[bool] = False
gradient_checkpointing: Optional[bool] = False
noise_scheduler: Optional[str] = None
optimizer: Optional[str] = None
lr: Optional[float] = None
ema_config: Optional[EMAConfig] = None
dtype: Optional[str] = None
do_paramiter_swapping: Optional[bool] = False
paramiter_swapping_factor: Optional[float] = None
skip_first_sample: Optional[bool] = False
disable_sampling: Optional[bool] = False
class ModelConfig(BaseModel):
name_or_path: str
is_flux: Optional[bool] = False
quantize: Optional[bool] = False
quantize_te: Optional[bool] = False
class SaveConfig(BaseModel):
dtype: Optional[str] = None
save_every: Optional[int] = None
max_step_saves_to_keep: Optional[int] = None
save_format: Optional[str] = None
class ProcessConfig(BaseModel):
type: str
training_folder: str
performance_log_every: Optional[int] = None
device: Optional[str] = None
trigger_word: Optional[str] = None
save: Optional[SaveConfig] = None
datasets: List[DatasetConfig]
train: TrainConfig
model: ModelConfig
sample: SampleConfig
class MetaConfig(BaseModel):
name: Optional[str] = None
version: Optional[str] = None
class TrainingConfig(BaseModel):
job: str
config: Dict[str, Any] # This will contain 'name' and 'process'
meta: MetaConfig
# And a Config class to represent the middle layer:
class Config(BaseModel):
name: str
process: List[ProcessConfig]

View File

@@ -0,0 +1,151 @@
# app/services/config_manager.py
import logging
import os
import aiofiles
import paramiko
import yaml
from fastapi import HTTPException
from app.core.config import settings
from app.models.config import TrainingConfig, ProcessConfig, SaveConfig, DatasetConfig
logger = logging.getLogger(__name__)
class ConfigManager:
"""
Manages access to training configuration files, supporting both local and remote (SFTP) paths.
Handles YAML parsing and conversion to strongly-typed configuration objects.
"""
def __init__(self):
# Initialize paths from settings, defaulting to None if not configured
self.remote_path = getattr(settings, 'TRAINING_CONFIG_REMOTE_PATH', None)
self.local_path = getattr(settings, 'TRAINING_CONFIG_LOCAL_PATH', None)
self.sftp_client = None
self.cached_config = None
# Validate that at least one path is configured
if not self.remote_path and not self.local_path:
raise ValueError("Either TRAINING_CONFIG_REMOTE_PATH or TRAINING_CONFIG_LOCAL_PATH must be configured")
logger.info(f"ConfigManager initialized with remote_path={self.remote_path}, local_path={self.local_path}")
async def _connect_sftp(self):
"""Establishes SFTP connection using SSH key authentication"""
try:
key_path = os.path.expanduser(settings.SFTP_KEY_PATH)
logger.info(f"Connecting to SFTP {settings.SFTP_HOST} with key {key_path}")
ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
ssh.connect(
hostname=settings.SFTP_HOST,
username=settings.SFTP_USER,
port=settings.SFTP_PORT,
key_filename=key_path,
)
self.sftp_client = ssh.open_sftp()
logger.info("SFTP connection established successfully")
except Exception as e:
logger.error(f"Failed to establish SFTP connection: {str(e)}")
raise HTTPException(status_code=500, detail=f"SFTP connection failed: {str(e)}")
def _disconnect_sftp(self):
"""Safely closes SFTP connection if it exists"""
if self.sftp_client:
try:
self.sftp_client.close()
self.sftp_client = None
logger.info("SFTP connection closed")
except Exception as e:
logger.error(f"Error closing SFTP connection: {str(e)}")
async def _read_remote_config(self) -> dict:
"""Reads and parses YAML configuration from remote SFTP location"""
if not self.sftp_client:
await self._connect_sftp()
try:
with self.sftp_client.open(self.remote_path, 'r') as f:
content = f.read()
return yaml.safe_load(content)
except Exception as e:
logger.error(f"Failed to read remote config: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to read remote config: {str(e)}")
finally:
self._disconnect_sftp()
async def _read_local_config(self) -> dict:
"""Reads and parses YAML configuration from local filesystem"""
try:
async with aiofiles.open(self.local_path, 'r') as f:
content = await f.read()
return yaml.safe_load(content)
except Exception as e:
logger.error(f"Failed to read local config: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to read local config: {str(e)}")
def _parse_config(self, raw_config: dict) -> TrainingConfig:
"""
Converts raw YAML dictionary into strongly-typed configuration objects.
Handles optional fields and nested configurations.
"""
try:
# Extract the first process configuration (assuming single process for now)
process_data = raw_config['config']['process'][0]
# Build the process config with all its nested components
process = ProcessConfig(
type=process_data['type'],
training_folder=process_data['training_folder'],
performance_log_every=process_data.get('performance_log_every'),
device=process_data.get('device'),
trigger_word=process_data.get('trigger_word'),
save=SaveConfig(**process_data['save']) if 'save' in process_data else None,
datasets=[DatasetConfig(**ds) for ds in process_data.get('datasets', [])],
train=process_data['train'],
model=process_data['model'],
sample=process_data['sample']
)
# Reconstruct the config dictionary with our parsed process
config_dict = dict(raw_config['config'])
config_dict['process'] = [process]
# Create the full training config
return TrainingConfig(
job=raw_config.get('job', ''),
config=config_dict,
meta=raw_config.get('meta', {})
)
except Exception as e:
logger.error(f"Failed to parse config: {str(e)}")
raise HTTPException(status_code=500, detail=f"Config parsing failed: {str(e)}")
async def get_config(self) -> TrainingConfig:
"""
Main method to retrieve and parse configuration.
Automatically handles local or remote access based on configuration.
"""
if self.cached_config is not None:
return self.cached_config
try:
# Read raw config from appropriate source
raw_config = await self._read_remote_config() if self.remote_path else await self._read_local_config()
# Parse and return strongly-typed config
parsed_config = self._parse_config(raw_config)
self.cached_config = parsed_config
return parsed_config
except Exception as e:
logger.error(f"Failed to get config: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -191,7 +191,7 @@ class SampleManager:
# Wait a bit before retrying on error
await asyncio.sleep(5)
async def list_samples(self, limit: int = 20, offset: int = 0) -> List[Sample]:
async def list_samples(self, limit: int = 200, offset: int = 0) -> List[Sample]:
"""List sample images with pagination"""
logger.info(f"Total samples: {len(self.samples)}")
@@ -204,7 +204,7 @@ class SampleManager:
return sorted_samples[offset:offset + limit]
async def get_latest_samples(self, count: int = 5) -> List[Sample]:
async def get_latest_samples(self, count: int = 20) -> List[Sample]:
"""Get most recent samples"""
return await self.list_samples(limit=count, offset=0)

View File

@@ -11,4 +11,5 @@ aiofiles>=23.2.1
pytest>=7.4.3
httpx>=0.25.1
pytest-asyncio>=0.21.1
psutil>=5.9.8
psutil>=5.9.8
PyYAML~=6.0.2

View File

@@ -0,0 +1,96 @@
// src/app/samples/page.tsx
"use client"
import {useSamples} from '@/contexts/SamplesContext'
import Image from 'next/image'
import {useMemo} from 'react'
interface ParsedSample {
filename: string
timestamp: number
batch: number
index: number
url: string
created_at: string
}
export default function SamplesPage() {
const {samples, isLoading, error} = useSamples()
const groupedSamples = useMemo(() => {
if (!samples?.length) return new Map()
const parsed: ParsedSample[] = samples.map(sample => {
console.debug('sample', sample)
const [timestamp, info] = sample.filename.split('__')
const [batch, index] = info.split('_')
return {
...sample,
timestamp: parseInt(timestamp),
batch: parseInt(batch),
index: parseInt(index.replace('.jpg', '')),
}
})
// Group by batch
const groups = parsed.reduce((acc, sample) => {
const group = acc.get(sample.batch) || []
group.push(sample)
acc.set(sample.batch, group)
return acc
}, new Map<number, ParsedSample[]>())
// Sort within each group
for (const [batch, items] of groups) {
groups.set(batch, items.sort((a, b) => b.index - a.index))
}
// return new Map([...groups].sort((a, b) => b[0] - a[0]))
return new Map([...groups])
}, [samples])
if (isLoading) return <div>Loading samples...</div>
if (error) return <div>Error: {error.message}</div>
return (
<div className="p-6 space-y-8">
<h1 className="text-2xl font-bold">Samples Gallery</h1>
{Array.from(groupedSamples).map(([batch, items]) => (
<div key={batch} className="space-y-2">
<h2 className="text-xl font-semibold">
Step {batch}
</h2>
<div className="overflow-x-auto">
<div className="flex gap-4 min-w-full pb-4">
{items.sort((a: any, b: any) => a.index - b.index).map((sample: any) => (
<div key={sample.filename} className="flex-shrink-0 w-48">
<Image
src={`${process.env.NEXT_PUBLIC_API_URL}${sample.url}`}
alt={`Sample ${sample.index}`}
width={200}
height={200}
className="rounded-lg shadow-sm object-cover w-full h-48"
/>
<div className="mt-2 text-sm">
<div className="text-gray-600">
{new Date(sample.created_at).toLocaleString()}
</div>
<a
href={`${process.env.NEXT_PUBLIC_API_URL}${sample.url}`}
target="_blank"
rel="noopener noreferrer"
className="text-blue-600 hover:underline"
>
Sample {sample.index}
</a>
</div>
</div>
))}
</div>
</div>
</div>
))}
</div>
)
}

View File

@@ -3,15 +3,15 @@ import {useSamples} from '@/contexts/SamplesContext'
import Image from 'next/image'
export function SamplesGallery() {
const {samples, isLoading, error, refreshSamples} = useSamples()
const {latestSamples, isLoading, error, refreshSamples} = useSamples()
if (isLoading) return <div>Loading samples...</div>
if (error) return <div>Error loading samples: {error.message}</div>
if (samples.length === 0) return <div>No samples available</div>
if (latestSamples.length === 0) return <div>No samples available</div>
return (
<div className="grid grid-cols-3 gap-4">
{samples.map((sample) => (
<div className="grid grid-cols-5 gap-4">
{latestSamples.map((sample) => (
<div key={sample.filename}>
<Image
src={`${process.env.NEXT_PUBLIC_API_URL}${sample.url}`}
@@ -20,7 +20,7 @@ export function SamplesGallery() {
height={200}
className="object-cover rounded"
/>
<p className="text-sm mt-1">{sample.filename}</p>
<p className="text-sm mt-1">{sample.url.split('__')[1]}</p>
</div>
))}
</div>

View File

@@ -4,6 +4,7 @@ import type {Sample} from '@/types/api'
interface SamplesContextType {
samples: Sample[]
latestSamples: Sample[]
isLoading: boolean
error: Error | null
refreshSamples: () => Promise<void>
@@ -37,7 +38,8 @@ export function SamplesProvider({children}: { children: React.ReactNode }) {
}, [])
return (
<SamplesContext.Provider value={{samples, isLoading, error, refreshSamples: fetchSamples}}>
<SamplesContext.Provider
value={{samples, latestSamples: samples.slice(0, 20), isLoading, error, refreshSamples: fetchSamples}}>
{children}
</SamplesContext.Provider>
)