Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	| import os | |
| import time | |
| import torch | |
| import shutil | |
| import argparse | |
| import numpy as np | |
| from tqdm import tqdm | |
| from PIL import Image | |
| from datasets import load_dataset | |
| from accelerate import Accelerator | |
| from diffusers.utils import load_image | |
| from diffusers import ( | |
| AutoencoderKL, | |
| StableDiffusionXLControlNetPipeline, | |
| ControlNetModel, | |
| UNet2DConditionModel, | |
| ) | |
| from huggingface_hub import hf_hub_download | |
| from safetensors.torch import load_file | |
| # Define the function to parse arguments | |
| def parse_args(input_args=None): | |
| parser = argparse.ArgumentParser(description="Simple example of a ControlNet evaluation script.") | |
| parser.add_argument( | |
| "--pretrained_model_name_or_path", | |
| type=str, | |
| default=None, | |
| required=True, | |
| help="Path to pretrained model or model identifier from huggingface.co/models.", | |
| ) | |
| parser.add_argument( | |
| "--pretrained_vae_model_name_or_path", | |
| type=str, | |
| default=None, | |
| help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.", | |
| ) | |
| parser.add_argument( | |
| "--controlnet_model_name_or_path", | |
| type=str, | |
| default=None, | |
| required=True, | |
| help="Path to pretrained controlnet model.", | |
| ) | |
| parser.add_argument( | |
| "--output_dir", | |
| type=str, | |
| default=None, | |
| required=True, | |
| help="Path to output results.", | |
| ) | |
| parser.add_argument( | |
| "--dataset", | |
| type=str, | |
| default="nickpai/coco2017-colorization", | |
| help="Dataset used" | |
| ) | |
| parser.add_argument( | |
| "--dataset_revision", | |
| type=str, | |
| default="caption-free", | |
| choices=["main", "caption-free", "custom-caption"], | |
| help="Revision option (main/caption-free/custom-caption)" | |
| ) | |
| parser.add_argument( | |
| "--mixed_precision", | |
| type=str, | |
| default=None, | |
| choices=["no", "fp16", "bf16"], | |
| help=( | |
| "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" | |
| " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" | |
| " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--variant", | |
| type=str, | |
| default=None, | |
| help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", | |
| ) | |
| parser.add_argument( | |
| "--revision", | |
| type=str, | |
| default=None, | |
| required=False, | |
| help="Revision of pretrained model identifier from huggingface.co/models.", | |
| ) | |
| parser.add_argument( | |
| "--num_inference_steps", | |
| type=int, | |
| default=8, | |
| help="1-step, 2-step, 4-step, or 8-step distilled models" | |
| ) | |
| parser.add_argument( | |
| "--repo", | |
| type=str, | |
| default="ByteDance/SDXL-Lightning", | |
| required=True, | |
| help="Repository from huggingface.co", | |
| ) | |
| parser.add_argument( | |
| "--ckpt", | |
| type=str, | |
| default="sdxl_lightning_4step_unet.safetensors", | |
| required=True, | |
| help="Available checkpoints from the repository", | |
| ) | |
| parser.add_argument( | |
| "--negative_prompt", | |
| action="store_true", | |
| help="The prompt or prompts not to guide the image generation", | |
| ) | |
| if input_args is not None: | |
| args = parser.parse_args(input_args) | |
| else: | |
| args = parser.parse_args() | |
| return args | |
| def apply_color(image, color_map): | |
| # Convert input images to LAB color space | |
| image_lab = image.convert('LAB') | |
| color_map_lab = color_map.convert('LAB') | |
| # Split LAB channels | |
| l, a, b = image_lab.split() | |
| _, a_map, b_map = color_map_lab.split() | |
| # Merge LAB channels with color map | |
| merged_lab = Image.merge('LAB', (l, a_map, b_map)) | |
| # Convert merged LAB image back to RGB color space | |
| result_rgb = merged_lab.convert('RGB') | |
| return result_rgb | |
| def main(args): | |
| generator = torch.manual_seed(0) | |
| # Path to the eval_results folder | |
| eval_results_folder = os.path.join(args.output_dir, "results") | |
| # Remove eval_results folder if it exists | |
| if os.path.exists(eval_results_folder): | |
| shutil.rmtree(eval_results_folder) | |
| # Create directory for eval_results | |
| os.makedirs(eval_results_folder) | |
| # Create subfolders for compare and colorized images | |
| compare_folder = os.path.join(eval_results_folder, "compare") | |
| colorized_folder = os.path.join(eval_results_folder, "colorized") | |
| os.makedirs(compare_folder) | |
| os.makedirs(colorized_folder) | |
| # Load the validation split of the colorization dataset | |
| val_dataset = load_dataset(args.dataset, split="validation", revision=args.dataset_revision) | |
| accelerator = Accelerator( | |
| mixed_precision=args.mixed_precision, | |
| ) | |
| weight_dtype = torch.float32 | |
| if accelerator.mixed_precision == "fp16": | |
| weight_dtype = torch.float16 | |
| elif accelerator.mixed_precision == "bf16": | |
| weight_dtype = torch.bfloat16 | |
| vae_path = ( | |
| args.pretrained_model_name_or_path | |
| if args.pretrained_vae_model_name_or_path is None | |
| else args.pretrained_vae_model_name_or_path | |
| ) | |
| vae = AutoencoderKL.from_pretrained( | |
| vae_path, | |
| subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, | |
| revision=args.revision, | |
| variant=args.variant, | |
| ) | |
| unet = UNet2DConditionModel.from_config( | |
| args.pretrained_model_name_or_path, | |
| subfolder="unet", | |
| revision=args.revision, | |
| variant=args.variant, | |
| ) | |
| unet.load_state_dict(load_file(hf_hub_download(args.repo, args.ckpt))) | |
| # Move vae, unet and text_encoder to device and cast to weight_dtype | |
| # The VAE is in float32 to avoid NaN losses. | |
| if args.pretrained_vae_model_name_or_path is not None: | |
| vae.to(accelerator.device, dtype=weight_dtype) | |
| else: | |
| vae.to(accelerator.device, dtype=torch.float32) | |
| unet.to(accelerator.device, dtype=weight_dtype) | |
| controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path, torch_dtype=weight_dtype) | |
| pipe = StableDiffusionXLControlNetPipeline.from_pretrained( | |
| args.pretrained_model_name_or_path, | |
| vae=vae, | |
| unet=unet, | |
| controlnet=controlnet, | |
| ) | |
| pipe.to(accelerator.device, dtype=weight_dtype) | |
| # Prepare everything with our `accelerator`. | |
| pipe, val_dataset = accelerator.prepare(pipe, val_dataset) | |
| pipe.safety_checker = None | |
| # Counter for processed images | |
| processed_images = 0 | |
| # Record start time | |
| start_time = time.time() | |
| # Iterate through the validation dataset | |
| for example in tqdm(val_dataset, desc="Processing Images"): | |
| image_path = example["file_name"] | |
| prompt = [] | |
| for caption in example["captions"]: | |
| if isinstance(caption, str): | |
| prompt.append(caption) | |
| elif isinstance(caption, (list, np.ndarray)): | |
| # take a random caption if there are multiple | |
| prompt.append(caption[0]) | |
| else: | |
| raise ValueError( | |
| f"Caption column `captions` should contain either strings or lists of strings." | |
| ) | |
| negative_prompt = None | |
| if args.negative_prompt: | |
| negative_prompt = [ | |
| "low quality, bad quality, low contrast, black and white, bw, monochrome, grainy, blurry, historical, restored, desaturate" | |
| ] | |
| # Generate image | |
| ground_truth_image = load_image(image_path).resize((512, 512)) | |
| control_image = load_image(image_path).convert("L").convert("RGB").resize((512, 512)) | |
| image = pipe(prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| num_inference_steps=args.num_inference_steps, | |
| generator=generator, | |
| image=control_image).images[0] | |
| # Apply color mapping | |
| image = apply_color(ground_truth_image, image) | |
| # Concatenate images into a row | |
| row_image = np.hstack((np.array(control_image), np.array(image), np.array(ground_truth_image))) | |
| row_image = Image.fromarray(row_image) | |
| # Save row image in the compare folder | |
| compare_output_path = os.path.join(compare_folder, f"{image_path.split('/')[-1]}") | |
| row_image.save(compare_output_path) | |
| # Save colorized image in the colorized folder | |
| colorized_output_path = os.path.join(colorized_folder, f"{image_path.split('/')[-1]}") | |
| image.save(colorized_output_path) | |
| # Increment processed images counter | |
| processed_images += 1 | |
| # Record end time | |
| end_time = time.time() | |
| # Calculate total time taken | |
| total_time = end_time - start_time | |
| # Calculate FPS | |
| fps = processed_images / total_time | |
| print("All images processed.") | |
| print(f"Total time taken: {total_time:.2f} seconds") | |
| print(f"FPS: {fps:.2f}") | |
| # Entry point of the script | |
| if __name__ == "__main__": | |
| args = parse_args() | |
| main(args) | 
