Add comparison support

Signed-off-by: Felipe Cardoso <felipe.cardoso@hotmail.it>
This commit is contained in:
2025-01-31 13:20:46 +01:00
parent 3a79065163
commit 9803e32f66
13 changed files with 974 additions and 2 deletions

View File

@@ -0,0 +1,68 @@
# app/api/routes/comparison.py
from fastapi import APIRouter, HTTPException
from starlette.responses import FileResponse
from app.models.comparison import PathRequest
from app.services.comparison_service import ComparisonService
router = APIRouter()
comparison_service = ComparisonService()
@router.post("/register")
async def register_comparison_path(request: PathRequest):
"""Register a new comparison path and get its ID."""
config_id = comparison_service.register_config(request.path)
return {"config_id": config_id}
@router.get("/image/{config_id}/{config_name}/{model_name}/{filename}")
async def get_comparison_image(config_id: str, config_name: str, model_name: str, filename: str):
"""Serve image files using the cached base path."""
base_path = comparison_service.get_base_path(config_id)
if not base_path:
raise HTTPException(status_code=404, detail="Configuration not found")
try:
full_path = base_path / config_name / model_name / filename
if not full_path.exists() or not full_path.is_file():
raise HTTPException(status_code=404, detail="Image not found")
return FileResponse(full_path)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/{config_id}/available")
async def fetch_available_configs(config_id: str):
"""Fetch available configs."""
base_path = comparison_service.get_base_path(config_id)
if not base_path:
raise HTTPException(status_code=404, detail="Configuration not found")
try:
return comparison_service.get_available_configs(str(base_path))
except ValueError as e:
# Convert ValueError from the service into a proper HTTP error
raise HTTPException(status_code=404, detail=str(e))
@router.get("/{config_id}/{config_name}")
async def fetch_config(config_id: str, config_name: str):
"""
Fetch detailed comparison data for a specific configuration.
Parameters:
config_id: The identifier returned from the initial path registration
config_name: The name of the specific configuration to load (e.g. 'cloth_lora')
"""
base_path = comparison_service.get_base_path(config_id)
if not base_path:
raise HTTPException(status_code=404, detail="Configuration not found")
try:
return comparison_service.load_config_data(str(base_path), config_name)
except ValueError as e:
# Convert ValueError from the service into a proper HTTP error
raise HTTPException(status_code=404, detail=str(e))

View File

@@ -6,7 +6,7 @@ import psutil
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from app.api.routes import training, samples, config from app.api.routes import training, samples, config, comparison
from app.core.config import settings from app.core.config import settings
from app.services.config_manager import ConfigManager from app.services.config_manager import ConfigManager
from app.services.sample_manager import SampleManager from app.services.sample_manager import SampleManager
@@ -105,6 +105,7 @@ async def shutdown_event():
app.include_router(training.router, prefix=f"{settings.API_VER_STR}/training", tags=["training"]) app.include_router(training.router, prefix=f"{settings.API_VER_STR}/training", tags=["training"])
app.include_router(samples.router, prefix=f"{settings.API_VER_STR}/samples", tags=["samples"]) app.include_router(samples.router, prefix=f"{settings.API_VER_STR}/samples", tags=["samples"])
app.include_router(config.router, prefix=f"{settings.API_VER_STR}/config", tags=["config"]) app.include_router(config.router, prefix=f"{settings.API_VER_STR}/config", tags=["config"])
app.include_router(comparison.router, prefix=f"{settings.API_VER_STR}/comparison", tags=["comparison"])
@app.get("/") @app.get("/")

View File

