Add comparison support
Signed-off-by: Felipe Cardoso <felipe.cardoso@hotmail.it>
This commit is contained in:
221
backend/app/services/comparison_service.py
Normal file
221
backend/app/services/comparison_service.py
Normal file
@@ -0,0 +1,221 @@
|
||||
# app/services/comparison_service.py
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
||||
from app.models.comparison import ConfigInfo, AvailableConfigs, ComparisonPair, ComparisonData, ComparisonImage
|
||||
|
||||
|
||||
class ComparisonService:
|
||||
def __init__(self):
|
||||
# Cache structure to store configuration information
|
||||
self.configs: Dict[str, dict] = {} # id -> config info
|
||||
self.paths: Dict[str, Path] = {} # id -> base path
|
||||
self.access_times: Dict[str, datetime] = {} # id -> last access time
|
||||
|
||||
def generate_config_id(self) -> str:
|
||||
"""Generate a unique identifier for a configuration."""
|
||||
# Simple timestamp-based ID, could be made more sophisticated
|
||||
return datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
|
||||
def _get_config_data(self, config_dir: Path) -> dict:
|
||||
"""Get config data with caching."""
|
||||
config_path = config_dir / f"config_{config_dir.name}.json"
|
||||
cache_key = str(config_path)
|
||||
|
||||
if cache_key not in self.configs:
|
||||
if not config_path.exists():
|
||||
raise ValueError(f"Configuration file not found: {config_path}")
|
||||
|
||||
with open(config_path) as f:
|
||||
self.configs[cache_key] = json.load(f)
|
||||
|
||||
return self.configs[cache_key]
|
||||
|
||||
def register_config(self, base_path: str) -> str:
|
||||
"""
|
||||
Register a new configuration base path and return its ID.
|
||||
This is called when a user first submits a path.
|
||||
"""
|
||||
config_id = self.generate_config_id()
|
||||
self.paths[config_id] = Path(base_path)
|
||||
self.access_times[config_id] = datetime.now()
|
||||
return config_id
|
||||
|
||||
def get_base_path(self, config_id: str) -> Optional[Path]:
|
||||
"""
|
||||
Retrieve the base path for a given configuration ID.
|
||||
Updates the last access time.
|
||||
"""
|
||||
if config_id in self.paths:
|
||||
self.access_times[config_id] = datetime.now()
|
||||
return self.paths[config_id]
|
||||
return None
|
||||
|
||||
def clean_old_configs(self, max_age_hours: int = 72):
|
||||
"""
|
||||
Clean up configurations that haven't been accessed in a while.
|
||||
This helps manage memory usage.
|
||||
"""
|
||||
current_time = datetime.now()
|
||||
expired_ids = [
|
||||
config_id for config_id, access_time in self.access_times.items()
|
||||
if (current_time - access_time).total_seconds() > max_age_hours * 3600
|
||||
]
|
||||
for config_id in expired_ids:
|
||||
self.paths.pop(config_id, None)
|
||||
self.configs.pop(config_id, None)
|
||||
self.access_times.pop(config_id, None)
|
||||
|
||||
def get_available_configs(self, base_path: str) -> AvailableConfigs:
|
||||
base_path = Path(base_path)
|
||||
configs = []
|
||||
|
||||
for config_dir in base_path.iterdir():
|
||||
if not config_dir.is_dir():
|
||||
continue
|
||||
|
||||
try:
|
||||
config_data = self._get_config_data(config_dir)
|
||||
model_count = sum(1 for x in config_dir.iterdir() if x.is_dir())
|
||||
|
||||
configs.append(ConfigInfo(
|
||||
name=config_dir.name,
|
||||
model_count=model_count,
|
||||
prompt_count=len(config_data.get('prompts', [])),
|
||||
seed_count=len(config_data.get('seeds', []))
|
||||
))
|
||||
except ValueError:
|
||||
# Skip this config if we can't read its data
|
||||
continue
|
||||
|
||||
return AvailableConfigs(
|
||||
base_path=str(base_path),
|
||||
configs=sorted(configs, key=lambda x: x.name)
|
||||
)
|
||||
|
||||
def load_config_data(self, base_path: str, config_name: str) -> ComparisonData:
|
||||
"""Load comparison data using the cached base path."""
|
||||
# base_path = self.get_base_path(config_id)
|
||||
# if not base_path:
|
||||
# raise ValueError(f"Configuration ID {config_id} not found")
|
||||
# config_dir = base_path / config_name
|
||||
#
|
||||
# if not config_dir.is_dir():
|
||||
# raise ValueError(f"Configuration '{config_name}' not found in {base_path}")
|
||||
#
|
||||
# # Load the configuration file
|
||||
# config_file = config_dir / f"config_{config_name}.json"
|
||||
# if not config_file.exists():
|
||||
# raise ValueError(f"Configuration file not found for {config_name}")
|
||||
#
|
||||
# with open(config_file) as f:
|
||||
# config_data = json.load(f)
|
||||
base_path = Path(base_path)
|
||||
config_dir = base_path / config_name
|
||||
|
||||
if not config_dir.is_dir():
|
||||
raise ValueError(f"Configuration '{config_name}' not found in {base_path}")
|
||||
|
||||
# Use our cached config data instead of reading directly
|
||||
try:
|
||||
config_data = self._get_config_data(config_dir)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Failed to load configuration data: {str(e)}")
|
||||
|
||||
# Process model directories
|
||||
models = [d.name for d in config_dir.iterdir() if d.is_dir()]
|
||||
pairs = []
|
||||
|
||||
# Build comparison pairs for this configuration
|
||||
for model in models:
|
||||
model_path = config_dir / model
|
||||
images = list(model_path.glob("*.png"))
|
||||
|
||||
for img in images:
|
||||
img_data = self.parse_image_path(img, config_dir)
|
||||
|
||||
# Find matching image in other model
|
||||
other_model = [m for m in models if m != model][0]
|
||||
other_path = model_path.parent / other_model / img.name
|
||||
|
||||
if other_path.exists():
|
||||
other_data = self.parse_image_path(other_path, config_dir)
|
||||
pairs.append(ComparisonPair(
|
||||
model1=img_data,
|
||||
model2=other_data,
|
||||
config=config_name,
|
||||
prompt_index=img_data.prompt_index,
|
||||
seed=img_data.seed,
|
||||
prompt=img_data.prompt or ""
|
||||
))
|
||||
|
||||
return ComparisonData(
|
||||
configs=[config_name],
|
||||
prompts={config_name: config_data['prompts']},
|
||||
seeds=sorted(config_data['seeds']),
|
||||
pairs=pairs
|
||||
)
|
||||
|
||||
def parse_image_path(self, path: Path, config_dir: Path) -> ComparisonImage:
|
||||
|
||||
"""
|
||||
Parse an image filename to extract metadata about the image.
|
||||
The method handles both single and dual LoRA naming patterns:
|
||||
- Single LoRA: lora_prompt_0_seed_42.png
|
||||
- Dual LoRA: lora1_lora2_prompt_0_seed_42.png
|
||||
"""
|
||||
filename = Path(path).stem
|
||||
parts = filename.split('_')
|
||||
|
||||
prompt_idx = parts.index('prompt')
|
||||
seed_idx = parts.index('seed')
|
||||
|
||||
prompt_index = int(parts[prompt_idx + 1])
|
||||
seed = int(parts[seed_idx + 1])
|
||||
|
||||
# Instead of reading the file directly, use our cached method
|
||||
prompt = None
|
||||
try:
|
||||
config_data = self._get_config_data(config_dir)
|
||||
prompts = config_data.get('prompts', [])
|
||||
if 0 <= prompt_index < len(prompts):
|
||||
prompt = prompts[prompt_index]
|
||||
except ValueError:
|
||||
# If we can't get the config data, we'll continue without the prompt
|
||||
pass
|
||||
# config_file = config_dir / f"config_{config_dir.name}.json"
|
||||
# if config_file.exists():
|
||||
# with open(config_file) as f:
|
||||
# config_data = json.load(f)
|
||||
# prompts = config_data.get('prompts', [])
|
||||
# if 0 <= prompt_index < len(prompts):
|
||||
# prompt = prompts[prompt_index]
|
||||
|
||||
# Determine if this is a dual LoRA setup by checking parts before 'prompt'
|
||||
lora_parts = parts[:prompt_idx]
|
||||
|
||||
if len(lora_parts) > 1:
|
||||
# Dual LoRA case
|
||||
return ComparisonImage(
|
||||
path=str(path),
|
||||
model=path.parent.name,
|
||||
config=config_dir.name,
|
||||
prompt_index=prompt_index,
|
||||
seed=seed,
|
||||
lora1=lora_parts[0],
|
||||
lora2=lora_parts[1],
|
||||
prompt=prompt
|
||||
)
|
||||
else:
|
||||
# Single LoRA case
|
||||
return ComparisonImage(
|
||||
path=str(path),
|
||||
model=path.parent.name,
|
||||
config=config_dir.name,
|
||||
prompt_index=prompt_index,
|
||||
seed=seed,
|
||||
lora1=lora_parts[0],
|
||||
prompt=prompt
|
||||
)
|
||||
Reference in New Issue
Block a user