Add comparison support

Signed-off-by: Felipe Cardoso <felipe.cardoso@hotmail.it>
This commit is contained in:
2025-01-31 13:20:46 +01:00
parent 3a79065163
commit 9803e32f66
13 changed files with 974 additions and 2 deletions

View File

@@ -0,0 +1,71 @@
// src/app/comparison/page.tsx
"use client"
import {ComparisonViewer} from "@/components/ComparisonViewer";
import {useComparison} from "@/contexts/ComparisonContext";
import {PathSelector} from "@/components/PathSelector";
import {ConfigSelector} from "@/components/ConfigsSelector";
export default function ComparisonPage() {
// Get everything we need from the comparison context
const {
basePath,
availableConfigs,
currentConfig,
isLoading,
error,
loadConfig
} = useComparison();
// If we don't have a base path yet, show the path selector
if (!basePath) {
return (
<div className="min-h-screen bg-gray-900 flex items-center justify-center">
<PathSelector/>
</div>
);
}
return (
<div className="min-h-screen bg-gray-900">
{/* More flexible top bar */}
<div className="bg-gray-800 border-b border-gray-700 p-4">
<div className="flex flex-col lg:flex-row gap-4 items-start lg:items-center">
{/* Config selector takes full width on mobile, shares space on desktop */}
<div className="w-full lg:w-auto lg:flex-1">
<ConfigSelector
configs={availableConfigs}
selectedConfig={currentConfig}
onConfigSelect={loadConfig}
disabled={isLoading}
/>
</div>
{/* Path display that wraps naturally */}
<div className="w-full lg:w-auto flex items-center gap-2 text-sm">
<span className="text-gray-400 whitespace-nowrap">Path:</span>
<span className="text-gray-500 truncate">
{basePath}
</span>
{isLoading && (
<span className="text-blue-400 whitespace-nowrap">
Loading...
</span>
)}
</div>
</div>
</div>
{/* Main comparison area */}
<div className="h-[calc(100vh-4rem)]">
{error ? (
<div className="flex items-center justify-center h-full text-red-400">
{error}
</div>
) : (
<ComparisonViewer/>
)}
</div>
</div>
);
}

View File

