Spaces:
Runtime error
Runtime error
| from transformers import CLIPTextModel, CLIPTokenizer, logging | |
| from diffusers import ( | |
| AutoencoderKL, | |
| UNet2DConditionModel, | |
| PNDMScheduler, | |
| DDIMScheduler, | |
| StableDiffusionPipeline, | |
| ) | |
| from diffusers.utils.import_utils import is_xformers_available | |
| # suppress partial model loading warning | |
| logging.set_verbosity_error() | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| def seed_everything(seed): | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed(seed) | |
| # torch.backends.cudnn.deterministic = True | |
| # torch.backends.cudnn.benchmark = True | |
| class StableDiffusion(nn.Module): | |
| def __init__( | |
| self, | |
| device, | |
| fp16=True, | |
| vram_O=False, | |
| sd_version="2.1", | |
| hf_key=None, | |
| t_range=[0.02, 0.98], | |
| ): | |
| super().__init__() | |
| self.device = device | |
| self.sd_version = sd_version | |
| if hf_key is not None: | |
| print(f"[INFO] using hugging face custom model key: {hf_key}") | |
| model_key = hf_key | |
| elif self.sd_version == "2.1": | |
| model_key = "stabilityai/stable-diffusion-2-1-base" | |
| elif self.sd_version == "2.0": | |
| model_key = "stabilityai/stable-diffusion-2-base" | |
| elif self.sd_version == "1.5": | |
| model_key = "runwayml/stable-diffusion-v1-5" | |
| else: | |
| raise ValueError( | |
| f"Stable-diffusion version {self.sd_version} not supported." | |
| ) | |
| self.dtype = torch.float16 if fp16 else torch.float32 | |
| # Create model | |
| pipe = StableDiffusionPipeline.from_pretrained( | |
| model_key, torch_dtype=self.dtype | |
| ) | |
| if vram_O: | |
| pipe.enable_sequential_cpu_offload() | |
| pipe.enable_vae_slicing() | |
| pipe.unet.to(memory_format=torch.channels_last) | |
| pipe.enable_attention_slicing(1) | |
| # pipe.enable_model_cpu_offload() | |
| else: | |
| pipe.to(device) | |
| self.vae = pipe.vae | |
| self.tokenizer = pipe.tokenizer | |
| self.text_encoder = pipe.text_encoder | |
| self.unet = pipe.unet | |
| self.scheduler = DDIMScheduler.from_pretrained( | |
| model_key, subfolder="scheduler", torch_dtype=self.dtype | |
| ) | |
| del pipe | |
| self.num_train_timesteps = self.scheduler.config.num_train_timesteps | |
| self.min_step = int(self.num_train_timesteps * t_range[0]) | |
| self.max_step = int(self.num_train_timesteps * t_range[1]) | |
| self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience | |
| self.embeddings = None | |
| def get_text_embeds(self, prompts, negative_prompts): | |
| pos_embeds = self.encode_text(prompts) # [1, 77, 768] | |
| neg_embeds = self.encode_text(negative_prompts) | |
| self.embeddings = torch.cat([neg_embeds, pos_embeds], dim=0) # [2, 77, 768] | |
| def encode_text(self, prompt): | |
| # prompt: [str] | |
| inputs = self.tokenizer( | |
| prompt, | |
| padding="max_length", | |
| max_length=self.tokenizer.model_max_length, | |
| return_tensors="pt", | |
| ) | |
| embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0] | |
| return embeddings | |
| def refine(self, pred_rgb, | |
| guidance_scale=100, steps=50, strength=0.8, | |
| ): | |
| batch_size = pred_rgb.shape[0] | |
| pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False) | |
| latents = self.encode_imgs(pred_rgb_512.to(self.dtype)) | |
| # latents = torch.randn((1, 4, 64, 64), device=self.device, dtype=self.dtype) | |
| self.scheduler.set_timesteps(steps) | |
| init_step = int(steps * strength) | |
| latents = self.scheduler.add_noise(latents, torch.randn_like(latents), self.scheduler.timesteps[init_step]) | |
| for i, t in enumerate(self.scheduler.timesteps[init_step:]): | |
| latent_model_input = torch.cat([latents] * 2) | |
| noise_pred = self.unet( | |
| latent_model_input, t, encoder_hidden_states=self.embeddings, | |
| ).sample | |
| noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) | |
| latents = self.scheduler.step(noise_pred, t, latents).prev_sample | |
| imgs = self.decode_latents(latents) # [1, 3, 512, 512] | |
| return imgs | |
| def train_step( | |
| self, | |
| pred_rgb, | |
| step_ratio=None, | |
| guidance_scale=100, | |
| as_latent=False, | |
| ): | |
| batch_size = pred_rgb.shape[0] | |
| pred_rgb = pred_rgb.to(self.dtype) | |
| if as_latent: | |
| latents = F.interpolate(pred_rgb, (64, 64), mode="bilinear", align_corners=False) * 2 - 1 | |
| else: | |
| # interp to 512x512 to be fed into vae. | |
| pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode="bilinear", align_corners=False) | |
| # encode image into latents with vae, requires grad! | |
| latents = self.encode_imgs(pred_rgb_512) | |
| if step_ratio is not None: | |
| # dreamtime-like | |
| # t = self.max_step - (self.max_step - self.min_step) * np.sqrt(step_ratio) | |
| t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step) | |
| t = torch.full((batch_size,), t, dtype=torch.long, device=self.device) | |
| else: | |
| t = torch.randint(self.min_step, self.max_step + 1, (batch_size,), dtype=torch.long, device=self.device) | |
| # w(t), sigma_t^2 | |
| w = (1 - self.alphas[t]).view(batch_size, 1, 1, 1) | |
| # predict the noise residual with unet, NO grad! | |
| with torch.no_grad(): | |
| # add noise | |
| noise = torch.randn_like(latents) | |
| latents_noisy = self.scheduler.add_noise(latents, noise, t) | |
| # pred noise | |
| latent_model_input = torch.cat([latents_noisy] * 2) | |
| tt = torch.cat([t] * 2) | |
| noise_pred = self.unet( | |
| latent_model_input, tt, encoder_hidden_states=self.embeddings.repeat(batch_size, 1, 1) | |
| ).sample | |
| # perform guidance (high scale from paper!) | |
| noise_pred_uncond, noise_pred_pos = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + guidance_scale * ( | |
| noise_pred_pos - noise_pred_uncond | |
| ) | |
| grad = w * (noise_pred - noise) | |
| grad = torch.nan_to_num(grad) | |
| # seems important to avoid NaN... | |
| # grad = grad.clamp(-1, 1) | |
| target = (latents - grad).detach() | |
| loss = 0.5 * F.mse_loss(latents.float(), target, reduction='sum') / latents.shape[0] | |
| return loss | |
| def produce_latents( | |
| self, | |
| height=512, | |
| width=512, | |
| num_inference_steps=50, | |
| guidance_scale=7.5, | |
| latents=None, | |
| ): | |
| if latents is None: | |
| latents = torch.randn( | |
| ( | |
| self.embeddings.shape[0] // 2, | |
| self.unet.in_channels, | |
| height // 8, | |
| width // 8, | |
| ), | |
| device=self.device, | |
| ) | |
| self.scheduler.set_timesteps(num_inference_steps) | |
| for i, t in enumerate(self.scheduler.timesteps): | |
| # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. | |
| latent_model_input = torch.cat([latents] * 2) | |
| # predict the noise residual | |
| noise_pred = self.unet( | |
| latent_model_input, t, encoder_hidden_states=self.embeddings | |
| ).sample | |
| # perform guidance | |
| noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + guidance_scale * ( | |
| noise_pred_cond - noise_pred_uncond | |
| ) | |
| # compute the previous noisy sample x_t -> x_t-1 | |
| latents = self.scheduler.step(noise_pred, t, latents).prev_sample | |
| return latents | |
| def decode_latents(self, latents): | |
| latents = 1 / self.vae.config.scaling_factor * latents | |
| imgs = self.vae.decode(latents).sample | |
| imgs = (imgs / 2 + 0.5).clamp(0, 1) | |
| return imgs | |
| def encode_imgs(self, imgs): | |
| # imgs: [B, 3, H, W] | |
| imgs = 2 * imgs - 1 | |
| posterior = self.vae.encode(imgs).latent_dist | |
| latents = posterior.sample() * self.vae.config.scaling_factor | |
| return latents | |
| def prompt_to_img( | |
| self, | |
| prompts, | |
| negative_prompts="", | |
| height=512, | |
| width=512, | |
| num_inference_steps=50, | |
| guidance_scale=7.5, | |
| latents=None, | |
| ): | |
| if isinstance(prompts, str): | |
| prompts = [prompts] | |
| if isinstance(negative_prompts, str): | |
| negative_prompts = [negative_prompts] | |
| # Prompts -> text embeds | |
| self.get_text_embeds(prompts, negative_prompts) | |
| # Text embeds -> img latents | |
| latents = self.produce_latents( | |
| height=height, | |
| width=width, | |
| latents=latents, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| ) # [1, 4, 64, 64] | |
| # Img latents -> imgs | |
| imgs = self.decode_latents(latents) # [1, 3, 512, 512] | |
| # Img to Numpy | |
| imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy() | |
| imgs = (imgs * 255).round().astype("uint8") | |
| return imgs | |
| if __name__ == "__main__": | |
| import argparse | |
| import matplotlib.pyplot as plt | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("prompt", type=str) | |
| parser.add_argument("--negative", default="", type=str) | |
| parser.add_argument( | |
| "--sd_version", | |
| type=str, | |
| default="2.1", | |
| choices=["1.5", "2.0", "2.1"], | |
| help="stable diffusion version", | |
| ) | |
| parser.add_argument( | |
| "--hf_key", | |
| type=str, | |
| default=None, | |
| help="hugging face Stable diffusion model key", | |
| ) | |
| parser.add_argument("--fp16", action="store_true", help="use float16 for training") | |
| parser.add_argument( | |
| "--vram_O", action="store_true", help="optimization for low VRAM usage" | |
| ) | |
| parser.add_argument("-H", type=int, default=512) | |
| parser.add_argument("-W", type=int, default=512) | |
| parser.add_argument("--seed", type=int, default=0) | |
| parser.add_argument("--steps", type=int, default=50) | |
| opt = parser.parse_args() | |
| seed_everything(opt.seed) | |
| device = torch.device("cuda") | |
| sd = StableDiffusion(device, opt.fp16, opt.vram_O, opt.sd_version, opt.hf_key) | |
| imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps) | |
| # visualize image | |
| plt.imshow(imgs[0]) | |
| plt.show() | |