Spaces:
Runtime error
Runtime error
| # Copyright 2022 Lunar Ring. All rights reserved. | |
| # Written by Johannes Stelzer, email [email protected] twitter @j_stelzer | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import os | |
| import torch | |
| torch.backends.cudnn.benchmark = False | |
| torch.set_grad_enabled(False) | |
| import numpy as np | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| import warnings | |
| import torch | |
| from PIL import Image | |
| import torch | |
| from typing import Optional | |
| from omegaconf import OmegaConf | |
| from torch import autocast | |
| from contextlib import nullcontext | |
| from ldm.util import instantiate_from_config | |
| from ldm.models.diffusion.ddim import DDIMSampler | |
| from einops import repeat, rearrange | |
| from utils import interpolate_spherical | |
| def pad_image(input_image): | |
| pad_w, pad_h = np.max(((2, 2), np.ceil( | |
| np.array(input_image.size) / 64).astype(int)), axis=0) * 64 - input_image.size | |
| im_padded = Image.fromarray( | |
| np.pad(np.array(input_image), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge')) | |
| return im_padded | |
| def make_batch_superres( | |
| image, | |
| txt, | |
| device, | |
| num_samples=1): | |
| image = np.array(image.convert("RGB")) | |
| image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 | |
| batch = { | |
| "lr": rearrange(image, 'h w c -> 1 c h w'), | |
| "txt": num_samples * [txt], | |
| } | |
| batch["lr"] = repeat(batch["lr"].to(device=device), | |
| "1 ... -> n ...", n=num_samples) | |
| return batch | |
| def make_noise_augmentation(model, batch, noise_level=None): | |
| x_low = batch[model.low_scale_key] | |
| x_low = x_low.to(memory_format=torch.contiguous_format).float() | |
| x_aug, noise_level = model.low_scale_model(x_low, noise_level) | |
| return x_aug, noise_level | |
| class StableDiffusionHolder: | |
| def __init__(self, | |
| fp_ckpt: str = None, | |
| fp_config: str = None, | |
| num_inference_steps: int = 30, | |
| height: Optional[int] = None, | |
| width: Optional[int] = None, | |
| device: str = None, | |
| precision: str = 'autocast', | |
| ): | |
| r""" | |
| Initializes the stable diffusion holder, which contains the models and sampler. | |
| Args: | |
| fp_ckpt: File pointer to the .ckpt model file | |
| fp_config: File pointer to the .yaml config file | |
| num_inference_steps: Number of diffusion iterations. Will be overwritten by latent blending. | |
| height: Height of the resulting image. | |
| width: Width of the resulting image. | |
| device: Device to run the model on. | |
| precision: Precision to run the model on. | |
| """ | |
| self.seed = 42 | |
| self.guidance_scale = 5.0 | |
| if device is None: | |
| self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
| else: | |
| self.device = device | |
| self.precision = precision | |
| self.init_model(fp_ckpt, fp_config) | |
| self.f = 8 # downsampling factor, most often 8 or 16" | |
| self.C = 4 | |
| self.ddim_eta = 0 | |
| self.num_inference_steps = num_inference_steps | |
| if height is None and width is None: | |
| self.init_auto_res() | |
| else: | |
| assert height is not None, "specify both width and height" | |
| assert width is not None, "specify both width and height" | |
| self.height = height | |
| self.width = width | |
| self.negative_prompt = [""] | |
| def init_model(self, fp_ckpt, fp_config): | |
| r"""Loads the models and sampler. | |
| """ | |
| assert os.path.isfile(fp_ckpt), f"Your model checkpoint file does not exist: {fp_ckpt}" | |
| self.fp_ckpt = fp_ckpt | |
| # Auto init the config? | |
| if fp_config is None: | |
| fn_ckpt = os.path.basename(fp_ckpt) | |
| if 'depth' in fn_ckpt: | |
| fp_config = 'configs/v2-midas-inference.yaml' | |
| elif 'upscaler' in fn_ckpt: | |
| fp_config = 'configs/x4-upscaling.yaml' | |
| elif '512' in fn_ckpt: | |
| fp_config = 'configs/v2-inference.yaml' | |
| elif '768' in fn_ckpt: | |
| fp_config = 'configs/v2-inference-v.yaml' | |
| elif 'v1-5' in fn_ckpt: | |
| fp_config = 'configs/v1-inference.yaml' | |
| else: | |
| raise ValueError("auto detect of config failed. please specify fp_config manually!") | |
| assert os.path.isfile(fp_config), "Auto-init of the config file failed. Please specify manually." | |
| assert os.path.isfile(fp_config), f"Your config file does not exist: {fp_config}" | |
| config = OmegaConf.load(fp_config) | |
| self.model = instantiate_from_config(config.model) | |
| self.model.load_state_dict(torch.load(fp_ckpt)["state_dict"], strict=False) | |
| self.model = self.model.to(self.device) | |
| self.sampler = DDIMSampler(self.model) | |
| def init_auto_res(self): | |
| r"""Automatically set the resolution to the one used in training. | |
| """ | |
| if '768' in self.fp_ckpt: | |
| self.height = 768 | |
| self.width = 768 | |
| else: | |
| self.height = 512 | |
| self.width = 512 | |
| def set_negative_prompt(self, negative_prompt): | |
| r"""Set the negative prompt. Currenty only one negative prompt is supported | |
| """ | |
| if isinstance(negative_prompt, str): | |
| self.negative_prompt = [negative_prompt] | |
| else: | |
| self.negative_prompt = negative_prompt | |
| if len(self.negative_prompt) > 1: | |
| self.negative_prompt = [self.negative_prompt[0]] | |
| def get_text_embedding(self, prompt): | |
| c = self.model.get_learned_conditioning(prompt) | |
| return c | |
| def get_cond_upscaling(self, image, text_embedding, noise_level): | |
| r""" | |
| Initializes the conditioning for the x4 upscaling model. | |
| """ | |
| image = pad_image(image) # resize to integer multiple of 32 | |
| w, h = image.size | |
| noise_level = torch.Tensor(1 * [noise_level]).to(self.sampler.model.device).long() | |
| batch = make_batch_superres(image, txt="placeholder", device=self.device, num_samples=1) | |
| x_augment, noise_level = make_noise_augmentation(self.model, batch, noise_level) | |
| cond = {"c_concat": [x_augment], "c_crossattn": [text_embedding], "c_adm": noise_level} | |
| # uncond cond | |
| uc_cross = self.model.get_unconditional_conditioning(1, "") | |
| uc_full = {"c_concat": [x_augment], "c_crossattn": [uc_cross], "c_adm": noise_level} | |
| return cond, uc_full | |
| def run_diffusion_standard( | |
| self, | |
| text_embeddings: torch.FloatTensor, | |
| latents_start: torch.FloatTensor, | |
| idx_start: int = 0, | |
| list_latents_mixing=None, | |
| mixing_coeffs=0.0, | |
| spatial_mask=None, | |
| return_image: Optional[bool] = False): | |
| r""" | |
| Diffusion standard version. | |
| Args: | |
| text_embeddings: torch.FloatTensor | |
| Text embeddings used for diffusion | |
| latents_for_injection: torch.FloatTensor or list | |
| Latents that are used for injection | |
| idx_start: int | |
| Index of the diffusion process start and where the latents_for_injection are injected | |
| mixing_coeff: | |
| mixing coefficients for latent blending | |
| spatial_mask: | |
| experimental feature for enforcing pixels from list_latents_mixing | |
| return_image: Optional[bool] | |
| Optionally return image directly | |
| """ | |
| # Asserts | |
| if type(mixing_coeffs) == float: | |
| list_mixing_coeffs = self.num_inference_steps * [mixing_coeffs] | |
| elif type(mixing_coeffs) == list: | |
| assert len(mixing_coeffs) == self.num_inference_steps | |
| list_mixing_coeffs = mixing_coeffs | |
| else: | |
| raise ValueError("mixing_coeffs should be float or list with len=num_inference_steps") | |
| if np.sum(list_mixing_coeffs) > 0: | |
| assert len(list_latents_mixing) == self.num_inference_steps | |
| precision_scope = autocast if self.precision == "autocast" else nullcontext | |
| with precision_scope("cuda"): | |
| with self.model.ema_scope(): | |
| if self.guidance_scale != 1.0: | |
| uc = self.model.get_learned_conditioning(self.negative_prompt) | |
| else: | |
| uc = None | |
| self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps - 1, ddim_eta=self.ddim_eta, verbose=False) | |
| latents = latents_start.clone() | |
| timesteps = self.sampler.ddim_timesteps | |
| time_range = np.flip(timesteps) | |
| total_steps = timesteps.shape[0] | |
| # Collect latents | |
| list_latents_out = [] | |
| for i, step in enumerate(time_range): | |
| # Set the right starting latents | |
| if i < idx_start: | |
| list_latents_out.append(None) | |
| continue | |
| elif i == idx_start: | |
| latents = latents_start.clone() | |
| # Mix latents | |
| if i > 0 and list_mixing_coeffs[i] > 0: | |
| latents_mixtarget = list_latents_mixing[i - 1].clone() | |
| latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i]) | |
| if spatial_mask is not None and list_latents_mixing is not None: | |
| latents = interpolate_spherical(latents, list_latents_mixing[i - 1], 1 - spatial_mask) | |
| index = total_steps - i - 1 | |
| ts = torch.full((1,), step, device=self.device, dtype=torch.long) | |
| outs = self.sampler.p_sample_ddim(latents, text_embeddings, ts, index=index, use_original_steps=False, | |
| quantize_denoised=False, temperature=1.0, | |
| noise_dropout=0.0, score_corrector=None, | |
| corrector_kwargs=None, | |
| unconditional_guidance_scale=self.guidance_scale, | |
| unconditional_conditioning=uc, | |
| dynamic_threshold=None) | |
| latents, pred_x0 = outs | |
| list_latents_out.append(latents.clone()) | |
| if return_image: | |
| return self.latent2image(latents) | |
| else: | |
| return list_latents_out | |
| def run_diffusion_upscaling( | |
| self, | |
| cond, | |
| uc_full, | |
| latents_start: torch.FloatTensor, | |
| idx_start: int = -1, | |
| list_latents_mixing: list = None, | |
| mixing_coeffs: float = 0.0, | |
| return_image: Optional[bool] = False): | |
| r""" | |
| Diffusion upscaling version. | |
| """ | |
| # Asserts | |
| if type(mixing_coeffs) == float: | |
| list_mixing_coeffs = self.num_inference_steps * [mixing_coeffs] | |
| elif type(mixing_coeffs) == list: | |
| assert len(mixing_coeffs) == self.num_inference_steps | |
| list_mixing_coeffs = mixing_coeffs | |
| else: | |
| raise ValueError("mixing_coeffs should be float or list with len=num_inference_steps") | |
| if np.sum(list_mixing_coeffs) > 0: | |
| assert len(list_latents_mixing) == self.num_inference_steps | |
| precision_scope = autocast if self.precision == "autocast" else nullcontext | |
| h = uc_full['c_concat'][0].shape[2] | |
| w = uc_full['c_concat'][0].shape[3] | |
| with precision_scope("cuda"): | |
| with self.model.ema_scope(): | |
| shape_latents = [self.model.channels, h, w] | |
| self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps - 1, ddim_eta=self.ddim_eta, verbose=False) | |
| C, H, W = shape_latents | |
| size = (1, C, H, W) | |
| b = size[0] | |
| latents = latents_start.clone() | |
| timesteps = self.sampler.ddim_timesteps | |
| time_range = np.flip(timesteps) | |
| total_steps = timesteps.shape[0] | |
| # collect latents | |
| list_latents_out = [] | |
| for i, step in enumerate(time_range): | |
| # Set the right starting latents | |
| if i < idx_start: | |
| list_latents_out.append(None) | |
| continue | |
| elif i == idx_start: | |
| latents = latents_start.clone() | |
| # Mix the latents. | |
| if i > 0 and list_mixing_coeffs[i] > 0: | |
| latents_mixtarget = list_latents_mixing[i - 1].clone() | |
| latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i]) | |
| # print(f"diffusion iter {i}") | |
| index = total_steps - i - 1 | |
| ts = torch.full((b,), step, device=self.device, dtype=torch.long) | |
| outs = self.sampler.p_sample_ddim(latents, cond, ts, index=index, use_original_steps=False, | |
| quantize_denoised=False, temperature=1.0, | |
| noise_dropout=0.0, score_corrector=None, | |
| corrector_kwargs=None, | |
| unconditional_guidance_scale=self.guidance_scale, | |
| unconditional_conditioning=uc_full, | |
| dynamic_threshold=None) | |
| latents, pred_x0 = outs | |
| list_latents_out.append(latents.clone()) | |
| if return_image: | |
| return self.latent2image(latents) | |
| else: | |
| return list_latents_out | |
| def latent2image( | |
| self, | |
| latents: torch.FloatTensor): | |
| r""" | |
| Returns an image provided a latent representation from diffusion. | |
| Args: | |
| latents: torch.FloatTensor | |
| Result of the diffusion process. | |
| """ | |
| x_sample = self.model.decode_first_stage(latents) | |
| x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) | |
| x_sample = 255 * x_sample[0, :, :].permute([1, 2, 0]).cpu().numpy() | |
| image = x_sample.astype(np.uint8) | |
| return image | |