Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torchvision.transforms as T | |
| from torch.autograd import grad | |
| import argparse | |
| from tqdm import tqdm | |
| from syncdiffusion.utils import * | |
| import lpips | |
| from transformers import CLIPTextModel, CLIPTokenizer | |
| from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler | |
| class SyncDiffusion(nn.Module): | |
| def __init__(self, device='cuda', sd_version='2.0', hf_key=None): | |
| super().__init__() | |
| self.device = device | |
| self.sd_version = sd_version | |
| print(f'[INFO] loading stable diffusion...') | |
| if hf_key is not None: | |
| print(f'[INFO] using hugging face custom model key: {hf_key}') | |
| model_key = hf_key | |
| elif self.sd_version == '2.1': | |
| model_key = "stabilityai/stable-diffusion-2-1-base" | |
| elif self.sd_version == '2.0': | |
| model_key = "stabilityai/stable-diffusion-2-base" | |
| elif self.sd_version == '1.5': | |
| model_key = "runwayml/stable-diffusion-v1-5" | |
| else: | |
| raise ValueError(f'Stable-diffusion version {self.sd_version} not supported.') | |
| # Load pretrained models from HuggingFace | |
| self.vae = AutoencoderKL.from_pretrained(model_key, subfolder="vae").to(self.device) | |
| self.tokenizer = CLIPTokenizer.from_pretrained(model_key, subfolder="tokenizer") | |
| self.text_encoder = CLIPTextModel.from_pretrained(model_key, subfolder="text_encoder").to(self.device) | |
| self.unet = UNet2DConditionModel.from_pretrained(model_key, subfolder="unet").to(self.device) | |
| # Freeze models | |
| for p in self.unet.parameters(): | |
| p.requires_grad_(False) | |
| for p in self.vae.parameters(): | |
| p.requires_grad_(False) | |
| for p in self.text_encoder.parameters(): | |
| p.requires_grad_(False) | |
| self.unet.eval() | |
| self.vae.eval() | |
| self.text_encoder.eval() | |
| print(f'[INFO] loaded stable diffusion!') | |
| # Set DDIM scheduler | |
| self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler") | |
| # load perceptual loss (LPIPS) | |
| self.percept_loss = lpips.LPIPS(net='vgg').to(self.device) | |
| print(f'[INFO] loaded perceptual loss!') | |
| def get_text_embeds(self, prompt, negative_prompt): | |
| # Tokenize text and get embeddings | |
| text_input = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length, | |
| truncation=True, return_tensors='pt') | |
| text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] | |
| # Repeat for unconditional embeddings | |
| uncond_input = self.tokenizer(negative_prompt, padding='max_length', max_length=self.tokenizer.model_max_length, | |
| return_tensors='pt') | |
| uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] | |
| # Concatenate for final embeddings | |
| text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) | |
| return text_embeddings | |
| def decode_latents(self, latents): | |
| latents = 1 / 0.18215 * latents | |
| imgs = self.vae.decode(latents).sample | |
| imgs = (imgs / 2 + 0.5).clamp(0, 1) | |
| return imgs | |
| def sample_syncdiffusion( | |
| self, | |
| prompts, | |
| negative_prompts="", | |
| height=512, | |
| width=2048, | |
| latent_size=64, # fix latent size to 64 for Stable Diffusion | |
| num_inference_steps=50, | |
| guidance_scale=7.5, | |
| sync_weight=20, # gradient descent weight 'w' in the paper | |
| sync_freq=1, # sync_freq=n: perform gradient descent every n steps | |
| sync_thres=50, # sync_thres=n: compute SyncDiffusion only for the first n steps | |
| sync_decay_rate=0.95, # decay rate for sync_weight, set as 0.95 in the paper | |
| stride=16, # stride for latents, set as 16 in the paper | |
| ): | |
| assert height >= 512 and width >= 512, 'height and width must be at least 512' | |
| assert height % (stride * 8) == 0 and width % (stride * 8) == 0, 'height and width must be divisible by the stride multiplied by 8' | |
| assert stride % 8 == 0 and stride < 64, 'stride must be divisible by 8 and smaller than the latent size of Stable Diffusion' | |
| if isinstance(prompts, str): | |
| prompts = [prompts] | |
| if isinstance(negative_prompts, str): | |
| negative_prompts = [negative_prompts] | |
| # obtain text embeddings | |
| text_embeds = self.get_text_embeds(prompts, negative_prompts) # [2, 77, 768] | |
| # define a list of windows to process in parallel | |
| views = get_views(height, width, stride=stride) | |
| print(f"[INFO] number of views to process: {len(views)}") | |
| # Initialize latent | |
| latent = torch.randn((1, self.unet.in_channels, height // 8, width // 8)) | |
| count = torch.zeros_like(latent, requires_grad=False, device=self.device) | |
| value = torch.zeros_like(latent, requires_grad=False, device=self.device) | |
| latent = latent.to(self.device) | |
| # set DDIM scheduler | |
| self.scheduler.set_timesteps(num_inference_steps) | |
| # set the anchor view as the middle view | |
| anchor_view_idx = len(views) // 2 | |
| # set SyncDiffusion scheduler | |
| sync_scheduler = exponential_decay_list( | |
| init_weight=sync_weight, | |
| decay_rate=sync_decay_rate, | |
| num_steps=num_inference_steps | |
| ) | |
| print(f'[INFO] using exponential decay scheduler with decay rate {sync_decay_rate}') | |
| with torch.autocast('cuda'): | |
| for i, t in enumerate(tqdm(self.scheduler.timesteps)): | |
| count.zero_() | |
| value.zero_() | |
| ''' | |
| (1) First, obtain the reference anchor view (for computing the perceptual loss) | |
| ''' | |
| with torch.no_grad(): | |
| if (i + 1) % sync_freq == 0 and i < sync_thres: | |
| # decode the anchor view | |
| h_start, h_end, w_start, w_end = views[anchor_view_idx] | |
| latent_view = latent[:, :, h_start:h_end, w_start:w_end].detach() | |
| latent_model_input = torch.cat([latent_view] * 2) # 2 x 4 x 64 x 64 | |
| noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeds)['sample'] | |
| # perform guidance | |
| noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) | |
| noise_pred_new = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) | |
| # predict the 'foreseen denoised' latent (x0) of the anchor view | |
| latent_pred_x0 = self.scheduler.step(noise_pred_new, t, latent_view)["pred_original_sample"] | |
| decoded_image_anchor = self.decode_latents(latent_pred_x0) # 1 x 3 x 512 x 512 | |
| ''' | |
| (2) Then perform SyncDiffusion and run a single denoising step | |
| ''' | |
| for view_idx, (h_start, h_end, w_start, w_end) in enumerate(views): | |
| latent_view = latent[:, :, h_start:h_end, w_start:w_end].detach() | |
| ############################## BEGIN: PERFORM GRADIENT DESCENT (SyncDiffusion) ############################## | |
| latent_view_copy = latent_view.clone().detach() | |
| #### TODO: TEST #### | |
| # if i % sync_freq == 0 and i < sync_thres: | |
| if (i + 1) % sync_freq == 0 and i < sync_thres: | |
| # gradient on latent_view | |
| latent_view = latent_view.requires_grad_() | |
| # expand the latents for classifier-free guidance | |
| latent_model_input = torch.cat([latent_view] * 2) | |
| # predict the noise residual | |
| noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeds)['sample'] | |
| # perform guidance | |
| noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) | |
| noise_pred_new = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) | |
| # compute the denoising step with the reference model | |
| out = self.scheduler.step(noise_pred_new, t, latent_view) | |
| # predict the 'foreseen denoised' latent (x0) | |
| latent_view_x0 = out['pred_original_sample'] | |
| # decode the denoised latent | |
| decoded_x0 = self.decode_latents(latent_view_x0) # 1 x 3 x 512 x 512 | |
| # compute the perceptual loss (LPIPS) | |
| percept_loss = self.percept_loss( | |
| decoded_x0 * 2.0 - 1.0, | |
| decoded_image_anchor * 2.0 - 1.0 | |
| ) | |
| # compute the gradient of the perceptual loss w.r.t. the latent | |
| norm_grad = grad(outputs=percept_loss, inputs=latent_view)[0] | |
| # SyncDiffusion: update the original latent | |
| if view_idx != anchor_view_idx: | |
| latent_view_copy = latent_view_copy - sync_scheduler[i] * norm_grad # 1 x 4 x 64 x 64 | |
| ############################## END: PERFORM GRADIENT DESCENT (SyncDiffusion) ############################## | |
| # after gradient descent, perform a single denoising step | |
| with torch.no_grad(): | |
| latent_model_input = torch.cat([latent_view_copy] * 2) | |
| noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeds)['sample'] | |
| noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) | |
| noise_pred_new = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) | |
| out = self.scheduler.step(noise_pred_new, t, latent_view_copy) | |
| latent_view_denoised = out['prev_sample'] | |
| # merge the latent views | |
| value[:, :, h_start:h_end, w_start:w_end] += latent_view_denoised | |
| count[:, :, h_start:h_end, w_start:w_end] += 1 | |
| # take the MultiDiffusion step (average the latents) | |
| latent = torch.where(count > 0, value / count, value) | |
| # decode latents to panorama image | |
| with torch.no_grad(): | |
| imgs = self.decode_latents(latent) # [1, 3, 512, 512] | |
| img = T.ToPILImage()(imgs[0].cpu()) | |
| print(f"[INFO] Done!") | |
| return img | |