# app/services/comparison_service.py import json from datetime import datetime from pathlib import Path from typing import Dict, Optional from app.models.comparison import ConfigInfo, AvailableConfigs, ComparisonPair, ComparisonData, ComparisonImage class ComparisonService: def __init__(self): # Cache structure to store configuration information self.configs: Dict[str, dict] = {} # id -> config info self.paths: Dict[str, Path] = {} # id -> base path self.access_times: Dict[str, datetime] = {} # id -> last access time def generate_config_id(self) -> str: """Generate a unique identifier for a configuration.""" # Simple timestamp-based ID, could be made more sophisticated return datetime.now().strftime('%Y%m%d_%H%M%S') def _get_config_data(self, config_dir: Path) -> dict: """Get config data with caching.""" config_path = config_dir / f"config_{config_dir.name}.json" cache_key = str(config_path) if cache_key not in self.configs: if not config_path.exists(): raise ValueError(f"Configuration file not found: {config_path}") with open(config_path) as f: self.configs[cache_key] = json.load(f) return self.configs[cache_key] def register_config(self, base_path: str) -> str: """ Register a new configuration base path and return its ID. This is called when a user first submits a path. """ config_id = self.generate_config_id() self.paths[config_id] = Path(base_path) self.access_times[config_id] = datetime.now() return config_id def get_base_path(self, config_id: str) -> Optional[Path]: """ Retrieve the base path for a given configuration ID. Updates the last access time. """ if config_id in self.paths: self.access_times[config_id] = datetime.now() return self.paths[config_id] return None def clean_old_configs(self, max_age_hours: int = 72): """ Clean up configurations that haven't been accessed in a while. This helps manage memory usage. """ current_time = datetime.now() expired_ids = [ config_id for config_id, access_time in self.access_times.items() if (current_time - access_time).total_seconds() > max_age_hours * 3600 ] for config_id in expired_ids: self.paths.pop(config_id, None) self.configs.pop(config_id, None) self.access_times.pop(config_id, None) def get_available_configs(self, base_path: str) -> AvailableConfigs: base_path = Path(base_path) configs = [] for config_dir in base_path.iterdir(): if not config_dir.is_dir(): continue try: config_data = self._get_config_data(config_dir) model_count = sum(1 for x in config_dir.iterdir() if x.is_dir()) configs.append(ConfigInfo( name=config_dir.name, model_count=model_count, prompt_count=len(config_data.get('prompts', [])), seed_count=len(config_data.get('seeds', [])) )) except ValueError: # Skip this config if we can't read its data continue return AvailableConfigs( base_path=str(base_path), configs=sorted(configs, key=lambda x: x.name) ) def load_config_data(self, base_path: str, config_name: str) -> ComparisonData: """Load comparison data using the cached base path.""" # base_path = self.get_base_path(config_id) # if not base_path: # raise ValueError(f"Configuration ID {config_id} not found") # config_dir = base_path / config_name # # if not config_dir.is_dir(): # raise ValueError(f"Configuration '{config_name}' not found in {base_path}") # # # Load the configuration file # config_file = config_dir / f"config_{config_name}.json" # if not config_file.exists(): # raise ValueError(f"Configuration file not found for {config_name}") # # with open(config_file) as f: # config_data = json.load(f) base_path = Path(base_path) config_dir = base_path / config_name if not config_dir.is_dir(): raise ValueError(f"Configuration '{config_name}' not found in {base_path}") # Use our cached config data instead of reading directly try: config_data = self._get_config_data(config_dir) except ValueError as e: raise ValueError(f"Failed to load configuration data: {str(e)}") # Process model directories models = [d.name for d in config_dir.iterdir() if d.is_dir()] pairs = [] # Build comparison pairs for this configuration for model in models: model_path = config_dir / model images = list(model_path.glob("*.png")) for img in images: img_data = self.parse_image_path(img, config_dir) # Find matching image in other model other_model = [m for m in models if m != model][0] other_path = model_path.parent / other_model / img.name if other_path.exists(): other_data = self.parse_image_path(other_path, config_dir) pairs.append(ComparisonPair( model1=img_data, model2=other_data, config=config_name, prompt_index=img_data.prompt_index, seed=img_data.seed, prompt=img_data.prompt or "" )) return ComparisonData( configs=[config_name], prompts={config_name: config_data['prompts']}, seeds=sorted(config_data['seeds']), pairs=pairs ) def parse_image_path(self, path: Path, config_dir: Path) -> ComparisonImage: """ Parse an image filename to extract metadata about the image. The method handles both single and dual LoRA naming patterns: - Single LoRA: lora_prompt_0_seed_42.png - Dual LoRA: lora1_lora2_prompt_0_seed_42.png """ filename = Path(path).stem parts = filename.split('_') prompt_idx = parts.index('prompt') seed_idx = parts.index('seed') prompt_index = int(parts[prompt_idx + 1]) seed = int(parts[seed_idx + 1]) # Instead of reading the file directly, use our cached method prompt = None try: config_data = self._get_config_data(config_dir) prompts = config_data.get('prompts', []) if 0 <= prompt_index < len(prompts): prompt = prompts[prompt_index] except ValueError: # If we can't get the config data, we'll continue without the prompt pass # config_file = config_dir / f"config_{config_dir.name}.json" # if config_file.exists(): # with open(config_file) as f: # config_data = json.load(f) # prompts = config_data.get('prompts', []) # if 0 <= prompt_index < len(prompts): # prompt = prompts[prompt_index] # Determine if this is a dual LoRA setup by checking parts before 'prompt' lora_parts = parts[:prompt_idx] if len(lora_parts) > 1: # Dual LoRA case return ComparisonImage( path=str(path), model=path.parent.name, config=config_dir.name, prompt_index=prompt_index, seed=seed, lora1=lora_parts[0], lora2=lora_parts[1], prompt=prompt ) else: # Single LoRA case return ComparisonImage( path=str(path), model=path.parent.name, config=config_dir.name, prompt_index=prompt_index, seed=seed, lora1=lora_parts[0], prompt=prompt )