Add comparison support

Signed-off-by: Felipe Cardoso <felipe.cardoso@hotmail.it>
This commit is contained in:
2025-01-31 13:20:46 +01:00
parent 3a79065163
commit 9803e32f66
13 changed files with 974 additions and 2 deletions

View File

@@ -0,0 +1,68 @@
# app/api/routes/comparison.py
from fastapi import APIRouter, HTTPException
from starlette.responses import FileResponse
from app.models.comparison import PathRequest
from app.services.comparison_service import ComparisonService
router = APIRouter()
comparison_service = ComparisonService()
@router.post("/register")
async def register_comparison_path(request: PathRequest):
"""Register a new comparison path and get its ID."""
config_id = comparison_service.register_config(request.path)
return {"config_id": config_id}
@router.get("/image/{config_id}/{config_name}/{model_name}/{filename}")
async def get_comparison_image(config_id: str, config_name: str, model_name: str, filename: str):
"""Serve image files using the cached base path."""
base_path = comparison_service.get_base_path(config_id)
if not base_path:
raise HTTPException(status_code=404, detail="Configuration not found")
try:
full_path = base_path / config_name / model_name / filename
if not full_path.exists() or not full_path.is_file():
raise HTTPException(status_code=404, detail="Image not found")
return FileResponse(full_path)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/{config_id}/available")
async def fetch_available_configs(config_id: str):
"""Fetch available configs."""
base_path = comparison_service.get_base_path(config_id)
if not base_path:
raise HTTPException(status_code=404, detail="Configuration not found")
try:
return comparison_service.get_available_configs(str(base_path))
except ValueError as e:
# Convert ValueError from the service into a proper HTTP error
raise HTTPException(status_code=404, detail=str(e))
@router.get("/{config_id}/{config_name}")
async def fetch_config(config_id: str, config_name: str):
"""
Fetch detailed comparison data for a specific configuration.
Parameters:
config_id: The identifier returned from the initial path registration
config_name: The name of the specific configuration to load (e.g. 'cloth_lora')
"""
base_path = comparison_service.get_base_path(config_id)
if not base_path:
raise HTTPException(status_code=404, detail="Configuration not found")
try:
return comparison_service.load_config_data(str(base_path), config_name)
except ValueError as e:
# Convert ValueError from the service into a proper HTTP error
raise HTTPException(status_code=404, detail=str(e))

View File

@@ -6,7 +6,7 @@ import psutil
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.api.routes import training, samples, config
from app.api.routes import training, samples, config, comparison
from app.core.config import settings
from app.services.config_manager import ConfigManager
from app.services.sample_manager import SampleManager
@@ -105,6 +105,7 @@ async def shutdown_event():
app.include_router(training.router, prefix=f"{settings.API_VER_STR}/training", tags=["training"])
app.include_router(samples.router, prefix=f"{settings.API_VER_STR}/samples", tags=["samples"])
app.include_router(config.router, prefix=f"{settings.API_VER_STR}/config", tags=["config"])
app.include_router(comparison.router, prefix=f"{settings.API_VER_STR}/comparison", tags=["comparison"])
@app.get("/")

View File

@@ -0,0 +1,55 @@
from typing import List, Optional, Dict
from pydantic import BaseModel
class ComparisonImage(BaseModel):
path: str
model: str
config: str
prompt_index: int
seed: int
lora1: Optional[str] = None
lora2: Optional[str] = None
prompt: Optional[str] = None # Adding prompt field
class ComparisonPair(BaseModel):
model1: ComparisonImage
model2: ComparisonImage
config: str
prompt_index: int
seed: int
prompt: str # The prompt used for both images
class ComparisonData(BaseModel):
configs: List[str] # Available config types (cloth_lora, identity_lora, dual_lora)
prompts: Dict[str, List[str]] # Mapping of config -> list of prompts
seeds: List[int] # All available seeds
pairs: List[ComparisonPair] # All comparison pairs with their prompts
class PathRequest(BaseModel):
"""Request model for providing the base comparison path"""
path: str
class ConfigRequest(BaseModel):
"""Request model for fetching specific configuration data"""
path: str
config_name: str
class ConfigInfo(BaseModel):
"""Basic information about an available configuration"""
name: str
model_count: int
prompt_count: int
seed_count: int
class AvailableConfigs(BaseModel):
"""Response model for the fetchConfigs endpoint"""
base_path: str
configs: List[ConfigInfo]

View File