@@ -0,0 +1,55 @@
from typing import List, Optional, Dict
from pydantic import BaseModel
class ComparisonImage(BaseModel):
path: str
model: str
config: str
prompt_index: int
seed: int
lora1: Optional[str] = None
lora2: Optional[str] = None
prompt: Optional[str] = None # Adding prompt field
class ComparisonPair(BaseModel):
model1: ComparisonImage
model2: ComparisonImage
config: str
prompt_index: int
seed: int
prompt: str # The prompt used for both images
class ComparisonData(BaseModel):
configs: List[str] # Available config types (cloth_lora, identity_lora, dual_lora)
prompts: Dict[str, List[str]] # Mapping of config -> list of prompts
seeds: List[int] # All available seeds
pairs: List[ComparisonPair] # All comparison pairs with their prompts
class PathRequest(BaseModel):
"""Request model for providing the base comparison path"""
path: str
class ConfigRequest(BaseModel):
"""Request model for fetching specific configuration data"""
path: str
config_name: str
class ConfigInfo(BaseModel):
"""Basic information about an available configuration"""
name: str
model_count: int
prompt_count: int
seed_count: int
class AvailableConfigs(BaseModel):
"""Response model for the fetchConfigs endpoint"""
base_path: str
configs: List[ConfigInfo]

View File