@@ -3,6 +3,7 @@ import {Geist, Geist_Mono} from "next/font/google";
import "./globals.css";
import {TrainingProvider} from "@/contexts/TrainingContext";
import {SamplesProvider} from "@/contexts/SamplesContext";
import {ComparisonProvider} from "@/contexts/ComparisonContext";
const geistSans = Geist({
variable: "--font-geist-sans",
@@ -31,7 +32,9 @@ export default function RootLayout({
>
<TrainingProvider>
<SamplesProvider>
{children}
<ComparisonProvider>
{children}
</ComparisonProvider>
</SamplesProvider>
</TrainingProvider>

View File

@@ -0,0 +1,139 @@
// src/components/ComparisonViewer.tsx
"use client"
import Image from 'next/image'
import {useComparison} from '@/contexts/ComparisonContext'
export function ComparisonViewer() {
// We get everything we need from the context instead of managing local state
const {
getCurrentPair,
nextPair,
previousPair,
currentPairIndex,
comparisonData,
isLoading,
getImageUrl,
} = useComparison()
// Get the current pair using our context helper
const currentPair = getCurrentPair()
// Handle loading state
if (isLoading) {
return (
<div className="h-full flex items-center justify-center text-gray-400">
<div className="space-y-2 text-center">
<div className="text-lg">Loading comparisons...</div>
<div className="text-sm text-gray-500">Please wait while we prepare your images</div>
</div>
</div>
)
}
// Handle no data state
if (!currentPair || !comparisonData) {
return (
<div className="h-full flex items-center justify-center text-gray-400">
<div className="space-y-2 text-center">
<div className="text-lg">No comparison data available</div>
<div className="text-sm text-gray-500">Please select a configuration to begin</div>
</div>
</div>
)
}
// // Helper function to construct image URLs
// const getImageUrl = (path: string) => {
// return `${env.API_URL}${path}`
// }
return (
<div className="h-full flex flex-col">
{/* Main image comparison area */}
<div className="flex-1 flex">
{/* Left image */}
<div className="flex-1 relative group">
<div className="absolute inset-0 flex items-center justify-center">
<Image
src={getImageUrl(currentPair.model1)}
alt={`${currentPair.model1.model} - Seed ${currentPair.seed}`}
className="max-h-full w-auto object-contain transition-transform duration-200 group-hover:scale-[1.02]"
width={1024}
height={1024}
/>
<div
className="absolute top-2 left-2 bg-black/50 backdrop-blur-sm text-white px-3 py-1.5 rounded-lg font-medium">
{currentPair.model1.model}
</div>
{/* Add metadata tooltip on hover */}
<div
className="absolute bottom-2 left-2 opacity-0 group-hover:opacity-100 transition-opacity bg-black/50 backdrop-blur-sm text-white px-3 py-1.5 rounded-lg text-sm">
Seed: {currentPair.seed}
{currentPair.model1.lora1 && <div>LoRA: {currentPair.model1.lora1}</div>}
{currentPair.model1.lora2 && <div>LoRA 2: {currentPair.model1.lora2}</div>}
</div>
</div>
</div>
{/* Right image - mirror of left image setup */}
<div className="flex-1 relative group">
<div className="absolute inset-0 flex items-center justify-center">
<Image
src={getImageUrl(currentPair.model2)}
alt={`${currentPair.model2.model} - Seed ${currentPair.seed}`}
className="max-h-full w-auto object-contain transition-transform duration-200 group-hover:scale-[1.02]"
width={1024}
height={1024}
/>
<div
className="absolute top-2 right-2 bg-black/50 backdrop-blur-sm text-white px-3 py-1.5 rounded-lg font-medium">
{currentPair.model2.model}
</div>
<div
className="absolute bottom-2 right-2 opacity-0 group-hover:opacity-100 transition-opacity bg-black/50 backdrop-blur-sm text-white px-3 py-1.5 rounded-lg text-sm">
Seed: {currentPair.seed}
{currentPair.model2.lora1 && <div>LoRA: {currentPair.model2.lora1}</div>}
{currentPair.model2.lora2 && <div>LoRA 2: {currentPair.model2.lora2}</div>}
</div>
</div>
</div>
</div>
{/* Bottom info panel with enhanced information */}
<div className="bg-gray-800 border-t border-gray-700 p-4">
<div className="flex justify-between items-start">
<div className="flex-1 space-y-2">
<div>
<h3 className="text-gray-300 font-semibold">Prompt:</h3>
<p className="text-gray-400 text-sm mt-1">{currentPair.prompt}</p>
</div>
<div className="flex gap-4 text-sm text-gray-500">
<div>Seed: {currentPair.seed}</div>
<div>Prompt Index: {currentPair.prompt_index}</div>
</div>
</div>
<div className="flex gap-4 items-center ml-4">
<button
onClick={previousPair}
className="px-3 py-1.5 rounded-md bg-gray-700 text-gray-300 hover:bg-gray-600
hover:text-white transition-colors duration-200"
>
Previous
</button>
<span className="text-gray-400 font-medium">
{currentPairIndex + 1} / {comparisonData.pairs.length}
</span>
<button
onClick={nextPair}
className="px-3 py-1.5 rounded-md bg-gray-700 text-gray-300 hover:bg-gray-600
hover:text-white transition-colors duration-200"
>
Next
</button>
</div>
</div>
</div>
</div>
)
}

View File

@@ -0,0 +1,78 @@
// src/components/ConfigSelector.tsx
"use client"
import {useCallback} from 'react'
import type {ConfigurationInfo} from '@/types/comparison'
interface ConfigSelectorProps {
configs: ConfigurationInfo[] // Now using our structured config info
selectedConfig: string | null // Matches the context's currentConfig type
onConfigSelect: (config: string) => Promise<void> // Handle async loading
disabled?: boolean // Allow disabling during loading states
}
// src/components/ConfigSelector.tsx
interface ConfigurationDisplay {
name: string;
count: string;
}
export function ConfigSelector({
configs,
selectedConfig,
onConfigSelect,
disabled = false
}: ConfigSelectorProps) {
// Helper to create display information
const getConfigDisplay = useCallback((config: ConfigurationInfo): ConfigurationDisplay => {
const baseName = config.name
.replace('_lora', '')
.split('_')
.map(word => word.charAt(0).toUpperCase() + word.slice(1))
.join(' ');
return {
name: baseName,
count: `${config.model_count}m, ${config.prompt_count}p` // Shortened display
};
}, []);
return (
<div className="flex flex-col sm:flex-row items-start sm:items-center gap-2 sm:gap-4 w-full">
{/* Label that stacks on mobile but stays inline on larger screens */}
<span className="text-gray-400 text-sm font-medium whitespace-nowrap">
Configuration:
</span>
{/* Button container that allows wrapping on smaller screens */}
<div className="flex flex-wrap gap-2 flex-1">
{configs.map(config => {
const display = getConfigDisplay(config);
return (
<button
key={config.name}
onClick={() => onConfigSelect(config.name)}
disabled={disabled || selectedConfig === config.name}
className={`
px-3 py-1.5 rounded-md text-sm font-medium
transition-colors duration-200
flex flex-col sm:flex-row items-center gap-1
min-w-[100px] sm:min-w-0
${disabled ? 'opacity-50 cursor-not-allowed' : ''}
${selectedConfig === config.name
? 'bg-blue-600 text-white'
: 'bg-gray-700 text-gray-300 hover:bg-gray-600'
}
`}
>
<span className="whitespace-nowrap">{display.name}</span>
<span className="text-xs opacity-75 whitespace-nowrap">
{display.count}
</span>
</button>
);
})}
</div>
</div>
);
}

View File

@@ -0,0 +1,63 @@
"use client"
import {useState} from 'react'
import {useComparison} from '@/contexts/ComparisonContext'
export function PathSelector() {
const {setBasePath, isLoading, error} = useComparison();
const [path, setPath] = useState('');
const handleSubmit = async (e: React.FormEvent) => {
e.preventDefault();
if (path.trim()) {
await setBasePath(path.trim());
}
};
return (
<div className="bg-gray-800 p-6 rounded-lg shadow-lg max-w-xl w-full mx-4">
<h2 className="text-xl font-semibold text-gray-200 mb-4">
Enter Comparison Path
</h2>
<form onSubmit={handleSubmit} className="space-y-4">
<div>
<label
htmlFor="path"
className="block text-sm font-medium text-gray-300 mb-2"
>
Base Path
</label>
<input
type="text"
id="path"
value={path}
onChange={(e) => setPath(e.target.value)}
placeholder="/path/to/comparison/directory"
className="w-full px-4 py-2 bg-gray-700 border border-gray-600 rounded-md
text-gray-200 placeholder-gray-400 focus:outline-none focus:ring-2
focus:ring-blue-500 focus:border-transparent"
disabled={isLoading}
/>
</div>
{error && (
<div className="text-red-400 text-sm">
{error}
</div>
)}
<button
type="submit"
disabled={isLoading || !path.trim()}
className="w-full px-4 py-2 bg-blue-600 text-white rounded-md
hover:bg-blue-700 focus:outline-none focus:ring-2
focus:ring-blue-500 focus:ring-offset-2 focus:ring-offset-gray-800
disabled:opacity-50 disabled:cursor-not-allowed"
>
{isLoading ? 'Loading...' : 'Load Comparisons'}
</button>
</form>
</div>
);
}

View File

@@ -0,0 +1,3 @@
export const env = {
API_URL: process.env.NEXT_PUBLIC_API_URL || 'http://localhost:2000'
} as const;

View File

@@ -0,0 +1,152 @@
'use client'
import {createContext, useCallback, useContext, useState} from 'react';
import {env} from '@/config/env';
import type {AvailableConfigs, ComparisonContextType, ComparisonData, ComparisonState} from '@/types/comparison';
const ComparisonContext = createContext<ComparisonContextType | undefined>(undefined);
export function ComparisonProvider({children}: { children: React.ReactNode }) {
// Our state now needs to include the configId we get from registration
const [state, setState] = useState<ComparisonState>({
basePath: null,
configId: null, // Add this to track our registered configuration
availableConfigs: [],
currentConfig: null,
currentPairIndex: 0,
comparisonData: null,
isLoading: false,
error: null,
});
// First step: Register the path and get a configuration ID
const setBasePath = useCallback(async (path: string) => {
setState(prev => ({...prev, isLoading: true, error: null}));
try {
// Register the path first to get our configId
const registerResponse = await fetch(`${env.API_URL}/api/v1/comparison/register`, {
method: 'POST',
headers: {'Content-Type': 'application/json'},
body: JSON.stringify({path})
});
if (!registerResponse.ok) throw new Error('Failed to register path');
const {config_id} = await registerResponse.json();
// After getting the configId, fetch available configurations
const configsResponse = await fetch(`${env.API_URL}/api/v1/comparison/${config_id}/available`);
if (!configsResponse.ok) throw new Error('Failed to fetch configurations');
const data: AvailableConfigs = await configsResponse.json();
setState(prev => ({
...prev,
basePath: path,
configId: config_id,
availableConfigs: data.configs,
isLoading: false
}));
} catch (error) {
setState(prev => ({
...prev,
error: error instanceof Error ? error.message : 'Unknown error',
isLoading: false
}));
}
}, []);
// Load a specific configuration using our new GET endpoint
const loadConfig = useCallback(async (configName: string) => {
if (!state.configId) return;
setState(prev => ({...prev, isLoading: true, error: null}));
try {
// Use the new GET endpoint structure
const response = await fetch(
`${env.API_URL}/api/v1/comparison/${state.configId}/${configName}`
);
if (!response.ok) throw new Error('Failed to fetch configuration data');
const data: ComparisonData = await response.json();
setState(prev => ({
...prev,
currentConfig: configName,
comparisonData: data,
currentPairIndex: 0,
isLoading: false
}));
} catch (error) {
setState(prev => ({
...prev,
error: error instanceof Error ? error.message : 'Unknown error',
isLoading: false
}));
}
}, [state.configId]);
// Helper to construct image URLs using our new endpoint structure
const getImageUrl = useCallback((model: { config: string, model: string, path: string }) => {
if (!state.configId) return '';
const filename = model.path.split("/").slice(-1)
return `${env.API_URL}/api/v1/comparison/image/${state.configId}/${model.config}/${model.model}/${filename}`;
}, [state.configId]);
// Navigation helpers remain the same
const nextPair = useCallback(() => {
if (!state.comparisonData?.pairs.length) return;
setState(prev => ({
...prev,
currentPairIndex: (prev.currentPairIndex + 1) % prev.comparisonData!.pairs.length
}));
}, [state.comparisonData?.pairs.length]);
const previousPair = useCallback(() => {
if (!state.comparisonData?.pairs.length) return;
setState(prev => ({
...prev,
currentPairIndex: (prev.currentPairIndex - 1 + prev.comparisonData!.pairs.length) % prev.comparisonData!.pairs.length
}));
}, [state.comparisonData?.pairs.length]);
const goToPair = useCallback((index: number) => {
if (!state.comparisonData?.pairs.length) return;
if (index >= 0 && index < state.comparisonData.pairs.length) {
setState(prev => ({...prev, currentPairIndex: index}));
}
}, [state.comparisonData?.pairs.length]);
const getCurrentPair = useCallback(() => {
if (!state.comparisonData?.pairs.length) return null;
return state.comparisonData.pairs[state.currentPairIndex];
}, [state.comparisonData, state.currentPairIndex]);
const value: ComparisonContextType = {
...state,
nextPair,
previousPair,
goToPair,
setBasePath,
loadConfig,
getCurrentPair,
getImageUrl, // Add this to help components construct image URLs
};
return (
<ComparisonContext.Provider value={value}>
{children}
</ComparisonContext.Provider>
);
}
export function useComparison() {
const context = useContext(ComparisonContext);
if (context === undefined) {
throw new Error('useComparison must be used within a ComparisonProvider');
}
return context;
}

View File

@@ -0,0 +1,116 @@
/**
* Represents a single image in the comparison system.
* This includes all metadata about the image and its generation parameters.
*/
interface ComparisonImage {
// Basic file information
path: string; // Full path to the image file
model: string; // Model name (e.g., 'flux_dev', 'ovs_bangel_001_000005000')
// Classification information
config: string; // Configuration type (e.g., 'cloth_lora', 'identity_lora', 'dual_lora')
prompt_index: number; // Index of the prompt used for generation
seed: number; // Seed used for generation
// LoRA information - optional as not all configs use both
lora1?: string; // First LoRA name (or single LoRA in non-dual cases)
lora2?: string; // Second LoRA name (only for dual_lora config)
// Generation parameters
prompt: string; // The actual prompt text used to generate this image
}
/**
* Represents a pair of images to be compared.
* Contains both images and their shared generation parameters.
*/
interface ComparisonPair {
model1: ComparisonImage; // First model's image and metadata
model2: ComparisonImage; // Second model's image and metadata
// Shared parameters for easy filtering and organization
config: string; // The configuration type for this pair
prompt_index: number; // Index of the shared prompt
seed: number; // Shared seed used for both generations
prompt: string; // The full prompt text used for both images
}
/**
* Contains all data needed for the comparison interface.
* Provides both the comparison pairs and the metadata needed for navigation and filtering.
*/
interface ComparisonData {
// Available configuration options
configs: string[]; // List of all configuration types (e.g., ['cloth_lora', 'identity_lora', 'dual_lora'])
// Mapping of prompts per configuration
prompts: Record<string, string[]>; // Example: { 'cloth_lora': ['prompt1', 'prompt2', ...] }
// Available seeds for filtering
seeds: number[]; // List of all seeds used in the comparisons
// The actual comparison data
pairs: ComparisonPair[]; // All comparison pairs available
}
/**
* Represents the filters that can be applied to the comparison view
*/
interface ComparisonFilters {
config?: string; // Selected configuration type
promptIndex?: number; // Selected prompt index
seed?: number; // Selected seed
}
// src/types/comparison.ts
// Add these to your existing types
// Represents the metadata about an available configuration
interface ConfigurationInfo {
name: string;
model_count: number;
prompt_count: number;
seed_count: number;
}
// Response from the fetchConfigs endpoint
interface AvailableConfigs {
base_path: string;
configs: ConfigurationInfo[];
}
// Represents the current state of comparison viewing
interface ComparisonState {
basePath: string | null;
configId: string | null; // Add this
availableConfigs: ConfigurationInfo[];
currentConfig: string | null;
currentPairIndex: number;
comparisonData: ComparisonData | null;
isLoading: boolean;
error: string | null;
}
// Actions that can be performed through the context
interface ComparisonContextType extends ComparisonState {
nextPair: () => void;
previousPair: () => void;
goToPair: (index: number) => void;
setBasePath: (path: string) => Promise<void>;
loadConfig: (configName: string) => Promise<void>;
getCurrentPair: () => ComparisonPair | null;
getImageUrl: (model: { config: string, model: string, filename: string }) => string; // Add this
}
// Export all types
export type {
ComparisonImage,
ComparisonPair,
ComparisonData,
ComparisonFilters,
ComparisonContextType,
ComparisonState,
AvailableConfigs,
ConfigurationInfo,
}