@@ -0,0 +1,221 @@
# app/services/comparison_service.py
import json
from datetime import datetime
from pathlib import Path
from typing import Dict, Optional
from app.models.comparison import ConfigInfo, AvailableConfigs, ComparisonPair, ComparisonData, ComparisonImage
class ComparisonService:
def __init__(self):
# Cache structure to store configuration information
self.configs: Dict[str, dict] = {} # id -> config info
self.paths: Dict[str, Path] = {} # id -> base path
self.access_times: Dict[str, datetime] = {} # id -> last access time
def generate_config_id(self) -> str:
"""Generate a unique identifier for a configuration."""
# Simple timestamp-based ID, could be made more sophisticated
return datetime.now().strftime('%Y%m%d_%H%M%S')
def _get_config_data(self, config_dir: Path) -> dict:
"""Get config data with caching."""
config_path = config_dir / f"config_{config_dir.name}.json"
cache_key = str(config_path)
if cache_key not in self.configs:
if not config_path.exists():
raise ValueError(f"Configuration file not found: {config_path}")
with open(config_path) as f:
self.configs[cache_key] = json.load(f)
return self.configs[cache_key]
def register_config(self, base_path: str) -> str:
"""
Register a new configuration base path and return its ID.
This is called when a user first submits a path.
"""
config_id = self.generate_config_id()
self.paths[config_id] = Path(base_path)
self.access_times[config_id] = datetime.now()
return config_id
def get_base_path(self, config_id: str) -> Optional[Path]:
"""
Retrieve the base path for a given configuration ID.
Updates the last access time.
"""
if config_id in self.paths:
self.access_times[config_id] = datetime.now()
return self.paths[config_id]
return None
def clean_old_configs(self, max_age_hours: int = 72):
"""
Clean up configurations that haven't been accessed in a while.
This helps manage memory usage.
"""
current_time = datetime.now()
expired_ids = [
config_id for config_id, access_time in self.access_times.items()
if (current_time - access_time).total_seconds() > max_age_hours * 3600
]
for config_id in expired_ids:
self.paths.pop(config_id, None)
self.configs.pop(config_id, None)
self.access_times.pop(config_id, None)
def get_available_configs(self, base_path: str) -> AvailableConfigs:
base_path = Path(base_path)
configs = []
for config_dir in base_path.iterdir():
if not config_dir.is_dir():
continue
try:
config_data = self._get_config_data(config_dir)
model_count = sum(1 for x in config_dir.iterdir() if x.is_dir())
configs.append(ConfigInfo(
name=config_dir.name,
model_count=model_count,
prompt_count=len(config_data.get('prompts', [])),
seed_count=len(config_data.get('seeds', []))
))
except ValueError:
# Skip this config if we can't read its data
continue
return AvailableConfigs(
base_path=str(base_path),
configs=sorted(configs, key=lambda x: x.name)
)
def load_config_data(self, base_path: str, config_name: str) -> ComparisonData:
"""Load comparison data using the cached base path."""
# base_path = self.get_base_path(config_id)
# if not base_path:
# raise ValueError(f"Configuration ID {config_id} not found")
# config_dir = base_path / config_name
#
# if not config_dir.is_dir():
# raise ValueError(f"Configuration '{config_name}' not found in {base_path}")
#
# # Load the configuration file
# config_file = config_dir / f"config_{config_name}.json"
# if not config_file.exists():
# raise ValueError(f"Configuration file not found for {config_name}")
#
# with open(config_file) as f:
# config_data = json.load(f)
base_path = Path(base_path)
config_dir = base_path / config_name
if not config_dir.is_dir():
raise ValueError(f"Configuration '{config_name}' not found in {base_path}")
# Use our cached config data instead of reading directly
try:
config_data = self._get_config_data(config_dir)
except ValueError as e:
raise ValueError(f"Failed to load configuration data: {str(e)}")
# Process model directories
models = [d.name for d in config_dir.iterdir() if d.is_dir()]
pairs = []
# Build comparison pairs for this configuration
for model in models:
model_path = config_dir / model
images = list(model_path.glob("*.png"))
for img in images:
img_data = self.parse_image_path(img, config_dir)
# Find matching image in other model
other_model = [m for m in models if m != model][0]
other_path = model_path.parent / other_model / img.name
if other_path.exists():
other_data = self.parse_image_path(other_path, config_dir)
pairs.append(ComparisonPair(
model1=img_data,
model2=other_data,
config=config_name,
prompt_index=img_data.prompt_index,
seed=img_data.seed,
prompt=img_data.prompt or ""
))
return ComparisonData(
configs=[config_name],
prompts={config_name: config_data['prompts']},
seeds=sorted(config_data['seeds']),
pairs=pairs
)
def parse_image_path(self, path: Path, config_dir: Path) -> ComparisonImage:
"""
Parse an image filename to extract metadata about the image.
The method handles both single and dual LoRA naming patterns:
- Single LoRA: lora_prompt_0_seed_42.png
- Dual LoRA: lora1_lora2_prompt_0_seed_42.png
"""
filename = Path(path).stem
parts = filename.split('_')
prompt_idx = parts.index('prompt')
seed_idx = parts.index('seed')
prompt_index = int(parts[prompt_idx + 1])
seed = int(parts[seed_idx + 1])
# Instead of reading the file directly, use our cached method
prompt = None
try:
config_data = self._get_config_data(config_dir)
prompts = config_data.get('prompts', [])
if 0 <= prompt_index < len(prompts):
prompt = prompts[prompt_index]
except ValueError:
# If we can't get the config data, we'll continue without the prompt
pass
# config_file = config_dir / f"config_{config_dir.name}.json"
# if config_file.exists():
# with open(config_file) as f:
# config_data = json.load(f)
# prompts = config_data.get('prompts', [])
# if 0 <= prompt_index < len(prompts):
# prompt = prompts[prompt_index]
# Determine if this is a dual LoRA setup by checking parts before 'prompt'
lora_parts = parts[:prompt_idx]
if len(lora_parts) > 1:
# Dual LoRA case
return ComparisonImage(
path=str(path),
model=path.parent.name,
config=config_dir.name,
prompt_index=prompt_index,
seed=seed,
lora1=lora_parts[0],
lora2=lora_parts[1],
prompt=prompt
)
else:
# Single LoRA case
return ComparisonImage(
path=str(path),
model=path.parent.name,
config=config_dir.name,
prompt_index=prompt_index,
seed=seed,
lora1=lora_parts[0],
prompt=prompt
)

2
backend/run_server.sh Normal file
View File

@@ -0,0 +1,2 @@
#!/bin/bash
uvicorn app.main:app --reload --port 2000