Spaces:
Build error
Build error
| import json | |
| import math | |
| import random | |
| import time | |
| from pathlib import Path | |
| from uuid import uuid4 | |
| import torch | |
| from diffusers import __version__ as diffusers_version | |
| from huggingface_hub import CommitOperationAdd, create_commit, create_repo | |
| from .upsampling import RealESRGANModel | |
| from .utils import pad_along_axis | |
| def get_all_files(root: Path): | |
| dirs = [root] | |
| while len(dirs) > 0: | |
| dir = dirs.pop() | |
| for candidate in dir.iterdir(): | |
| if candidate.is_file(): | |
| yield candidate | |
| if candidate.is_dir(): | |
| dirs.append(candidate) | |
| def get_groups_of_n(n: int, iterator): | |
| assert n > 1 | |
| buffer = [] | |
| for elt in iterator: | |
| if len(buffer) == n: | |
| yield buffer | |
| buffer = [] | |
| buffer.append(elt) | |
| if len(buffer) != 0: | |
| yield buffer | |
| def upload_folder_chunked( | |
| repo_id: str, | |
| upload_dir: Path, | |
| n: int = 100, | |
| private: bool = False, | |
| create_pr: bool = False, | |
| ): | |
| """Upload a folder to the Hugging Face Hub in chunks of n files at a time. | |
| Args: | |
| repo_id (str): The repo id to upload to. | |
| upload_dir (Path): The directory to upload. | |
| n (int, *optional*, defaults to 100): The number of files to upload at a time. | |
| private (bool, *optional*): Whether to upload the repo as private. | |
| create_pr (bool, *optional*): Whether to create a PR after uploading instead of commiting directly. | |
| """ | |
| url = create_repo(repo_id, exist_ok=True, private=private, repo_type="dataset") | |
| print(f"Uploading files to: {url}") | |
| root = Path(upload_dir) | |
| if not root.exists(): | |
| raise ValueError(f"Upload directory {root} does not exist.") | |
| for i, file_paths in enumerate(get_groups_of_n(n, get_all_files(root))): | |
| print(f"Committing {file_paths}") | |
| operations = [ | |
| CommitOperationAdd( | |
| path_in_repo=f"{file_path.parent.name}/{file_path.name}", | |
| path_or_fileobj=str(file_path), | |
| ) | |
| for file_path in file_paths | |
| ] | |
| create_commit( | |
| repo_id=repo_id, | |
| operations=operations, | |
| commit_message=f"Upload part {i}", | |
| repo_type="dataset", | |
| create_pr=create_pr, | |
| ) | |
| def generate_input_batches(pipeline, prompts, seeds, batch_size, height, width): | |
| if len(prompts) != len(seeds): | |
| raise ValueError("Number of prompts and seeds must be equal.") | |
| embeds_batch, noise_batch = None, None | |
| batch_idx = 0 | |
| for i, (prompt, seed) in enumerate(zip(prompts, seeds)): | |
| embeds = pipeline.embed_text(prompt) | |
| noise = torch.randn( | |
| (1, pipeline.unet.in_channels, height // 8, width // 8), | |
| device=pipeline.device, | |
| generator=torch.Generator(device="cpu" if pipeline.device.type == "mps" else pipeline.device).manual_seed( | |
| seed | |
| ), | |
| ) | |
| embeds_batch = embeds if embeds_batch is None else torch.cat([embeds_batch, embeds]) | |
| noise_batch = noise if noise_batch is None else torch.cat([noise_batch, noise]) | |
| batch_is_ready = embeds_batch.shape[0] == batch_size or i + 1 == len(prompts) | |
| if not batch_is_ready: | |
| continue | |
| yield batch_idx, embeds_batch.type(torch.cuda.HalfTensor), noise_batch.type(torch.cuda.HalfTensor) | |
| batch_idx += 1 | |
| del embeds_batch, noise_batch | |
| torch.cuda.empty_cache() | |
| embeds_batch, noise_batch = None, None | |
| def generate_images( | |
| pipeline, | |
| prompt, | |
| batch_size=1, | |
| num_batches=1, | |
| seeds=None, | |
| num_inference_steps=50, | |
| guidance_scale=7.5, | |
| output_dir="./images", | |
| image_file_ext=".jpg", | |
| upsample=False, | |
| height=512, | |
| width=512, | |
| eta=0.0, | |
| push_to_hub=False, | |
| repo_id=None, | |
| private=False, | |
| create_pr=False, | |
| name=None, | |
| ): | |
| """Generate images using the StableDiffusion pipeline. | |
| Args: | |
| pipeline (StableDiffusionWalkPipeline): The StableDiffusion pipeline instance. | |
| prompt (str): The prompt to use for the image generation. | |
| batch_size (int, *optional*, defaults to 1): The batch size to use for image generation. | |
| num_batches (int, *optional*, defaults to 1): The number of batches to generate. | |
| seeds (list[int], *optional*): The seeds to use for the image generation. | |
| num_inference_steps (int, *optional*, defaults to 50): The number of inference steps to take. | |
| guidance_scale (float, *optional*, defaults to 7.5): The guidance scale to use for image generation. | |
| output_dir (str, *optional*, defaults to "./images"): The output directory to save the images to. | |
| image_file_ext (str, *optional*, defaults to '.jpg'): The image file extension to use. | |
| upsample (bool, *optional*, defaults to False): Whether to upsample the images. | |
| height (int, *optional*, defaults to 512): The height of the images to generate. | |
| width (int, *optional*, defaults to 512): The width of the images to generate. | |
| eta (float, *optional*, defaults to 0.0): The eta parameter to use for image generation. | |
| push_to_hub (bool, *optional*, defaults to False): Whether to push the generated images to the Hugging Face Hub. | |
| repo_id (str, *optional*): The repo id to push the images to. | |
| private (bool, *optional*): Whether to push the repo as private. | |
| create_pr (bool, *optional*): Whether to create a PR after pushing instead of commiting directly. | |
| name (str, *optional*, defaults to current timestamp str): The name of the sub-directory of | |
| output_dir to save the images to. | |
| """ | |
| if push_to_hub: | |
| if repo_id is None: | |
| raise ValueError("Must provide repo_id if push_to_hub is True.") | |
| name = name or time.strftime("%Y%m%d-%H%M%S") | |
| save_path = Path(output_dir) / name | |
| save_path.mkdir(exist_ok=False, parents=True) | |
| prompt_config_path = save_path / "prompt_config.json" | |
| num_images = batch_size * num_batches | |
| seeds = seeds or [random.choice(list(range(0, 9999999))) for _ in range(num_images)] | |
| if len(seeds) != num_images: | |
| raise ValueError("Number of seeds must be equal to batch_size * num_batches.") | |
| if upsample: | |
| if getattr(pipeline, "upsampler", None) is None: | |
| pipeline.upsampler = RealESRGANModel.from_pretrained("nateraw/real-esrgan") | |
| pipeline.upsampler.to(pipeline.device) | |
| cfg = dict( | |
| prompt=prompt, | |
| guidance_scale=guidance_scale, | |
| eta=eta, | |
| num_inference_steps=num_inference_steps, | |
| upsample=upsample, | |
| height=height, | |
| width=width, | |
| scheduler=dict(pipeline.scheduler.config), | |
| tiled=pipeline.tiled, | |
| diffusers_version=diffusers_version, | |
| device_name=torch.cuda.get_device_name(0) if torch.cuda.is_available() else "unknown", | |
| ) | |
| prompt_config_path.write_text(json.dumps(cfg, indent=2, sort_keys=False)) | |
| frame_index = 0 | |
| frame_filepaths = [] | |
| for batch_idx, embeds, noise in generate_input_batches( | |
| pipeline, [prompt] * num_images, seeds, batch_size, height, width | |
| ): | |
| print(f"Generating batch {batch_idx}") | |
| outputs = pipeline( | |
| text_embeddings=embeds, | |
| latents=noise, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| eta=eta, | |
| height=height, | |
| width=width, | |
| output_type="pil" if not upsample else "numpy", | |
| )["images"] | |
| if upsample: | |
| images = [] | |
| for output in outputs: | |
| images.append(pipeline.upsampler(output)) | |
| else: | |
| images = outputs | |
| for image in images: | |
| frame_filepath = save_path / f"{seeds[frame_index]}{image_file_ext}" | |
| image.save(frame_filepath) | |
| frame_filepaths.append(str(frame_filepath)) | |
| frame_index += 1 | |
| return frame_filepaths | |
| if push_to_hub: | |
| upload_folder_chunked(repo_id, save_path, private=private, create_pr=create_pr) | |
| def generate_images_flax( | |
| pipeline, | |
| params, | |
| prompt, | |
| batch_size=1, | |
| num_batches=1, | |
| seeds=None, | |
| num_inference_steps=50, | |
| guidance_scale=7.5, | |
| output_dir="./images", | |
| image_file_ext=".jpg", | |
| upsample=False, | |
| height=512, | |
| width=512, | |
| push_to_hub=False, | |
| repo_id=None, | |
| private=False, | |
| create_pr=False, | |
| name=None, | |
| ): | |
| import jax | |
| from flax.training.common_utils import shard | |
| """Generate images using the StableDiffusion pipeline. | |
| Args: | |
| pipeline (StableDiffusionWalkPipeline): The StableDiffusion pipeline instance. | |
| params (`Union[Dict, FrozenDict]`): The model parameters. | |
| prompt (str): The prompt to use for the image generation. | |
| batch_size (int, *optional*, defaults to 1): The batch size to use for image generation. | |
| num_batches (int, *optional*, defaults to 1): The number of batches to generate. | |
| seeds (int, *optional*): The seed to use for the image generation. | |
| num_inference_steps (int, *optional*, defaults to 50): The number of inference steps to take. | |
| guidance_scale (float, *optional*, defaults to 7.5): The guidance scale to use for image generation. | |
| output_dir (str, *optional*, defaults to "./images"): The output directory to save the images to. | |
| image_file_ext (str, *optional*, defaults to '.jpg'): The image file extension to use. | |
| upsample (bool, *optional*, defaults to False): Whether to upsample the images. | |
| height (int, *optional*, defaults to 512): The height of the images to generate. | |
| width (int, *optional*, defaults to 512): The width of the images to generate. | |
| push_to_hub (bool, *optional*, defaults to False): Whether to push the generated images to the Hugging Face Hub. | |
| repo_id (str, *optional*): The repo id to push the images to. | |
| private (bool, *optional*): Whether to push the repo as private. | |
| create_pr (bool, *optional*): Whether to create a PR after pushing instead of commiting directly. | |
| name (str, *optional*, defaults to current timestamp str): The name of the sub-directory of | |
| output_dir to save the images to. | |
| """ | |
| if push_to_hub: | |
| if repo_id is None: | |
| raise ValueError("Must provide repo_id if push_to_hub is True.") | |
| name = name or time.strftime("%Y%m%d-%H%M%S") | |
| save_path = Path(output_dir) / name | |
| save_path.mkdir(exist_ok=False, parents=True) | |
| prompt_config_path = save_path / "prompt_config.json" | |
| num_images = batch_size * num_batches | |
| seeds = seeds or random.choice(list(range(0, 9999999))) | |
| prng_seed = jax.random.PRNGKey(seeds) | |
| if upsample: | |
| if getattr(pipeline, "upsampler", None) is None: | |
| pipeline.upsampler = RealESRGANModel.from_pretrained("nateraw/real-esrgan") | |
| if not torch.cuda.is_available(): | |
| print("Upsampling is recommended to be done on a GPU, as it is very slow on CPU") | |
| else: | |
| pipeline.upsampler = pipeline.upsampler.cuda() | |
| cfg = dict( | |
| prompt=prompt, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| upsample=upsample, | |
| height=height, | |
| width=width, | |
| scheduler=dict(pipeline.scheduler.config), | |
| # tiled=pipeline.tiled, | |
| diffusers_version=diffusers_version, | |
| device_name=torch.cuda.get_device_name(0) if torch.cuda.is_available() else "unknown", | |
| ) | |
| prompt_config_path.write_text(json.dumps(cfg, indent=2, sort_keys=False)) | |
| NUM_TPU_CORES = jax.device_count() | |
| jit = True # force jit, assume params are already sharded | |
| batch_size_total = NUM_TPU_CORES * batch_size if jit else batch_size | |
| def generate_input_batches(prompts, batch_size): | |
| prompt_batch = None | |
| for batch_idx in range(math.ceil(len(prompts) / batch_size)): | |
| prompt_batch = prompts[batch_idx * batch_size : (batch_idx + 1) * batch_size] | |
| yield batch_idx, prompt_batch | |
| frame_index = 0 | |
| frame_filepaths = [] | |
| for batch_idx, prompt_batch in generate_input_batches([prompt] * num_images, batch_size_total): | |
| # This batch size correspond to each TPU core, so we are generating batch_size * NUM_TPU_CORES images | |
| print(f"Generating batches: {batch_idx*NUM_TPU_CORES} - {min((batch_idx+1)*NUM_TPU_CORES, num_batches)}") | |
| prompt_ids_batch = pipeline.prepare_inputs(prompt_batch) | |
| prng_seed_batch = prng_seed | |
| if jit: | |
| padded = False | |
| # Check if len of prompt_batch is multiple of NUM_TPU_CORES, if not pad its ids | |
| if len(prompt_batch) % NUM_TPU_CORES != 0: | |
| padded = True | |
| pad_size = NUM_TPU_CORES - (len(prompt_batch) % NUM_TPU_CORES) | |
| # Pad embeds_batch and noise_batch with zeros in batch dimension | |
| prompt_ids_batch = pad_along_axis(prompt_ids_batch, pad_size, axis=0) | |
| prompt_ids_batch = shard(prompt_ids_batch) | |
| prng_seed_batch = jax.random.split(prng_seed, jax.device_count()) | |
| outputs = pipeline( | |
| params, | |
| prng_seed=prng_seed_batch, | |
| prompt_ids=prompt_ids_batch, | |
| height=height, | |
| width=width, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| output_type="pil" if not upsample else "numpy", | |
| jit=jit, | |
| )["images"] | |
| if jit: | |
| # check if we padded and remove that padding from outputs | |
| if padded: | |
| outputs = outputs[:-pad_size] | |
| if upsample: | |
| images = [] | |
| for output in outputs: | |
| images.append(pipeline.upsampler(output)) | |
| else: | |
| images = outputs | |
| for image in images: | |
| uuid = str(uuid4()) | |
| frame_filepath = save_path / f"{uuid}{image_file_ext}" | |
| image.save(frame_filepath) | |
| frame_filepaths.append(str(frame_filepath)) | |
| frame_index += 1 | |
| return frame_filepaths | |
| if push_to_hub: | |
| upload_folder_chunked(repo_id, save_path, private=private, create_pr=create_pr) | |