@@ -0,0 +1,221 @@
# app/services/comparison_service.py
import json
from datetime import datetime
from pathlib import Path
from typing import Dict, Optional
from app.models.comparison import ConfigInfo, AvailableConfigs, ComparisonPair, ComparisonData, ComparisonImage
class ComparisonService:
def __init__(self):
# Cache structure to store configuration information
self.configs: Dict[str, dict] = {} # id -> config info
self.paths: Dict[str, Path] = {} # id -> base path
self.access_times: Dict[str, datetime] = {} # id -> last access time
def generate_config_id(self) -> str:
"""Generate a unique identifier for a configuration."""
# Simple timestamp-based ID, could be made more sophisticated
return datetime.now().strftime('%Y%m%d_%H%M%S')
def _get_config_data(self, config_dir: Path) -> dict:
"""Get config data with caching."""
config_path = config_dir / f"config_{config_dir.name}.json"
cache_key = str(config_path)
if cache_key not in self.configs:
if not config_path.exists():
raise ValueError(f"Configuration file not found: {config_path}")
with open(config_path) as f:
self.configs[cache_key] = json.load(f)
return self.configs[cache_key]
def register_config(self, base_path: str) -> str:
"""
Register a new configuration base path and return its ID.
This is called when a user first submits a path.
"""
config_id = self.generate_config_id()
self.paths[config_id] = Path(base_path)
self.access_times[config_id] = datetime.now()
return config_id
def get_base_path(self, config_id: str) -> Optional[Path]:
"""
Retrieve the base path for a given configuration ID.
Updates the last access time.
"""
if config_id in self.paths:
self.access_times[config_id] = datetime.now()
return self.paths[config_id]
return None
def clean_old_configs(self, max_age_hours: int = 72):
"""
Clean up configurations that haven't been accessed in a while.
This helps manage memory usage.
"""
current_time = datetime.now()
expired_ids = [
config_id for config_id, access_time in self.access_times.items()
if (current_time - access_time).total_seconds() > max_age_hours * 3600
]
for config_id in expired_ids:
self.paths.pop(config_id, None)
self.configs.pop(config_id, None)
self.access_times.pop(config_id, None)
def get_available_configs(self, base_path: str) -> AvailableConfigs:
base_path = Path(base_path)
configs = []
for config_dir in base_path.iterdir():
if not config_dir.is_dir():
continue
try:
config_data = self._get_config_data(config_dir)
model_count = sum(1 for x in config_dir.iterdir() if x.is_dir())
configs.append(ConfigInfo(
name=config_dir.name,
model_count=model_count,
prompt_count=len(config_data.get('prompts', [])),
seed_count=len(config_data.get('seeds', []))
))
except ValueError:
# Skip this config if we can't read its data
continue
return AvailableConfigs(
base_path=str(base_path),
configs=sorted(configs, key=lambda x: x.name)
)
def load_config_data(self, base_path: str, config_name: str) -> ComparisonData:
"""Load comparison data using the cached base path."""
# base_path = self.get_base_path(config_id)
# if not base_path:
# raise ValueError(f"Configuration ID {config_id} not found")
# config_dir = base_path / config_name
#
# if not config_dir.is_dir():
# raise ValueError(f"Configuration '{config_name}' not found in {base_path}")
#
# # Load the configuration file
# config_file = config_dir / f"config_{config_name}.json"
# if not config_file.exists():
# raise ValueError(f"Configuration file not found for {config_name}")
#
# with open(config_file) as f:
# config_data = json.load(f)
base_path = Path(base_path)
config_dir = base_path / config_name
if not config_dir.is_dir():
raise ValueError(f"Configuration '{config_name}' not found in {base_path}")
# Use our cached config data instead of reading directly
try:
config_data = self._get_config_data(config_dir)
except ValueError as e:
raise ValueError(f"Failed to load configuration data: {str(e)}")
# Process model directories
models = [d.name for d in config_dir.iterdir() if d.is_dir()]
pairs = []
# Build comparison pairs for this configuration
for model in models:
model_path = config_dir / model
images = list(model_path.glob("*.png"))
for img in images:
img_data = self.parse_image_path(img, config_dir)
# Find matching image in other model
other_model = [m for m in models if m != model][0]
other_path = model_path.parent / other_model / img.name
if other_path.exists():
other_data = self.parse_image_path(other_path, config_dir)
pairs.append(ComparisonPair(
model1=img_data,
model2=other_data,
config=config_name,
prompt_index=img_data.prompt_index,
seed=img_data.seed,
prompt=img_data.prompt or ""
))
return ComparisonData(
configs=[config_name],
prompts={config_name: config_data['prompts']},
seeds=sorted(config_data['seeds']),
pairs=pairs
)
def parse_image_path(self, path: Path, config_dir: Path) -> ComparisonImage:
"""
Parse an image filename to extract metadata about the image.
The method handles both single and dual LoRA naming patterns:
- Single LoRA: lora_prompt_0_seed_42.png
- Dual LoRA: lora1_lora2_prompt_0_seed_42.png
"""
filename = Path(path).stem
parts = filename.split('_')
prompt_idx = parts.index('prompt')
seed_idx = parts.index('seed')
prompt_index = int(parts[prompt_idx + 1])
seed = int(parts[seed_idx + 1])
# Instead of reading the file directly, use our cached method
prompt = None
try:
config_data = self._get_config_data(config_dir)
prompts = config_data.get('prompts', [])
if 0 <= prompt_index < len(prompts):
prompt = prompts[prompt_index]
except ValueError:
# If we can't get the config data, we'll continue without the prompt
pass
# config_file = config_dir / f"config_{config_dir.name}.json"
# if config_file.exists():
# with open(config_file) as f:
# config_data = json.load(f)
# prompts = config_data.get('prompts', [])
# if 0 <= prompt_index < len(prompts):
# prompt = prompts[prompt_index]
# Determine if this is a dual LoRA setup by checking parts before 'prompt'
lora_parts = parts[:prompt_idx]
if len(lora_parts) > 1:
# Dual LoRA case
return ComparisonImage(
path=str(path),
model=path.parent.name,
config=config_dir.name,
prompt_index=prompt_index,
seed=seed,
lora1=lora_parts[0],
lora2=lora_parts[1],
prompt=prompt
)
else:
# Single LoRA case
return ComparisonImage(
path=str(path),
model=path.parent.name,
config=config_dir.name,
prompt_index=prompt_index,
seed=seed,
lora1=lora_parts[0],
prompt=prompt
)

2
backend/run_server.sh Normal file
View File

@@ -0,0 +1,2 @@
#!/bin/bash
uvicorn app.main:app --reload --port 2000

View File

