Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| NOTE: pip install optuna | |
| """ | |
| import dataclasses | |
| import json | |
| import shutil | |
| import tempfile | |
| from pathlib import Path | |
| from typing import Any, Dict, Optional | |
| import jax | |
| import numpy as np | |
| import optuna | |
| import tyro | |
| import yaml | |
| import zea | |
| from keras import ops | |
| from PIL import Image | |
| from zea import init_device, log | |
| from eval import evaluate | |
| from main import init, run | |
| from utils import load_image | |
| def load_images_from_dir(input_folder): | |
| """Load images from directory, similar to main.py implementation.""" | |
| paths = list(Path(input_folder).glob("*.png")) | |
| images = [] | |
| for path in paths: | |
| image = load_image(path) | |
| images.append(image) | |
| if len(images) == 0: | |
| raise ValueError(f"No PNG images found in {input_folder}") | |
| images = ops.stack(images, axis=0) | |
| return images, paths | |
| def save_images_to_temp_dir(images, image_paths, prefix=""): | |
| """Save numpy arrays as PNG images to a temporary directory.""" | |
| temp_dir = tempfile.mkdtemp(prefix=prefix) | |
| temp_dir_path = Path(temp_dir) | |
| for i, (img, path) in enumerate(zip(images, image_paths)): | |
| # Get the filename from the original path | |
| filename = Path(path).name | |
| # Convert image to uint8 if needed | |
| if img.dtype != np.uint8: | |
| # Assume image is in [0, 1] range and convert to [0, 255] | |
| if img.max() <= 1.0: | |
| img = (img * 255).astype(np.uint8) | |
| else: | |
| img = img.astype(np.uint8) | |
| # Ensure image is 2D or 3D | |
| if len(img.shape) == 3 and img.shape[-1] == 1: | |
| img = img.squeeze(-1) | |
| # Save as PNG | |
| img_pil = Image.fromarray(img) | |
| save_path = temp_dir_path / filename | |
| img_pil.save(save_path) | |
| return str(temp_dir_path) | |
| class SweeperConfig: | |
| """Configuration for hyperparameter sweeping with Optuna.""" | |
| # Required paths - no defaults | |
| input_image_dir: str # Path to input hazy images | |
| roi_folder: str # Path to ROI mask images | |
| reference_folder: str # Path to reference/ground truth images | |
| base_config_path: str = "configs/semantic_dps.yaml" | |
| # Base configuration | |
| method: str = "semantic_dps" # Which method to optimize | |
| broad_sweep: bool = False # Choose between broad or narrow sweep | |
| # Optuna settings | |
| study_name: str = "dehaze_optimization" | |
| storage: Optional[str] = None # e.g., "sqlite:///dehaze_study.db" for persistence | |
| n_trials: int = 100 | |
| # Optimization settings | |
| objective_metric: str = "final_score" # Which metric to optimize | |
| direction: str = "maximize" # "maximize" or "minimize" | |
| # Output settings | |
| output_dir: str = "sweep_results" | |
| # Evaluation settings | |
| skip_fid: bool = False | |
| # Device configuration | |
| device: str = "auto:1" | |
| # Pruning settings | |
| enable_pruning: bool = True | |
| pruner_type: str = "median" # "median", "hyperband", or "none" | |
| class OptunaObjective: | |
| """Optuna objective function for hyperparameter optimization.""" | |
| def __init__(self, config: SweeperConfig): | |
| self.config = config | |
| self.base_config = self._load_base_config() | |
| self.hazy_images, self.image_paths = load_images_from_dir( | |
| config.input_image_dir | |
| ) | |
| # Initialize device | |
| init_device(config.device) | |
| # Initialize the diffusion model once | |
| self.diffusion_model = init(self.base_config) | |
| def _load_base_config(self): | |
| """Load base configuration from YAML file.""" | |
| with open(self.config.base_config_path, "r") as f: | |
| config_dict = yaml.safe_load(f) | |
| return zea.Config(**config_dict) | |
| def _create_trial_params(self, trial: optuna.Trial) -> Dict[str, Any]: | |
| """Create trial parameters by suggesting hyperparameters.""" | |
| params = { | |
| "guidance_kwargs": { | |
| "omega": trial.suggest_float("omega", 0.5, 50.0, log=True), | |
| "omega_vent": trial.suggest_float("omega_vent", 0.0001, 50.0, log=True), | |
| "omega_sept": trial.suggest_float("omega_sept", 0.1, 50.0, log=True), | |
| "eta": trial.suggest_float("eta", 0.001, 1.0, log=True), | |
| "smooth_l1_beta": trial.suggest_float( | |
| "smooth_l1_beta", 0.1, 10.0, log=True | |
| ), | |
| }, | |
| "skeleton_params": { | |
| "sigma_pre": trial.suggest_float("skeleton_sigma_pre", 0.0, 10.0), | |
| "sigma_post": trial.suggest_float("skeleton_sigma_post", 0.0, 10.0), | |
| "threshold": trial.suggest_float("skeleton_threshold", 0.0, 1.0), | |
| }, | |
| "mask_params": { | |
| "threshold": trial.suggest_float("mask_threshold", 0.0, 1.0), | |
| "sigma": trial.suggest_float("mask_sigma", 0.0, 10.0), | |
| }, | |
| } | |
| # Add base config parameters that aren't being optimized | |
| if hasattr(self.base_config, "params"): | |
| base_params = self.base_config.params | |
| for key, value in base_params.items(): | |
| if key not in params: | |
| params[key] = value | |
| return params | |
| def __call__(self, trial: optuna.Trial) -> float: | |
| """Optuna objective function.""" | |
| # Suggest hyperparameters for this trial | |
| params = self._create_trial_params(trial) | |
| # Create seed for reproducibility | |
| seed = jax.random.PRNGKey(self.base_config.seed + trial.number) | |
| # Run the semantic DPS method | |
| try: | |
| hazy_images, pred_tissue_images, pred_haze_images, masks = run( | |
| hazy_images=self.hazy_images, | |
| diffusion_model=self.diffusion_model, | |
| seed=seed, | |
| **params, | |
| ) | |
| except Exception as e: | |
| log.error(f"Error during model inference: {e}") | |
| return 0.0 | |
| # Convert tensors to numpy arrays if needed | |
| if hasattr(pred_tissue_images, "numpy"): | |
| pred_tissue_images = pred_tissue_images.numpy() | |
| # Initialize temp directory | |
| pred_tissue_temp_dir = None | |
| try: | |
| # Save predicted tissue images to temp directory | |
| pred_tissue_temp_dir = save_images_to_temp_dir( | |
| pred_tissue_images, self.image_paths, prefix="pred_tissue_" | |
| ) | |
| # Run evaluation using the updated evaluate function | |
| results = evaluate( | |
| folder=pred_tissue_temp_dir, | |
| noisy_folder=self.config.input_image_dir, | |
| roi_folder=self.config.roi_folder, | |
| reference_folder=self.config.reference_folder, | |
| ) | |
| objective_value = results[self.config.objective_metric] | |
| except Exception as e: | |
| log.error(f"Error during evaluation: {e}") | |
| objective_value = 0.0 | |
| finally: | |
| # Clean up temporary directory | |
| if pred_tissue_temp_dir and Path(pred_tissue_temp_dir).exists(): | |
| try: | |
| shutil.rmtree(pred_tissue_temp_dir) | |
| except Exception as e: | |
| log.warning( | |
| f"Failed to clean up temp directory {pred_tissue_temp_dir}: {e}" | |
| ) | |
| # Log intermediate results for potential pruning | |
| trial.report(objective_value, step=0) | |
| # Check if trial should be pruned | |
| if trial.should_prune(): | |
| raise optuna.TrialPruned() | |
| # Store hyperparameters as user attributes | |
| for key, value in params.items(): | |
| if isinstance(value, dict): | |
| for subkey, subvalue in value.items(): | |
| trial.set_user_attr(f"{key}_{subkey}", subvalue) | |
| else: | |
| trial.set_user_attr(key, value) | |
| log.info( | |
| f"Trial {trial.number}: {self.config.objective_metric} = {objective_value:.4f}" | |
| ) | |
| return objective_value | |
| def create_pruner(pruner_type: str) -> optuna.pruners.BasePruner: | |
| """Create an Optuna pruner based on the specified type.""" | |
| if pruner_type == "median": | |
| return optuna.pruners.MedianPruner( | |
| n_startup_trials=5, n_warmup_steps=0, interval_steps=1 | |
| ) | |
| elif pruner_type == "hyperband": | |
| return optuna.pruners.HyperbandPruner( | |
| min_resource=1, max_resource=100, reduction_factor=3 | |
| ) | |
| else: | |
| return optuna.pruners.NopPruner() | |
| def run_optimization(config: SweeperConfig): | |
| """Run hyperparameter optimization using Optuna.""" | |
| # Create pruner | |
| pruner = create_pruner(config.pruner_type) if config.enable_pruning else None | |
| # Create or load study | |
| study = optuna.create_study( | |
| study_name=config.study_name, | |
| storage=config.storage, | |
| direction=config.direction, | |
| pruner=pruner, | |
| load_if_exists=True, | |
| ) | |
| log.info(f"Starting optimization for method: {config.method}") | |
| log.info(f"Study name: {config.study_name}") | |
| log.info(f"Number of trials: {config.n_trials}") | |
| log.info(f"Objective metric: {config.objective_metric} ({config.direction})") | |
| # Create objective function | |
| objective = OptunaObjective(config) | |
| # Run optimization | |
| study.optimize(objective, n_trials=config.n_trials) | |
| # Save results | |
| output_dir = Path(config.output_dir) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| # Save best trial info | |
| best_trial = study.best_trial | |
| best_results = { | |
| "best_value": best_trial.value, | |
| "best_params": best_trial.params, | |
| "best_user_attrs": best_trial.user_attrs, | |
| "study_stats": { | |
| "n_trials": len(study.trials), | |
| "n_complete_trials": len( | |
| [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE] | |
| ), | |
| "n_pruned_trials": len( | |
| [t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED] | |
| ), | |
| "n_failed_trials": len( | |
| [t for t in study.trials if t.state == optuna.trial.TrialState.FAIL] | |
| ), | |
| }, | |
| } | |
| with open( | |
| output_dir / f"best_results_{config.method}_{config.study_name}.json", "w" | |
| ) as f: | |
| json.dump(best_results, f, indent=2) | |
| # Save all trials data | |
| trials_data = [] | |
| for trial in study.trials: | |
| trial_data = { | |
| "number": trial.number, | |
| "value": trial.value, | |
| "params": trial.params, | |
| "user_attrs": trial.user_attrs, | |
| "state": trial.state.name, | |
| "datetime_start": trial.datetime_start.isoformat() | |
| if trial.datetime_start | |
| else None, | |
| "datetime_complete": trial.datetime_complete.isoformat() | |
| if trial.datetime_complete | |
| else None, | |
| } | |
| trials_data.append(trial_data) | |
| with open( | |
| output_dir / f"all_trials_{config.method}_{config.study_name}.json", "w" | |
| ) as f: | |
| json.dump(trials_data, f, indent=2) | |
| # Print summary | |
| log.success("Optimization completed!") | |
| log.info(f"Best {config.objective_metric}: {best_trial.value:.4f}") | |
| log.info("Best parameters:") | |
| for key, value in best_trial.params.items(): | |
| log.info(f" {key}: {value}") | |
| # Print study statistics | |
| stats = best_results["study_stats"] | |
| log.info("Study statistics:") | |
| log.info(f" Total trials: {stats['n_trials']}") | |
| log.info(f" Complete trials: {stats['n_complete_trials']}") | |
| log.info(f" Pruned trials: {stats['n_pruned_trials']}") | |
| log.info(f" Failed trials: {stats['n_failed_trials']}") | |
| return study | |
| def main(): | |
| """Main function for running hyperparameter optimization.""" | |
| config = tyro.cli(SweeperConfig) | |
| # Validate required paths exist | |
| required_paths = [ | |
| (config.input_image_dir, "Input image directory"), | |
| (config.roi_folder, "ROI folder"), | |
| (config.reference_folder, "Reference folder"), | |
| ] | |
| for path, description in required_paths: | |
| if not Path(path).exists(): | |
| raise FileNotFoundError(f"{description} not found: {path}") | |
| # Set visualization style | |
| zea.visualize.set_mpl_style() | |
| # Run optimization | |
| study = run_optimization(config) | |
| # Optionally, generate optimization plots | |
| try: | |
| import matplotlib.pyplot as plt | |
| import optuna.visualization as vis | |
| output_dir = Path(config.output_dir) | |
| # Plot optimization history | |
| fig = vis.matplotlib.plot_optimization_history(study).figure | |
| fig.savefig( | |
| output_dir / f"optimization_history_{config.method}.png", | |
| dpi=300, | |
| bbox_inches="tight", | |
| ) | |
| plt.close(fig) | |
| # Plot parameter importances | |
| fig = vis.matplotlib.plot_param_importances(study).figure | |
| fig.savefig( | |
| output_dir / f"param_importances_{config.method}.png", | |
| dpi=300, | |
| bbox_inches="tight", | |
| ) | |
| plt.close(fig) | |
| # Plot parallel coordinate | |
| fig = vis.matplotlib.plot_parallel_coordinate(study).figure | |
| fig.savefig( | |
| output_dir / f"parallel_coordinate_{config.method}.png", | |
| dpi=300, | |
| bbox_inches="tight", | |
| ) | |
| plt.close(fig) | |
| log.success(f"Optimization plots saved to {output_dir}") | |
| except ImportError: | |
| log.warning( | |
| "Optuna visualization not available. Install with: pip install optuna[visualization]" | |
| ) | |
| if __name__ == "__main__": | |
| main() | |