Add comparison support
Signed-off-by: Felipe Cardoso <felipe.cardoso@hotmail.it>
This commit is contained in:
68
backend/app/api/routes/comparison.py
Normal file
68
backend/app/api/routes/comparison.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# app/api/routes/comparison.py
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from starlette.responses import FileResponse
|
||||
|
||||
from app.models.comparison import PathRequest
|
||||
from app.services.comparison_service import ComparisonService
|
||||
|
||||
router = APIRouter()
|
||||
comparison_service = ComparisonService()
|
||||
|
||||
|
||||
@router.post("/register")
|
||||
async def register_comparison_path(request: PathRequest):
|
||||
"""Register a new comparison path and get its ID."""
|
||||
config_id = comparison_service.register_config(request.path)
|
||||
return {"config_id": config_id}
|
||||
|
||||
|
||||
@router.get("/image/{config_id}/{config_name}/{model_name}/{filename}")
|
||||
async def get_comparison_image(config_id: str, config_name: str, model_name: str, filename: str):
|
||||
"""Serve image files using the cached base path."""
|
||||
base_path = comparison_service.get_base_path(config_id)
|
||||
if not base_path:
|
||||
raise HTTPException(status_code=404, detail="Configuration not found")
|
||||
|
||||
try:
|
||||
full_path = base_path / config_name / model_name / filename
|
||||
if not full_path.exists() or not full_path.is_file():
|
||||
raise HTTPException(status_code=404, detail="Image not found")
|
||||
|
||||
return FileResponse(full_path)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/{config_id}/available")
|
||||
async def fetch_available_configs(config_id: str):
|
||||
"""Fetch available configs."""
|
||||
base_path = comparison_service.get_base_path(config_id)
|
||||
if not base_path:
|
||||
raise HTTPException(status_code=404, detail="Configuration not found")
|
||||
|
||||
try:
|
||||
return comparison_service.get_available_configs(str(base_path))
|
||||
except ValueError as e:
|
||||
# Convert ValueError from the service into a proper HTTP error
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/{config_id}/{config_name}")
|
||||
async def fetch_config(config_id: str, config_name: str):
|
||||
"""
|
||||
Fetch detailed comparison data for a specific configuration.
|
||||
|
||||
Parameters:
|
||||
config_id: The identifier returned from the initial path registration
|
||||
config_name: The name of the specific configuration to load (e.g. 'cloth_lora')
|
||||
"""
|
||||
base_path = comparison_service.get_base_path(config_id)
|
||||
if not base_path:
|
||||
raise HTTPException(status_code=404, detail="Configuration not found")
|
||||
|
||||
try:
|
||||
return comparison_service.load_config_data(str(base_path), config_name)
|
||||
except ValueError as e:
|
||||
# Convert ValueError from the service into a proper HTTP error
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
@@ -6,7 +6,7 @@ import psutil
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.api.routes import training, samples, config
|
||||
from app.api.routes import training, samples, config, comparison
|
||||
from app.core.config import settings
|
||||
from app.services.config_manager import ConfigManager
|
||||
from app.services.sample_manager import SampleManager
|
||||
@@ -105,6 +105,7 @@ async def shutdown_event():
|
||||
app.include_router(training.router, prefix=f"{settings.API_VER_STR}/training", tags=["training"])
|
||||
app.include_router(samples.router, prefix=f"{settings.API_VER_STR}/samples", tags=["samples"])
|
||||
app.include_router(config.router, prefix=f"{settings.API_VER_STR}/config", tags=["config"])
|
||||
app.include_router(comparison.router, prefix=f"{settings.API_VER_STR}/comparison", tags=["comparison"])
|
||||
|
||||
|
||||
@app.get("/")
|
||||
|
||||
55
backend/app/models/comparison.py
Normal file
55
backend/app/models/comparison.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from typing import List, Optional, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ComparisonImage(BaseModel):
|
||||
path: str
|
||||
model: str
|
||||
config: str
|
||||
prompt_index: int
|
||||
seed: int
|
||||
lora1: Optional[str] = None
|
||||
lora2: Optional[str] = None
|
||||
prompt: Optional[str] = None # Adding prompt field
|
||||
|
||||
|
||||
class ComparisonPair(BaseModel):
|
||||
model1: ComparisonImage
|
||||
model2: ComparisonImage
|
||||
config: str
|
||||
prompt_index: int
|
||||
seed: int
|
||||
prompt: str # The prompt used for both images
|
||||
|
||||
|
||||
class ComparisonData(BaseModel):
|
||||
configs: List[str] # Available config types (cloth_lora, identity_lora, dual_lora)
|
||||
prompts: Dict[str, List[str]] # Mapping of config -> list of prompts
|
||||
seeds: List[int] # All available seeds
|
||||
pairs: List[ComparisonPair] # All comparison pairs with their prompts
|
||||
|
||||
|
||||
class PathRequest(BaseModel):
|
||||
"""Request model for providing the base comparison path"""
|
||||
path: str
|
||||
|
||||
|
||||
class ConfigRequest(BaseModel):
|
||||
"""Request model for fetching specific configuration data"""
|
||||
path: str
|
||||
config_name: str
|
||||
|
||||
|
||||
class ConfigInfo(BaseModel):
|
||||
"""Basic information about an available configuration"""
|
||||
name: str
|
||||
model_count: int
|
||||
prompt_count: int
|
||||
seed_count: int
|
||||
|
||||
|
||||
class AvailableConfigs(BaseModel):
|
||||
"""Response model for the fetchConfigs endpoint"""
|
||||
base_path: str
|
||||
configs: List[ConfigInfo]
|
||||
221
backend/app/services/comparison_service.py
Normal file
221
backend/app/services/comparison_service.py
Normal file
@@ -0,0 +1,221 @@
|
||||
# app/services/comparison_service.py
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
||||
from app.models.comparison import ConfigInfo, AvailableConfigs, ComparisonPair, ComparisonData, ComparisonImage
|
||||
|
||||
|
||||
class ComparisonService:
|
||||
def __init__(self):
|
||||
# Cache structure to store configuration information
|
||||
self.configs: Dict[str, dict] = {} # id -> config info
|
||||
self.paths: Dict[str, Path] = {} # id -> base path
|
||||
self.access_times: Dict[str, datetime] = {} # id -> last access time
|
||||
|
||||
def generate_config_id(self) -> str:
|
||||
"""Generate a unique identifier for a configuration."""
|
||||
# Simple timestamp-based ID, could be made more sophisticated
|
||||
return datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
|
||||
def _get_config_data(self, config_dir: Path) -> dict:
|
||||
"""Get config data with caching."""
|
||||
config_path = config_dir / f"config_{config_dir.name}.json"
|
||||
cache_key = str(config_path)
|
||||
|
||||
if cache_key not in self.configs:
|
||||
if not config_path.exists():
|
||||
raise ValueError(f"Configuration file not found: {config_path}")
|
||||
|
||||
with open(config_path) as f:
|
||||
self.configs[cache_key] = json.load(f)
|
||||
|
||||
return self.configs[cache_key]
|
||||
|
||||
def register_config(self, base_path: str) -> str:
|
||||
"""
|
||||
Register a new configuration base path and return its ID.
|
||||
This is called when a user first submits a path.
|
||||
"""
|
||||
config_id = self.generate_config_id()
|
||||
self.paths[config_id] = Path(base_path)
|
||||
self.access_times[config_id] = datetime.now()
|
||||
return config_id
|
||||
|
||||
def get_base_path(self, config_id: str) -> Optional[Path]:
|
||||
"""
|
||||
Retrieve the base path for a given configuration ID.
|
||||
Updates the last access time.
|
||||
"""
|
||||
if config_id in self.paths:
|
||||
self.access_times[config_id] = datetime.now()
|
||||
return self.paths[config_id]
|
||||
return None
|
||||
|
||||
def clean_old_configs(self, max_age_hours: int = 72):
|
||||
"""
|
||||
Clean up configurations that haven't been accessed in a while.
|
||||
This helps manage memory usage.
|
||||
"""
|
||||
current_time = datetime.now()
|
||||
expired_ids = [
|
||||
config_id for config_id, access_time in self.access_times.items()
|
||||
if (current_time - access_time).total_seconds() > max_age_hours * 3600
|
||||
]
|
||||
for config_id in expired_ids:
|
||||
self.paths.pop(config_id, None)
|
||||
self.configs.pop(config_id, None)
|
||||
self.access_times.pop(config_id, None)
|
||||
|
||||
def get_available_configs(self, base_path: str) -> AvailableConfigs:
|
||||
base_path = Path(base_path)
|
||||
configs = []
|
||||
|
||||
for config_dir in base_path.iterdir():
|
||||
if not config_dir.is_dir():
|
||||
continue
|
||||
|
||||
try:
|
||||
config_data = self._get_config_data(config_dir)
|
||||
model_count = sum(1 for x in config_dir.iterdir() if x.is_dir())
|
||||
|
||||
configs.append(ConfigInfo(
|
||||
name=config_dir.name,
|
||||
model_count=model_count,
|
||||
prompt_count=len(config_data.get('prompts', [])),
|
||||
seed_count=len(config_data.get('seeds', []))
|
||||
))
|
||||
except ValueError:
|
||||
# Skip this config if we can't read its data
|
||||
continue
|
||||
|
||||
return AvailableConfigs(
|
||||
base_path=str(base_path),
|
||||
configs=sorted(configs, key=lambda x: x.name)
|
||||
)
|
||||
|
||||
def load_config_data(self, base_path: str, config_name: str) -> ComparisonData:
|
||||
"""Load comparison data using the cached base path."""
|
||||
# base_path = self.get_base_path(config_id)
|
||||
# if not base_path:
|
||||
# raise ValueError(f"Configuration ID {config_id} not found")
|
||||
# config_dir = base_path / config_name
|
||||
#
|
||||
# if not config_dir.is_dir():
|
||||
# raise ValueError(f"Configuration '{config_name}' not found in {base_path}")
|
||||
#
|
||||
# # Load the configuration file
|
||||
# config_file = config_dir / f"config_{config_name}.json"
|
||||
# if not config_file.exists():
|
||||
# raise ValueError(f"Configuration file not found for {config_name}")
|
||||
#
|
||||
# with open(config_file) as f:
|
||||
# config_data = json.load(f)
|
||||
base_path = Path(base_path)
|
||||
config_dir = base_path / config_name
|
||||
|
||||
if not config_dir.is_dir():
|
||||
raise ValueError(f"Configuration '{config_name}' not found in {base_path}")
|
||||
|
||||
# Use our cached config data instead of reading directly
|
||||
try:
|
||||
config_data = self._get_config_data(config_dir)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Failed to load configuration data: {str(e)}")
|
||||
|
||||
# Process model directories
|
||||
models = [d.name for d in config_dir.iterdir() if d.is_dir()]
|
||||
pairs = []
|
||||
|
||||
# Build comparison pairs for this configuration
|
||||
for model in models:
|
||||
model_path = config_dir / model
|
||||
images = list(model_path.glob("*.png"))
|
||||
|
||||
for img in images:
|
||||
img_data = self.parse_image_path(img, config_dir)
|
||||
|
||||
# Find matching image in other model
|
||||
other_model = [m for m in models if m != model][0]
|
||||
other_path = model_path.parent / other_model / img.name
|
||||
|
||||
if other_path.exists():
|
||||
other_data = self.parse_image_path(other_path, config_dir)
|
||||
pairs.append(ComparisonPair(
|
||||
model1=img_data,
|
||||
model2=other_data,
|
||||
config=config_name,
|
||||
prompt_index=img_data.prompt_index,
|
||||
seed=img_data.seed,
|
||||
prompt=img_data.prompt or ""
|
||||
))
|
||||
|
||||
return ComparisonData(
|
||||
configs=[config_name],
|
||||
prompts={config_name: config_data['prompts']},
|
||||
seeds=sorted(config_data['seeds']),
|
||||
pairs=pairs
|
||||
)
|
||||
|
||||
def parse_image_path(self, path: Path, config_dir: Path) -> ComparisonImage:
|
||||
|
||||
"""
|
||||
Parse an image filename to extract metadata about the image.
|
||||
The method handles both single and dual LoRA naming patterns:
|
||||
- Single LoRA: lora_prompt_0_seed_42.png
|
||||
- Dual LoRA: lora1_lora2_prompt_0_seed_42.png
|
||||
"""
|
||||
filename = Path(path).stem
|
||||
parts = filename.split('_')
|
||||
|
||||
prompt_idx = parts.index('prompt')
|
||||
seed_idx = parts.index('seed')
|
||||
|
||||
prompt_index = int(parts[prompt_idx + 1])
|
||||
seed = int(parts[seed_idx + 1])
|
||||
|
||||
# Instead of reading the file directly, use our cached method
|
||||
prompt = None
|
||||
try:
|
||||
config_data = self._get_config_data(config_dir)
|
||||
prompts = config_data.get('prompts', [])
|
||||
if 0 <= prompt_index < len(prompts):
|
||||
prompt = prompts[prompt_index]
|
||||
except ValueError:
|
||||
# If we can't get the config data, we'll continue without the prompt
|
||||
pass
|
||||
# config_file = config_dir / f"config_{config_dir.name}.json"
|
||||
# if config_file.exists():
|
||||
# with open(config_file) as f:
|
||||
# config_data = json.load(f)
|
||||
# prompts = config_data.get('prompts', [])
|
||||
# if 0 <= prompt_index < len(prompts):
|
||||
# prompt = prompts[prompt_index]
|
||||
|
||||
# Determine if this is a dual LoRA setup by checking parts before 'prompt'
|
||||
lora_parts = parts[:prompt_idx]
|
||||
|
||||
if len(lora_parts) > 1:
|
||||
# Dual LoRA case
|
||||
return ComparisonImage(
|
||||
path=str(path),
|
||||
model=path.parent.name,
|
||||
config=config_dir.name,
|
||||
prompt_index=prompt_index,
|
||||
seed=seed,
|
||||
lora1=lora_parts[0],
|
||||
lora2=lora_parts[1],
|
||||
prompt=prompt
|
||||
)
|
||||
else:
|
||||
# Single LoRA case
|
||||
return ComparisonImage(
|
||||
path=str(path),
|
||||
model=path.parent.name,
|
||||
config=config_dir.name,
|
||||
prompt_index=prompt_index,
|
||||
seed=seed,
|
||||
lora1=lora_parts[0],
|
||||
prompt=prompt
|
||||
)
|
||||
2
backend/run_server.sh
Normal file
2
backend/run_server.sh
Normal file
@@ -0,0 +1,2 @@
|
||||
#!/bin/bash
|
||||
uvicorn app.main:app --reload --port 2000
|
||||
71
frontend/src/app/comparison/page.tsx
Normal file
71
frontend/src/app/comparison/page.tsx
Normal file
@@ -0,0 +1,71 @@
|
||||
// src/app/comparison/page.tsx
|
||||
"use client"
|
||||
|
||||
import {ComparisonViewer} from "@/components/ComparisonViewer";
|
||||
import {useComparison} from "@/contexts/ComparisonContext";
|
||||
import {PathSelector} from "@/components/PathSelector";
|
||||
import {ConfigSelector} from "@/components/ConfigsSelector";
|
||||
|
||||
export default function ComparisonPage() {
|
||||
// Get everything we need from the comparison context
|
||||
const {
|
||||
basePath,
|
||||
availableConfigs,
|
||||
currentConfig,
|
||||
isLoading,
|
||||
error,
|
||||
loadConfig
|
||||
} = useComparison();
|
||||
|
||||
// If we don't have a base path yet, show the path selector
|
||||
if (!basePath) {
|
||||
return (
|
||||
<div className="min-h-screen bg-gray-900 flex items-center justify-center">
|
||||
<PathSelector/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="min-h-screen bg-gray-900">
|
||||
{/* More flexible top bar */}
|
||||
<div className="bg-gray-800 border-b border-gray-700 p-4">
|
||||
<div className="flex flex-col lg:flex-row gap-4 items-start lg:items-center">
|
||||
{/* Config selector takes full width on mobile, shares space on desktop */}
|
||||
<div className="w-full lg:w-auto lg:flex-1">
|
||||
<ConfigSelector
|
||||
configs={availableConfigs}
|
||||
selectedConfig={currentConfig}
|
||||
onConfigSelect={loadConfig}
|
||||
disabled={isLoading}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* Path display that wraps naturally */}
|
||||
<div className="w-full lg:w-auto flex items-center gap-2 text-sm">
|
||||
<span className="text-gray-400 whitespace-nowrap">Path:</span>
|
||||
<span className="text-gray-500 truncate">
|
||||
{basePath}
|
||||
</span>
|
||||
{isLoading && (
|
||||
<span className="text-blue-400 whitespace-nowrap">
|
||||
Loading...
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Main comparison area */}
|
||||
<div className="h-[calc(100vh-4rem)]">
|
||||
{error ? (
|
||||
<div className="flex items-center justify-center h-full text-red-400">
|
||||
{error}
|
||||
</div>
|
||||
) : (
|
||||
<ComparisonViewer/>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -3,6 +3,7 @@ import {Geist, Geist_Mono} from "next/font/google";
|
||||
import "./globals.css";
|
||||
import {TrainingProvider} from "@/contexts/TrainingContext";
|
||||
import {SamplesProvider} from "@/contexts/SamplesContext";
|
||||
import {ComparisonProvider} from "@/contexts/ComparisonContext";
|
||||
|
||||
const geistSans = Geist({
|
||||
variable: "--font-geist-sans",
|
||||
@@ -31,7 +32,9 @@ export default function RootLayout({
|
||||
>
|
||||
<TrainingProvider>
|
||||
<SamplesProvider>
|
||||
{children}
|
||||
<ComparisonProvider>
|
||||
{children}
|
||||
</ComparisonProvider>
|
||||
</SamplesProvider>
|
||||
</TrainingProvider>
|
||||
|
||||
|
||||
139
frontend/src/components/ComparisonViewer.tsx
Normal file
139
frontend/src/components/ComparisonViewer.tsx
Normal file
@@ -0,0 +1,139 @@
|
||||
// src/components/ComparisonViewer.tsx
|
||||
"use client"
|
||||
|
||||
import Image from 'next/image'
|
||||
import {useComparison} from '@/contexts/ComparisonContext'
|
||||
|
||||
export function ComparisonViewer() {
|
||||
// We get everything we need from the context instead of managing local state
|
||||
const {
|
||||
getCurrentPair,
|
||||
nextPair,
|
||||
previousPair,
|
||||
currentPairIndex,
|
||||
comparisonData,
|
||||
isLoading,
|
||||
getImageUrl,
|
||||
} = useComparison()
|
||||
|
||||
// Get the current pair using our context helper
|
||||
const currentPair = getCurrentPair()
|
||||
|
||||
// Handle loading state
|
||||
if (isLoading) {
|
||||
return (
|
||||
<div className="h-full flex items-center justify-center text-gray-400">
|
||||
<div className="space-y-2 text-center">
|
||||
<div className="text-lg">Loading comparisons...</div>
|
||||
<div className="text-sm text-gray-500">Please wait while we prepare your images</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// Handle no data state
|
||||
if (!currentPair || !comparisonData) {
|
||||
return (
|
||||
<div className="h-full flex items-center justify-center text-gray-400">
|
||||
<div className="space-y-2 text-center">
|
||||
<div className="text-lg">No comparison data available</div>
|
||||
<div className="text-sm text-gray-500">Please select a configuration to begin</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// // Helper function to construct image URLs
|
||||
// const getImageUrl = (path: string) => {
|
||||
// return `${env.API_URL}${path}`
|
||||
// }
|
||||
|
||||
return (
|
||||
<div className="h-full flex flex-col">
|
||||
{/* Main image comparison area */}
|
||||
<div className="flex-1 flex">
|
||||
{/* Left image */}
|
||||
<div className="flex-1 relative group">
|
||||
<div className="absolute inset-0 flex items-center justify-center">
|
||||
<Image
|
||||
src={getImageUrl(currentPair.model1)}
|
||||
alt={`${currentPair.model1.model} - Seed ${currentPair.seed}`}
|
||||
className="max-h-full w-auto object-contain transition-transform duration-200 group-hover:scale-[1.02]"
|
||||
width={1024}
|
||||
height={1024}
|
||||
/>
|
||||
<div
|
||||
className="absolute top-2 left-2 bg-black/50 backdrop-blur-sm text-white px-3 py-1.5 rounded-lg font-medium">
|
||||
{currentPair.model1.model}
|
||||
</div>
|
||||
{/* Add metadata tooltip on hover */}
|
||||
<div
|
||||
className="absolute bottom-2 left-2 opacity-0 group-hover:opacity-100 transition-opacity bg-black/50 backdrop-blur-sm text-white px-3 py-1.5 rounded-lg text-sm">
|
||||
Seed: {currentPair.seed}
|
||||
{currentPair.model1.lora1 && <div>LoRA: {currentPair.model1.lora1}</div>}
|
||||
{currentPair.model1.lora2 && <div>LoRA 2: {currentPair.model1.lora2}</div>}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Right image - mirror of left image setup */}
|
||||
<div className="flex-1 relative group">
|
||||
<div className="absolute inset-0 flex items-center justify-center">
|
||||
<Image
|
||||
src={getImageUrl(currentPair.model2)}
|
||||
alt={`${currentPair.model2.model} - Seed ${currentPair.seed}`}
|
||||
className="max-h-full w-auto object-contain transition-transform duration-200 group-hover:scale-[1.02]"
|
||||
width={1024}
|
||||
height={1024}
|
||||
/>
|
||||
<div
|
||||
className="absolute top-2 right-2 bg-black/50 backdrop-blur-sm text-white px-3 py-1.5 rounded-lg font-medium">
|
||||
{currentPair.model2.model}
|
||||
</div>
|
||||
<div
|
||||
className="absolute bottom-2 right-2 opacity-0 group-hover:opacity-100 transition-opacity bg-black/50 backdrop-blur-sm text-white px-3 py-1.5 rounded-lg text-sm">
|
||||
Seed: {currentPair.seed}
|
||||
{currentPair.model2.lora1 && <div>LoRA: {currentPair.model2.lora1}</div>}
|
||||
{currentPair.model2.lora2 && <div>LoRA 2: {currentPair.model2.lora2}</div>}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Bottom info panel with enhanced information */}
|
||||
<div className="bg-gray-800 border-t border-gray-700 p-4">
|
||||
<div className="flex justify-between items-start">
|
||||
<div className="flex-1 space-y-2">
|
||||
<div>
|
||||
<h3 className="text-gray-300 font-semibold">Prompt:</h3>
|
||||
<p className="text-gray-400 text-sm mt-1">{currentPair.prompt}</p>
|
||||
</div>
|
||||
<div className="flex gap-4 text-sm text-gray-500">
|
||||
<div>Seed: {currentPair.seed}</div>
|
||||
<div>Prompt Index: {currentPair.prompt_index}</div>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex gap-4 items-center ml-4">
|
||||
<button
|
||||
onClick={previousPair}
|
||||
className="px-3 py-1.5 rounded-md bg-gray-700 text-gray-300 hover:bg-gray-600
|
||||
hover:text-white transition-colors duration-200"
|
||||
>
|
||||
Previous
|
||||
</button>
|
||||
<span className="text-gray-400 font-medium">
|
||||
{currentPairIndex + 1} / {comparisonData.pairs.length}
|
||||
</span>
|
||||
<button
|
||||
onClick={nextPair}
|
||||
className="px-3 py-1.5 rounded-md bg-gray-700 text-gray-300 hover:bg-gray-600
|
||||
hover:text-white transition-colors duration-200"
|
||||
>
|
||||
Next
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
78
frontend/src/components/ConfigsSelector.tsx
Normal file
78
frontend/src/components/ConfigsSelector.tsx
Normal file
@@ -0,0 +1,78 @@
|
||||
// src/components/ConfigSelector.tsx
|
||||
"use client"
|
||||
|
||||
import {useCallback} from 'react'
|
||||
import type {ConfigurationInfo} from '@/types/comparison'
|
||||
|
||||
interface ConfigSelectorProps {
|
||||
configs: ConfigurationInfo[] // Now using our structured config info
|
||||
selectedConfig: string | null // Matches the context's currentConfig type
|
||||
onConfigSelect: (config: string) => Promise<void> // Handle async loading
|
||||
disabled?: boolean // Allow disabling during loading states
|
||||
}
|
||||
|
||||
// src/components/ConfigSelector.tsx
|
||||
interface ConfigurationDisplay {
|
||||
name: string;
|
||||
count: string;
|
||||
}
|
||||
|
||||
export function ConfigSelector({
|
||||
configs,
|
||||
selectedConfig,
|
||||
onConfigSelect,
|
||||
disabled = false
|
||||
}: ConfigSelectorProps) {
|
||||
// Helper to create display information
|
||||
const getConfigDisplay = useCallback((config: ConfigurationInfo): ConfigurationDisplay => {
|
||||
const baseName = config.name
|
||||
.replace('_lora', '')
|
||||
.split('_')
|
||||
.map(word => word.charAt(0).toUpperCase() + word.slice(1))
|
||||
.join(' ');
|
||||
|
||||
return {
|
||||
name: baseName,
|
||||
count: `${config.model_count}m, ${config.prompt_count}p` // Shortened display
|
||||
};
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<div className="flex flex-col sm:flex-row items-start sm:items-center gap-2 sm:gap-4 w-full">
|
||||
{/* Label that stacks on mobile but stays inline on larger screens */}
|
||||
<span className="text-gray-400 text-sm font-medium whitespace-nowrap">
|
||||
Configuration:
|
||||
</span>
|
||||
|
||||
{/* Button container that allows wrapping on smaller screens */}
|
||||
<div className="flex flex-wrap gap-2 flex-1">
|
||||
{configs.map(config => {
|
||||
const display = getConfigDisplay(config);
|
||||
return (
|
||||
<button
|
||||
key={config.name}
|
||||
onClick={() => onConfigSelect(config.name)}
|
||||
disabled={disabled || selectedConfig === config.name}
|
||||
className={`
|
||||
px-3 py-1.5 rounded-md text-sm font-medium
|
||||
transition-colors duration-200
|
||||
flex flex-col sm:flex-row items-center gap-1
|
||||
min-w-[100px] sm:min-w-0
|
||||
${disabled ? 'opacity-50 cursor-not-allowed' : ''}
|
||||
${selectedConfig === config.name
|
||||
? 'bg-blue-600 text-white'
|
||||
: 'bg-gray-700 text-gray-300 hover:bg-gray-600'
|
||||
}
|
||||
`}
|
||||
>
|
||||
<span className="whitespace-nowrap">{display.name}</span>
|
||||
<span className="text-xs opacity-75 whitespace-nowrap">
|
||||
{display.count}
|
||||
</span>
|
||||
</button>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
63
frontend/src/components/PathSelector.tsx
Normal file
63
frontend/src/components/PathSelector.tsx
Normal file
@@ -0,0 +1,63 @@
|
||||
"use client"
|
||||
|
||||
import {useState} from 'react'
|
||||
import {useComparison} from '@/contexts/ComparisonContext'
|
||||
|
||||
export function PathSelector() {
|
||||
const {setBasePath, isLoading, error} = useComparison();
|
||||
const [path, setPath] = useState('');
|
||||
|
||||
const handleSubmit = async (e: React.FormEvent) => {
|
||||
e.preventDefault();
|
||||
if (path.trim()) {
|
||||
await setBasePath(path.trim());
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="bg-gray-800 p-6 rounded-lg shadow-lg max-w-xl w-full mx-4">
|
||||
<h2 className="text-xl font-semibold text-gray-200 mb-4">
|
||||
Enter Comparison Path
|
||||
</h2>
|
||||
|
||||
<form onSubmit={handleSubmit} className="space-y-4">
|
||||
<div>
|
||||
<label
|
||||
htmlFor="path"
|
||||
className="block text-sm font-medium text-gray-300 mb-2"
|
||||
>
|
||||
Base Path
|
||||
</label>
|
||||
<input
|
||||
type="text"
|
||||
id="path"
|
||||
value={path}
|
||||
onChange={(e) => setPath(e.target.value)}
|
||||
placeholder="/path/to/comparison/directory"
|
||||
className="w-full px-4 py-2 bg-gray-700 border border-gray-600 rounded-md
|
||||
text-gray-200 placeholder-gray-400 focus:outline-none focus:ring-2
|
||||
focus:ring-blue-500 focus:border-transparent"
|
||||
disabled={isLoading}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{error && (
|
||||
<div className="text-red-400 text-sm">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<button
|
||||
type="submit"
|
||||
disabled={isLoading || !path.trim()}
|
||||
className="w-full px-4 py-2 bg-blue-600 text-white rounded-md
|
||||
hover:bg-blue-700 focus:outline-none focus:ring-2
|
||||
focus:ring-blue-500 focus:ring-offset-2 focus:ring-offset-gray-800
|
||||
disabled:opacity-50 disabled:cursor-not-allowed"
|
||||
>
|
||||
{isLoading ? 'Loading...' : 'Load Comparisons'}
|
||||
</button>
|
||||
</form>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
3
frontend/src/config/env.ts
Normal file
3
frontend/src/config/env.ts
Normal file
@@ -0,0 +1,3 @@
|
||||
export const env = {
|
||||
API_URL: process.env.NEXT_PUBLIC_API_URL || 'http://localhost:2000'
|
||||
} as const;
|
||||
152
frontend/src/contexts/ComparisonContext.tsx
Normal file
152
frontend/src/contexts/ComparisonContext.tsx
Normal file
@@ -0,0 +1,152 @@
|
||||
'use client'
|
||||
|
||||
import {createContext, useCallback, useContext, useState} from 'react';
|
||||
import {env} from '@/config/env';
|
||||
|
||||
import type {AvailableConfigs, ComparisonContextType, ComparisonData, ComparisonState} from '@/types/comparison';
|
||||
|
||||
const ComparisonContext = createContext<ComparisonContextType | undefined>(undefined);
|
||||
|
||||
export function ComparisonProvider({children}: { children: React.ReactNode }) {
|
||||
// Our state now needs to include the configId we get from registration
|
||||
const [state, setState] = useState<ComparisonState>({
|
||||
basePath: null,
|
||||
configId: null, // Add this to track our registered configuration
|
||||
availableConfigs: [],
|
||||
currentConfig: null,
|
||||
currentPairIndex: 0,
|
||||
comparisonData: null,
|
||||
isLoading: false,
|
||||
error: null,
|
||||
});
|
||||
|
||||
// First step: Register the path and get a configuration ID
|
||||
const setBasePath = useCallback(async (path: string) => {
|
||||
setState(prev => ({...prev, isLoading: true, error: null}));
|
||||
|
||||
try {
|
||||
// Register the path first to get our configId
|
||||
const registerResponse = await fetch(`${env.API_URL}/api/v1/comparison/register`, {
|
||||
method: 'POST',
|
||||
headers: {'Content-Type': 'application/json'},
|
||||
body: JSON.stringify({path})
|
||||
});
|
||||
|
||||
if (!registerResponse.ok) throw new Error('Failed to register path');
|
||||
|
||||
const {config_id} = await registerResponse.json();
|
||||
|
||||
// After getting the configId, fetch available configurations
|
||||
const configsResponse = await fetch(`${env.API_URL}/api/v1/comparison/${config_id}/available`);
|
||||
if (!configsResponse.ok) throw new Error('Failed to fetch configurations');
|
||||
|
||||
const data: AvailableConfigs = await configsResponse.json();
|
||||
|
||||
setState(prev => ({
|
||||
...prev,
|
||||
basePath: path,
|
||||
configId: config_id,
|
||||
availableConfigs: data.configs,
|
||||
isLoading: false
|
||||
}));
|
||||
} catch (error) {
|
||||
setState(prev => ({
|
||||
...prev,
|
||||
error: error instanceof Error ? error.message : 'Unknown error',
|
||||
isLoading: false
|
||||
}));
|
||||
}
|
||||
}, []);
|
||||
|
||||
// Load a specific configuration using our new GET endpoint
|
||||
const loadConfig = useCallback(async (configName: string) => {
|
||||
if (!state.configId) return;
|
||||
|
||||
setState(prev => ({...prev, isLoading: true, error: null}));
|
||||
|
||||
try {
|
||||
// Use the new GET endpoint structure
|
||||
const response = await fetch(
|
||||
`${env.API_URL}/api/v1/comparison/${state.configId}/${configName}`
|
||||
);
|
||||
|
||||
if (!response.ok) throw new Error('Failed to fetch configuration data');
|
||||
|
||||
const data: ComparisonData = await response.json();
|
||||
setState(prev => ({
|
||||
...prev,
|
||||
currentConfig: configName,
|
||||
comparisonData: data,
|
||||
currentPairIndex: 0,
|
||||
isLoading: false
|
||||
}));
|
||||
} catch (error) {
|
||||
setState(prev => ({
|
||||
...prev,
|
||||
error: error instanceof Error ? error.message : 'Unknown error',
|
||||
isLoading: false
|
||||
}));
|
||||
}
|
||||
}, [state.configId]);
|
||||
|
||||
// Helper to construct image URLs using our new endpoint structure
|
||||
const getImageUrl = useCallback((model: { config: string, model: string, path: string }) => {
|
||||
if (!state.configId) return '';
|
||||
const filename = model.path.split("/").slice(-1)
|
||||
return `${env.API_URL}/api/v1/comparison/image/${state.configId}/${model.config}/${model.model}/${filename}`;
|
||||
}, [state.configId]);
|
||||
|
||||
// Navigation helpers remain the same
|
||||
const nextPair = useCallback(() => {
|
||||
if (!state.comparisonData?.pairs.length) return;
|
||||
setState(prev => ({
|
||||
...prev,
|
||||
currentPairIndex: (prev.currentPairIndex + 1) % prev.comparisonData!.pairs.length
|
||||
}));
|
||||
}, [state.comparisonData?.pairs.length]);
|
||||
|
||||
const previousPair = useCallback(() => {
|
||||
if (!state.comparisonData?.pairs.length) return;
|
||||
setState(prev => ({
|
||||
...prev,
|
||||
currentPairIndex: (prev.currentPairIndex - 1 + prev.comparisonData!.pairs.length) % prev.comparisonData!.pairs.length
|
||||
}));
|
||||
}, [state.comparisonData?.pairs.length]);
|
||||
|
||||
const goToPair = useCallback((index: number) => {
|
||||
if (!state.comparisonData?.pairs.length) return;
|
||||
if (index >= 0 && index < state.comparisonData.pairs.length) {
|
||||
setState(prev => ({...prev, currentPairIndex: index}));
|
||||
}
|
||||
}, [state.comparisonData?.pairs.length]);
|
||||
|
||||
const getCurrentPair = useCallback(() => {
|
||||
if (!state.comparisonData?.pairs.length) return null;
|
||||
return state.comparisonData.pairs[state.currentPairIndex];
|
||||
}, [state.comparisonData, state.currentPairIndex]);
|
||||
|
||||
const value: ComparisonContextType = {
|
||||
...state,
|
||||
nextPair,
|
||||
previousPair,
|
||||
goToPair,
|
||||
setBasePath,
|
||||
loadConfig,
|
||||
getCurrentPair,
|
||||
getImageUrl, // Add this to help components construct image URLs
|
||||
};
|
||||
|
||||
return (
|
||||
<ComparisonContext.Provider value={value}>
|
||||
{children}
|
||||
</ComparisonContext.Provider>
|
||||
);
|
||||
}
|
||||
|
||||
export function useComparison() {
|
||||
const context = useContext(ComparisonContext);
|
||||
if (context === undefined) {
|
||||
throw new Error('useComparison must be used within a ComparisonProvider');
|
||||
}
|
||||
return context;
|
||||
}
|
||||
116
frontend/src/types/comparison.ts
Normal file
116
frontend/src/types/comparison.ts
Normal file
@@ -0,0 +1,116 @@
|
||||
/**
|
||||
* Represents a single image in the comparison system.
|
||||
* This includes all metadata about the image and its generation parameters.
|
||||
*/
|
||||
interface ComparisonImage {
|
||||
// Basic file information
|
||||
path: string; // Full path to the image file
|
||||
model: string; // Model name (e.g., 'flux_dev', 'ovs_bangel_001_000005000')
|
||||
|
||||
// Classification information
|
||||
config: string; // Configuration type (e.g., 'cloth_lora', 'identity_lora', 'dual_lora')
|
||||
prompt_index: number; // Index of the prompt used for generation
|
||||
seed: number; // Seed used for generation
|
||||
|
||||
// LoRA information - optional as not all configs use both
|
||||
lora1?: string; // First LoRA name (or single LoRA in non-dual cases)
|
||||
lora2?: string; // Second LoRA name (only for dual_lora config)
|
||||
|
||||
// Generation parameters
|
||||
prompt: string; // The actual prompt text used to generate this image
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents a pair of images to be compared.
|
||||
* Contains both images and their shared generation parameters.
|
||||
*/
|
||||
interface ComparisonPair {
|
||||
model1: ComparisonImage; // First model's image and metadata
|
||||
model2: ComparisonImage; // Second model's image and metadata
|
||||
|
||||
// Shared parameters for easy filtering and organization
|
||||
config: string; // The configuration type for this pair
|
||||
prompt_index: number; // Index of the shared prompt
|
||||
seed: number; // Shared seed used for both generations
|
||||
prompt: string; // The full prompt text used for both images
|
||||
}
|
||||
|
||||
/**
|
||||
* Contains all data needed for the comparison interface.
|
||||
* Provides both the comparison pairs and the metadata needed for navigation and filtering.
|
||||
*/
|
||||
interface ComparisonData {
|
||||
// Available configuration options
|
||||
configs: string[]; // List of all configuration types (e.g., ['cloth_lora', 'identity_lora', 'dual_lora'])
|
||||
|
||||
// Mapping of prompts per configuration
|
||||
prompts: Record<string, string[]>; // Example: { 'cloth_lora': ['prompt1', 'prompt2', ...] }
|
||||
|
||||
// Available seeds for filtering
|
||||
seeds: number[]; // List of all seeds used in the comparisons
|
||||
|
||||
// The actual comparison data
|
||||
pairs: ComparisonPair[]; // All comparison pairs available
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents the filters that can be applied to the comparison view
|
||||
*/
|
||||
interface ComparisonFilters {
|
||||
config?: string; // Selected configuration type
|
||||
promptIndex?: number; // Selected prompt index
|
||||
seed?: number; // Selected seed
|
||||
}
|
||||
|
||||
// src/types/comparison.ts
|
||||
// Add these to your existing types
|
||||
|
||||
// Represents the metadata about an available configuration
|
||||
interface ConfigurationInfo {
|
||||
name: string;
|
||||
model_count: number;
|
||||
prompt_count: number;
|
||||
seed_count: number;
|
||||
}
|
||||
|
||||
// Response from the fetchConfigs endpoint
|
||||
interface AvailableConfigs {
|
||||
base_path: string;
|
||||
configs: ConfigurationInfo[];
|
||||
}
|
||||
|
||||
// Represents the current state of comparison viewing
|
||||
interface ComparisonState {
|
||||
basePath: string | null;
|
||||
configId: string | null; // Add this
|
||||
availableConfigs: ConfigurationInfo[];
|
||||
currentConfig: string | null;
|
||||
currentPairIndex: number;
|
||||
comparisonData: ComparisonData | null;
|
||||
isLoading: boolean;
|
||||
error: string | null;
|
||||
}
|
||||
|
||||
// Actions that can be performed through the context
|
||||
interface ComparisonContextType extends ComparisonState {
|
||||
nextPair: () => void;
|
||||
previousPair: () => void;
|
||||
goToPair: (index: number) => void;
|
||||
setBasePath: (path: string) => Promise<void>;
|
||||
loadConfig: (configName: string) => Promise<void>;
|
||||
getCurrentPair: () => ComparisonPair | null;
|
||||
getImageUrl: (model: { config: string, model: string, filename: string }) => string; // Add this
|
||||
}
|
||||
|
||||
// Export all types
|
||||
export type {
|
||||
ComparisonImage,
|
||||
ComparisonPair,
|
||||
ComparisonData,
|
||||
ComparisonFilters,
|
||||
ComparisonContextType,
|
||||
ComparisonState,
|
||||
AvailableConfigs,
|
||||
ConfigurationInfo,
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user