@@ -0,0 +1,71 @@
// src/app/comparison/page.tsx
"use client"
import {ComparisonViewer} from "@/components/ComparisonViewer";
import {useComparison} from "@/contexts/ComparisonContext";
import {PathSelector} from "@/components/PathSelector";
import {ConfigSelector} from "@/components/ConfigsSelector";
export default function ComparisonPage() {
// Get everything we need from the comparison context
const {
basePath,
availableConfigs,
currentConfig,
isLoading,
error,
loadConfig
} = useComparison();
// If we don't have a base path yet, show the path selector
if (!basePath) {
return (
<div className="min-h-screen bg-gray-900 flex items-center justify-center">
<PathSelector/>
</div>
);
}
return (
<div className="min-h-screen bg-gray-900">
{/* More flexible top bar */}
<div className="bg-gray-800 border-b border-gray-700 p-4">
<div className="flex flex-col lg:flex-row gap-4 items-start lg:items-center">
{/* Config selector takes full width on mobile, shares space on desktop */}
<div className="w-full lg:w-auto lg:flex-1">
<ConfigSelector
configs={availableConfigs}
selectedConfig={currentConfig}
onConfigSelect={loadConfig}
disabled={isLoading}
/>
</div>
{/* Path display that wraps naturally */}
<div className="w-full lg:w-auto flex items-center gap-2 text-sm">
<span className="text-gray-400 whitespace-nowrap">Path:</span>
<span className="text-gray-500 truncate">
{basePath}
</span>
{isLoading && (
<span className="text-blue-400 whitespace-nowrap">
Loading...
</span>
)}
</div>
</div>
</div>
{/* Main comparison area */}
<div className="h-[calc(100vh-4rem)]">
{error ? (
<div className="flex items-center justify-center h-full text-red-400">
{error}
</div>
) : (
<ComparisonViewer/>
)}
</div>
</div>
);
}

View File

@@ -3,6 +3,7 @@ import {Geist, Geist_Mono} from "next/font/google";
import "./globals.css"; import "./globals.css";
import {TrainingProvider} from "@/contexts/TrainingContext"; import {TrainingProvider} from "@/contexts/TrainingContext";
import {SamplesProvider} from "@/contexts/SamplesContext"; import {SamplesProvider} from "@/contexts/SamplesContext";
import {ComparisonProvider} from "@/contexts/ComparisonContext";
const geistSans = Geist({ const geistSans = Geist({
variable: "--font-geist-sans", variable: "--font-geist-sans",
@@ -31,7 +32,9 @@ export default function RootLayout({
> >
<TrainingProvider> <TrainingProvider>
<SamplesProvider> <SamplesProvider>
{children} <ComparisonProvider>
{children}
</ComparisonProvider>
</SamplesProvider> </SamplesProvider>
</TrainingProvider> </TrainingProvider>

View File

@@ -0,0 +1,139 @@
// src/components/ComparisonViewer.tsx
"use client"
import Image from 'next/image'
import {useComparison} from '@/contexts/ComparisonContext'
export function ComparisonViewer() {
// We get everything we need from the context instead of managing local state
const {
getCurrentPair,
nextPair,
previousPair,
currentPairIndex,
comparisonData,
isLoading,
getImageUrl,
} = useComparison()
// Get the current pair using our context helper
const currentPair = getCurrentPair()
// Handle loading state
if (isLoading) {
return (
<div className="h-full flex items-center justify-center text-gray-400">
<div className="space-y-2 text-center">
<div className="text-lg">Loading comparisons...</div>
<div className="text-sm text-gray-500">Please wait while we prepare your images</div>
</div>
</div>
)
}
// Handle no data state
if (!currentPair || !comparisonData) {
return (
<div className="h-full flex items-center justify-center text-gray-400">
<div className="space-y-2 text-center">
<div className="text-lg">No comparison data available</div>
<div className="text-sm text-gray-500">Please select a configuration to begin</div>
</div>
</div>
)
}
// // Helper function to construct image URLs
// const getImageUrl = (path: string) => {
// return `${env.API_URL}${path}`
// }
return (
<div className="h-full flex flex-col">
{/* Main image comparison area */}
<div className="flex-1 flex">
{/* Left image */}
<div className="flex-1 relative group">
<div className="absolute inset-0 flex items-center justify-center">
<Image
src={getImageUrl(currentPair.model1)}
alt={`${currentPair.model1.model} - Seed ${currentPair.seed}`}
className="max-h-full w-auto object-contain transition-transform duration-200 group-hover:scale-[1.02]"
width={1024}
height={1024}
/>
<div
className="absolute top-2 left-2 bg-black/50 backdrop-blur-sm text-white px-3 py-1.5 rounded-lg font-medium">
{currentPair.model1.model}
</div>
{/* Add metadata tooltip on hover */}
<div
className="absolute bottom-2 left-2 opacity-0 group-hover:opacity-100 transition-opacity bg-black/50 backdrop-blur-sm text-white px-3 py-1.5 rounded-lg text-sm">
Seed: {currentPair.seed}
{currentPair.model1.lora1 && <div>LoRA: {currentPair.model1.lora1}</div>}
{currentPair.model1.lora2 && <div>LoRA 2: {currentPair.model1.lora2}</div>}
</div>
</div>
</div>
{/* Right image - mirror of left image setup */}
<div className="flex-1 relative group">
<div className="absolute inset-0 flex items-center justify-center">
<Image
src={getImageUrl(currentPair.model2)}
alt={`${currentPair.model2.model} - Seed ${currentPair.seed}`}
className="max-h-full w-auto object-contain transition-transform duration-200 group-hover:scale-[1.02]"
width={1024}
height={1024}
/>
<div
className="absolute top-2 right-2 bg-black/50 backdrop-blur-sm text-white px-3 py-1.5 rounded-lg font-medium">
{currentPair.model2.model}
</div>
<div
className="absolute bottom-2 right-2 opacity-0 group-hover:opacity-100 transition-opacity bg-black/50 backdrop-blur-sm text-white px-3 py-1.5 rounded-lg text-sm">
Seed: {currentPair.seed}
{currentPair.model2.lora1 && <div>LoRA: {currentPair.model2.lora1}</div>}
{currentPair.model2.lora2 && <div>LoRA 2: {currentPair.model2.lora2}</div>}
</div>
</div>
</div>
</div>
{/* Bottom info panel with enhanced information */}
<div className="bg-gray-800 border-t border-gray-700 p-4">
<div className="flex justify-between items-start">
<div className="flex-1 space-y-2">
<div>
<h3 className="text-gray-300 font-semibold">Prompt:</h3>
<p className="text-gray-400 text-sm mt-1">{currentPair.prompt}</p>
</div>
<div className="flex gap-4 text-sm text-gray-500">
<div>Seed: {currentPair.seed}</div>
<div>Prompt Index: {currentPair.prompt_index}</div>
</div>
</div>
<div className="flex gap-4 items-center ml-4">
<button
onClick={previousPair}
className="px-3 py-1.5 rounded-md bg-gray-700 text-gray-300 hover:bg-gray-600
hover:text-white transition-colors duration-200"
>
Previous
</button>
<span className="text-gray-400 font-medium">
{currentPairIndex + 1} / {comparisonData.pairs.length}
</span>
<button
onClick={nextPair}
className="px-3 py-1.5 rounded-md bg-gray-700 text-gray-300 hover:bg-gray-600
hover:text-white transition-colors duration-200"
>
Next
</button>
</div>
</div>
</div>
</div>
)
}

View File

@@ -0,0 +1,78 @@
// src/components/ConfigSelector.tsx
"use client"
import {useCallback} from 'react'
import type {ConfigurationInfo} from '@/types/comparison'
interface ConfigSelectorProps {
configs: ConfigurationInfo[] // Now using our structured config info
selectedConfig: string | null // Matches the context's currentConfig type
onConfigSelect: (config: string) => Promise<void> // Handle async loading
disabled?: boolean // Allow disabling during loading states
}
// src/components/ConfigSelector.tsx
interface ConfigurationDisplay {
name: string;
count: string;
}
export function ConfigSelector({
configs,
selectedConfig,
onConfigSelect,
disabled = false
}: ConfigSelectorProps) {
// Helper to create display information
const getConfigDisplay = useCallback((config: ConfigurationInfo): ConfigurationDisplay => {
const baseName = config.name
.replace('_lora', '')
.split('_')
.map(word => word.charAt(0).toUpperCase() + word.slice(1))
.join(' ');
return {
name: baseName,
count: `${config.model_count}m, ${config.prompt_count}p` // Shortened display
};
}, []);
return (
<div className="flex flex-col sm:flex-row items-start sm:items-center gap-2 sm:gap-4 w-full">
{/* Label that stacks on mobile but stays inline on larger screens */}
<span className="text-gray-400 text-sm font-medium whitespace-nowrap">
Configuration:
</span>
{/* Button container that allows wrapping on smaller screens */}
<div className="flex flex-wrap gap-2 flex-1">
{configs.map(config => {
const display = getConfigDisplay(config);
return (
<button
key={config.name}
onClick={() => onConfigSelect(config.name)}
disabled={disabled || selectedConfig === config.name}
className={`
px-3 py-1.5 rounded-md text-sm font-medium
transition-colors duration-200
flex flex-col sm:flex-row items-center gap-1
min-w-[100px] sm:min-w-0
${disabled ? 'opacity-50 cursor-not-allowed' : ''}
${selectedConfig === config.name
? 'bg-blue-600 text-white'
: 'bg-gray-700 text-gray-300 hover:bg-gray-600'
}
`}
>
<span className="whitespace-nowrap">{display.name}</span>
<span className="text-xs opacity-75 whitespace-nowrap">
{display.count}
</span>
</button>
);
})}
</div>
</div>
);
}

View File

@@ -0,0 +1,63 @@
"use client"
import {useState} from 'react'
import {useComparison} from '@/contexts/ComparisonContext'
export function PathSelector() {
const {setBasePath, isLoading, error} = useComparison();
const [path, setPath] = useState('');
const handleSubmit = async (e: React.FormEvent) => {
e.preventDefault();
if (path.trim()) {
await setBasePath(path.trim());
}
};
return (
<div className="bg-gray-800 p-6 rounded-lg shadow-lg max-w-xl w-full mx-4">
<h2 className="text-xl font-semibold text-gray-200 mb-4">
Enter Comparison Path
</h2>
<form onSubmit={handleSubmit} className="space-y-4">
<div>
<label
htmlFor="path"
className="block text-sm font-medium text-gray-300 mb-2"
>
Base Path
</label>
<input
type="text"
id="path"
value={path}
onChange={(e) => setPath(e.target.value)}
placeholder="/path/to/comparison/directory"
className="w-full px-4 py-2 bg-gray-700 border border-gray-600 rounded-md
text-gray-200 placeholder-gray-400 focus:outline-none focus:ring-2
focus:ring-blue-500 focus:border-transparent"
disabled={isLoading}
/>
</div>
{error && (
<div className="text-red-400 text-sm">
{error}
</div>
)}
<button
type="submit"
disabled={isLoading || !path.trim()}
className="w-full px-4 py-2 bg-blue-600 text-white rounded-md
hover:bg-blue-700 focus:outline-none focus:ring-2
focus:ring-blue-500 focus:ring-offset-2 focus:ring-offset-gray-800
disabled:opacity-50 disabled:cursor-not-allowed"
>
{isLoading ? 'Loading...' : 'Load Comparisons'}
</button>
</form>
</div>
);
}

View File

@@ -0,0 +1,3 @@
export const env = {
API_URL: process.env.NEXT_PUBLIC_API_URL || 'http://localhost:2000'
} as const;

View File

@@ -0,0 +1,152 @@
'use client'
import {createContext, useCallback, useContext, useState} from 'react';
import {env} from '@/config/env';
import type {AvailableConfigs, ComparisonContextType, ComparisonData, ComparisonState} from '@/types/comparison';
const ComparisonContext = createContext<ComparisonContextType | undefined>(undefined);
export function ComparisonProvider({children}: { children: React.ReactNode }) {
// Our state now needs to include the configId we get from registration
const [state, setState] = useState<ComparisonState>({
basePath: null,
configId: null, // Add this to track our registered configuration
availableConfigs: [],
currentConfig: null,
currentPairIndex: 0,
comparisonData: null,
isLoading: false,
error: null,
});
// First step: Register the path and get a configuration ID
const setBasePath = useCallback(async (path: string) => {
setState(prev => ({...prev, isLoading: true, error: null}));
try {
// Register the path first to get our configId
const registerResponse = await fetch(`${env.API_URL}/api/v1/comparison/register`, {
method: 'POST',
headers: {'Content-Type': 'application/json'},
body: JSON.stringify({path})
});
if (!registerResponse.ok) throw new Error('Failed to register path');
const {config_id} = await registerResponse.json();
// After getting the configId, fetch available configurations
const configsResponse = await fetch(`${env.API_URL}/api/v1/comparison/${config_id}/available`);
if (!configsResponse.ok) throw new Error('Failed to fetch configurations');
const data: AvailableConfigs = await configsResponse.json();
setState(prev => ({
...prev,
basePath: path,
configId: config_id,
availableConfigs: data.configs,
isLoading: false
}));
} catch (error) {
setState(prev => ({
...prev,
error: error instanceof Error ? error.message : 'Unknown error',
isLoading: false
}));
}
}, []);
// Load a specific configuration using our new GET endpoint
const loadConfig = useCallback(async (configName: string) => {
if (!state.configId) return;
setState(prev => ({...prev, isLoading: true, error: null}));
try {
// Use the new GET endpoint structure
const response = await fetch(
`${env.API_URL}/api/v1/comparison/${state.configId}/${configName}`
);
if (!response.ok) throw new Error('Failed to fetch configuration data');
const data: ComparisonData = await response.json();
setState(prev => ({
...prev,
currentConfig: configName,
comparisonData: data,
currentPairIndex: 0,
isLoading: false
}));
} catch (error) {
setState(prev => ({
...prev,
error: error instanceof Error ? error.message : 'Unknown error',
isLoading: false
}));
}
}, [state.configId]);
// Helper to construct image URLs using our new endpoint structure
const getImageUrl = useCallback((model: { config: string, model: string, path: string }) => {
if (!state.configId) return '';
const filename = model.path.split("/").slice(-1)
return `${env.API_URL}/api/v1/comparison/image/${state.configId}/${model.config}/${model.model}/${filename}`;
}, [state.configId]);
// Navigation helpers remain the same
const nextPair = useCallback(() => {
if (!state.comparisonData?.pairs.length) return;
setState(prev => ({
...prev,
currentPairIndex: (prev.currentPairIndex + 1) % prev.comparisonData!.pairs.length
}));
}, [state.comparisonData?.pairs.length]);
const previousPair = useCallback(() => {
if (!state.comparisonData?.pairs.length) return;
setState(prev => ({
...prev,
currentPairIndex: (prev.currentPairIndex - 1 + prev.comparisonData!.pairs.length) % prev.comparisonData!.pairs.length
}));
}, [state.comparisonData?.pairs.length]);
const goToPair = useCallback((index: number) => {
if (!state.comparisonData?.pairs.length) return;
if (index >= 0 && index < state.comparisonData.pairs.length) {
setState(prev => ({...prev, currentPairIndex: index}));
}
}, [state.comparisonData?.pairs.length]);
const getCurrentPair = useCallback(() => {
if (!state.comparisonData?.pairs.length) return null;
return state.comparisonData.pairs[state.currentPairIndex];
}, [state.comparisonData, state.currentPairIndex]);
const value: ComparisonContextType = {
...state,
nextPair,
previousPair,
goToPair,
setBasePath,
loadConfig,
getCurrentPair,
getImageUrl, // Add this to help components construct image URLs
};
return (
<ComparisonContext.Provider value={value}>
{children}
</ComparisonContext.Provider>
);
}
export function useComparison() {
const context = useContext(ComparisonContext);
if (context === undefined) {
throw new Error('useComparison must be used within a ComparisonProvider');
}
return context;
}

View File

@@ -0,0 +1,116 @@
/**
* Represents a single image in the comparison system.
* This includes all metadata about the image and its generation parameters.
*/
interface ComparisonImage {
// Basic file information
path: string; // Full path to the image file
model: string; // Model name (e.g., 'flux_dev', 'ovs_bangel_001_000005000')
// Classification information
config: string; // Configuration type (e.g., 'cloth_lora', 'identity_lora', 'dual_lora')
prompt_index: number; // Index of the prompt used for generation
seed: number; // Seed used for generation
// LoRA information - optional as not all configs use both
lora1?: string; // First LoRA name (or single LoRA in non-dual cases)
lora2?: string; // Second LoRA name (only for dual_lora config)
// Generation parameters
prompt: string; // The actual prompt text used to generate this image
}
/**
* Represents a pair of images to be compared.
* Contains both images and their shared generation parameters.
*/
interface ComparisonPair {
model1: ComparisonImage; // First model's image and metadata
model2: ComparisonImage; // Second model's image and metadata
// Shared parameters for easy filtering and organization
config: string; // The configuration type for this pair
prompt_index: number; // Index of the shared prompt
seed: number; // Shared seed used for both generations
prompt: string; // The full prompt text used for both images
}
/**
* Contains all data needed for the comparison interface.
* Provides both the comparison pairs and the metadata needed for navigation and filtering.
*/
interface ComparisonData {
// Available configuration options
configs: string[]; // List of all configuration types (e.g., ['cloth_lora', 'identity_lora', 'dual_lora'])
// Mapping of prompts per configuration
prompts: Record<string, string[]>; // Example: { 'cloth_lora': ['prompt1', 'prompt2', ...] }
// Available seeds for filtering
seeds: number[]; // List of all seeds used in the comparisons
// The actual comparison data
pairs: ComparisonPair[]; // All comparison pairs available
}
/**
* Represents the filters that can be applied to the comparison view
*/
interface ComparisonFilters {
config?: string; // Selected configuration type
promptIndex?: number; // Selected prompt index
seed?: number; // Selected seed
}
// src/types/comparison.ts
// Add these to your existing types
// Represents the metadata about an available configuration
interface ConfigurationInfo {
name: string;
model_count: number;
prompt_count: number;
seed_count: number;
}
// Response from the fetchConfigs endpoint
interface AvailableConfigs {
base_path: string;
configs: ConfigurationInfo[];
}
// Represents the current state of comparison viewing
interface ComparisonState {
basePath: string | null;
configId: string | null; // Add this
availableConfigs: ConfigurationInfo[];
currentConfig: string | null;
currentPairIndex: number;
comparisonData: ComparisonData | null;
isLoading: boolean;
error: string | null;
}
// Actions that can be performed through the context
interface ComparisonContextType extends ComparisonState {
nextPair: () => void;
previousPair: () => void;
goToPair: (index: number) => void;
setBasePath: (path: string) => Promise<void>;
loadConfig: (configName: string) => Promise<void>;
getCurrentPair: () => ComparisonPair | null;
getImageUrl: (model: { config: string, model: string, filename: string }) => string; // Add this
}
// Export all types
export type {
ComparisonImage,
ComparisonPair,
ComparisonData,
ComparisonFilters,
ComparisonContextType,
ComparisonState,
AvailableConfigs,
ConfigurationInfo,
}