Spaces:
Runtime error
Runtime error
| # Copyright 2022 Lunar Ring. All rights reserved. | |
| # Written by Johannes Stelzer, email [email protected] twitter @j_stelzer | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import os | |
| import torch | |
| torch.backends.cudnn.benchmark = False | |
| torch.set_grad_enabled(False) | |
| import numpy as np | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| import time | |
| import warnings | |
| from tqdm.auto import tqdm | |
| from PIL import Image | |
| from movie_util import MovieSaver | |
| from typing import List, Optional | |
| from ldm.models.diffusion.ddpm import LatentUpscaleDiffusion, LatentInpaintDiffusion | |
| import lpips | |
| from utils import interpolate_spherical, interpolate_linear, add_frames_linear_interp, yml_load, yml_save | |
| class LatentBlending(): | |
| def __init__( | |
| self, | |
| sdh: None, | |
| guidance_scale: float = 4, | |
| guidance_scale_mid_damper: float = 0.5, | |
| mid_compression_scaler: float = 1.2): | |
| r""" | |
| Initializes the latent blending class. | |
| Args: | |
| guidance_scale: float | |
| Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). | |
| `guidance_scale` is defined as `w` of equation 2. of [Imagen | |
| Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > | |
| 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, | |
| usually at the expense of lower image quality. | |
| guidance_scale_mid_damper: float = 0.5 | |
| Reduces the guidance scale towards the middle of the transition. | |
| A value of 0.5 would decrease the guidance_scale towards the middle linearly by 0.5. | |
| mid_compression_scaler: float = 2.0 | |
| Increases the sampling density in the middle (where most changes happen). Higher value | |
| imply more values in the middle. However the inflection point can occur outside the middle, | |
| thus high values can give rough transitions. Values around 2 should be fine. | |
| """ | |
| assert guidance_scale_mid_damper > 0 \ | |
| and guidance_scale_mid_damper <= 1.0, \ | |
| f"guidance_scale_mid_damper neees to be in interval (0,1], you provided {guidance_scale_mid_damper}" | |
| self.sdh = sdh | |
| self.device = self.sdh.device | |
| self.width = self.sdh.width | |
| self.height = self.sdh.height | |
| self.guidance_scale_mid_damper = guidance_scale_mid_damper | |
| self.mid_compression_scaler = mid_compression_scaler | |
| self.seed1 = 0 | |
| self.seed2 = 0 | |
| # Initialize vars | |
| self.prompt1 = "" | |
| self.prompt2 = "" | |
| self.negative_prompt = "" | |
| self.tree_latents = [None, None] | |
| self.tree_fracts = None | |
| self.idx_injection = [] | |
| self.tree_status = None | |
| self.tree_final_imgs = [] | |
| self.list_nmb_branches_prev = [] | |
| self.list_injection_idx_prev = [] | |
| self.text_embedding1 = None | |
| self.text_embedding2 = None | |
| self.image1_lowres = None | |
| self.image2_lowres = None | |
| self.negative_prompt = None | |
| self.num_inference_steps = self.sdh.num_inference_steps | |
| self.noise_level_upscaling = 20 | |
| self.list_injection_idx = None | |
| self.list_nmb_branches = None | |
| # Mixing parameters | |
| self.branch1_crossfeed_power = 0.1 | |
| self.branch1_crossfeed_range = 0.6 | |
| self.branch1_crossfeed_decay = 0.8 | |
| self.parental_crossfeed_power = 0.1 | |
| self.parental_crossfeed_range = 0.8 | |
| self.parental_crossfeed_power_decay = 0.8 | |
| self.set_guidance_scale(guidance_scale) | |
| self.init_mode() | |
| self.multi_transition_img_first = None | |
| self.multi_transition_img_last = None | |
| self.dt_per_diff = 0 | |
| self.spatial_mask = None | |
| self.lpips = lpips.LPIPS(net='alex').cuda(self.device) | |
| def init_mode(self): | |
| r""" | |
| Sets the operational mode. Currently supported are standard, inpainting and x4 upscaling. | |
| """ | |
| if isinstance(self.sdh.model, LatentUpscaleDiffusion): | |
| self.mode = 'upscale' | |
| elif isinstance(self.sdh.model, LatentInpaintDiffusion): | |
| self.sdh.image_source = None | |
| self.sdh.mask_image = None | |
| self.mode = 'inpaint' | |
| else: | |
| self.mode = 'standard' | |
| def set_guidance_scale(self, guidance_scale): | |
| r""" | |
| sets the guidance scale. | |
| """ | |
| self.guidance_scale_base = guidance_scale | |
| self.guidance_scale = guidance_scale | |
| self.sdh.guidance_scale = guidance_scale | |
| def set_negative_prompt(self, negative_prompt): | |
| r"""Set the negative prompt. Currenty only one negative prompt is supported | |
| """ | |
| self.negative_prompt = negative_prompt | |
| self.sdh.set_negative_prompt(negative_prompt) | |
| def set_guidance_mid_dampening(self, fract_mixing): | |
| r""" | |
| Tunes the guidance scale down as a linear function of fract_mixing, | |
| towards 0.5 the minimum will be reached. | |
| """ | |
| mid_factor = 1 - np.abs(fract_mixing - 0.5) / 0.5 | |
| max_guidance_reduction = self.guidance_scale_base * (1 - self.guidance_scale_mid_damper) - 1 | |
| guidance_scale_effective = self.guidance_scale_base - max_guidance_reduction * mid_factor | |
| self.guidance_scale = guidance_scale_effective | |
| self.sdh.guidance_scale = guidance_scale_effective | |
| def set_branch1_crossfeed(self, crossfeed_power, crossfeed_range, crossfeed_decay): | |
| r""" | |
| Sets the crossfeed parameters for the first branch to the last branch. | |
| Args: | |
| crossfeed_power: float [0,1] | |
| Controls the level of cross-feeding between the first and last image branch. | |
| crossfeed_range: float [0,1] | |
| Sets the duration of active crossfeed during development. | |
| crossfeed_decay: float [0,1] | |
| Sets decay for branch1_crossfeed_power. Lower values make the decay stronger across the range. | |
| """ | |
| self.branch1_crossfeed_power = np.clip(crossfeed_power, 0, 1) | |
| self.branch1_crossfeed_range = np.clip(crossfeed_range, 0, 1) | |
| self.branch1_crossfeed_decay = np.clip(crossfeed_decay, 0, 1) | |
| def set_parental_crossfeed(self, crossfeed_power, crossfeed_range, crossfeed_decay): | |
| r""" | |
| Sets the crossfeed parameters for all transition images (within the first and last branch). | |
| Args: | |
| crossfeed_power: float [0,1] | |
| Controls the level of cross-feeding from the parental branches | |
| crossfeed_range: float [0,1] | |
| Sets the duration of active crossfeed during development. | |
| crossfeed_decay: float [0,1] | |
| Sets decay for branch1_crossfeed_power. Lower values make the decay stronger across the range. | |
| """ | |
| self.parental_crossfeed_power = np.clip(crossfeed_power, 0, 1) | |
| self.parental_crossfeed_range = np.clip(crossfeed_range, 0, 1) | |
| self.parental_crossfeed_power_decay = np.clip(crossfeed_decay, 0, 1) | |
| def set_prompt1(self, prompt: str): | |
| r""" | |
| Sets the first prompt (for the first keyframe) including text embeddings. | |
| Args: | |
| prompt: str | |
| ABC trending on artstation painted by Greg Rutkowski | |
| """ | |
| prompt = prompt.replace("_", " ") | |
| self.prompt1 = prompt | |
| self.text_embedding1 = self.get_text_embeddings(self.prompt1) | |
| def set_prompt2(self, prompt: str): | |
| r""" | |
| Sets the second prompt (for the second keyframe) including text embeddings. | |
| Args: | |
| prompt: str | |
| XYZ trending on artstation painted by Greg Rutkowski | |
| """ | |
| prompt = prompt.replace("_", " ") | |
| self.prompt2 = prompt | |
| self.text_embedding2 = self.get_text_embeddings(self.prompt2) | |
| def set_image1(self, image: Image): | |
| r""" | |
| Sets the first image (keyframe), relevant for the upscaling model transitions. | |
| Args: | |
| image: Image | |
| """ | |
| self.image1_lowres = image | |
| def set_image2(self, image: Image): | |
| r""" | |
| Sets the second image (keyframe), relevant for the upscaling model transitions. | |
| Args: | |
| image: Image | |
| """ | |
| self.image2_lowres = image | |
| def run_transition( | |
| self, | |
| recycle_img1: Optional[bool] = False, | |
| recycle_img2: Optional[bool] = False, | |
| num_inference_steps: Optional[int] = 30, | |
| depth_strength: Optional[float] = 0.3, | |
| t_compute_max_allowed: Optional[float] = None, | |
| nmb_max_branches: Optional[int] = None, | |
| fixed_seeds: Optional[List[int]] = None): | |
| r""" | |
| Function for computing transitions. | |
| Returns a list of transition images using spherical latent blending. | |
| Args: | |
| recycle_img1: Optional[bool]: | |
| Don't recompute the latents for the first keyframe (purely prompt1). Saves compute. | |
| recycle_img2: Optional[bool]: | |
| Don't recompute the latents for the second keyframe (purely prompt2). Saves compute. | |
| num_inference_steps: | |
| Number of diffusion steps. Higher values will take more compute time. | |
| depth_strength: | |
| Determines how deep the first injection will happen. | |
| Deeper injections will cause (unwanted) formation of new structures, | |
| more shallow values will go into alpha-blendy land. | |
| t_compute_max_allowed: | |
| Either provide t_compute_max_allowed or nmb_max_branches. | |
| The maximum time allowed for computation. Higher values give better results but take longer. | |
| nmb_max_branches: int | |
| Either provide t_compute_max_allowed or nmb_max_branches. The maximum number of branches to be computed. Higher values give better | |
| results. Use this if you want to have controllable results independent | |
| of your computer. | |
| fixed_seeds: Optional[List[int)]: | |
| You can supply two seeds that are used for the first and second keyframe (prompt1 and prompt2). | |
| Otherwise random seeds will be taken. | |
| """ | |
| # Sanity checks first | |
| assert self.text_embedding1 is not None, 'Set the first text embedding with .set_prompt1(...) before' | |
| assert self.text_embedding2 is not None, 'Set the second text embedding with .set_prompt2(...) before' | |
| # Random seeds | |
| if fixed_seeds is not None: | |
| if fixed_seeds == 'randomize': | |
| fixed_seeds = list(np.random.randint(0, 1000000, 2).astype(np.int32)) | |
| else: | |
| assert len(fixed_seeds) == 2, "Supply a list with len = 2" | |
| self.seed1 = fixed_seeds[0] | |
| self.seed2 = fixed_seeds[1] | |
| # Ensure correct num_inference_steps in holder | |
| self.num_inference_steps = num_inference_steps | |
| self.sdh.num_inference_steps = num_inference_steps | |
| # Compute / Recycle first image | |
| if not recycle_img1 or len(self.tree_latents[0]) != self.num_inference_steps: | |
| list_latents1 = self.compute_latents1() | |
| else: | |
| list_latents1 = self.tree_latents[0] | |
| # Compute / Recycle first image | |
| if not recycle_img2 or len(self.tree_latents[-1]) != self.num_inference_steps: | |
| list_latents2 = self.compute_latents2() | |
| else: | |
| list_latents2 = self.tree_latents[-1] | |
| # Reset the tree, injecting the edge latents1/2 we just generated/recycled | |
| self.tree_latents = [list_latents1, list_latents2] | |
| self.tree_fracts = [0.0, 1.0] | |
| self.tree_final_imgs = [self.sdh.latent2image((self.tree_latents[0][-1])), self.sdh.latent2image((self.tree_latents[-1][-1]))] | |
| self.tree_idx_injection = [0, 0] | |
| # Hard-fix. Apply spatial mask only for list_latents2 but not for transition. WIP... | |
| self.spatial_mask = None | |
| # Set up branching scheme (dependent on provided compute time) | |
| list_idx_injection, list_nmb_stems = self.get_time_based_branching(depth_strength, t_compute_max_allowed, nmb_max_branches) | |
| # Run iteratively, starting with the longest trajectory. | |
| # Always inserting new branches where they are needed most according to image similarity | |
| for s_idx in tqdm(range(len(list_idx_injection))): | |
| nmb_stems = list_nmb_stems[s_idx] | |
| idx_injection = list_idx_injection[s_idx] | |
| for i in range(nmb_stems): | |
| fract_mixing, b_parent1, b_parent2 = self.get_mixing_parameters(idx_injection) | |
| self.set_guidance_mid_dampening(fract_mixing) | |
| list_latents = self.compute_latents_mix(fract_mixing, b_parent1, b_parent2, idx_injection) | |
| self.insert_into_tree(fract_mixing, idx_injection, list_latents) | |
| # print(f"fract_mixing: {fract_mixing} idx_injection {idx_injection}") | |
| return self.tree_final_imgs | |
| def compute_latents1(self, return_image=False): | |
| r""" | |
| Runs a diffusion trajectory for the first image | |
| Args: | |
| return_image: bool | |
| whether to return an image or the list of latents | |
| """ | |
| print("starting compute_latents1") | |
| list_conditionings = self.get_mixed_conditioning(0) | |
| t0 = time.time() | |
| latents_start = self.get_noise(self.seed1) | |
| list_latents1 = self.run_diffusion( | |
| list_conditionings, | |
| latents_start=latents_start, | |
| idx_start=0) | |
| t1 = time.time() | |
| self.dt_per_diff = (t1 - t0) / self.num_inference_steps | |
| self.tree_latents[0] = list_latents1 | |
| if return_image: | |
| return self.sdh.latent2image(list_latents1[-1]) | |
| else: | |
| return list_latents1 | |
| def compute_latents2(self, return_image=False): | |
| r""" | |
| Runs a diffusion trajectory for the last image, which may be affected by the first image's trajectory. | |
| Args: | |
| return_image: bool | |
| whether to return an image or the list of latents | |
| """ | |
| print("starting compute_latents2") | |
| list_conditionings = self.get_mixed_conditioning(1) | |
| latents_start = self.get_noise(self.seed2) | |
| # Influence from branch1 | |
| if self.branch1_crossfeed_power > 0.0: | |
| # Set up the mixing_coeffs | |
| idx_mixing_stop = int(round(self.num_inference_steps * self.branch1_crossfeed_range)) | |
| mixing_coeffs = list(np.linspace(self.branch1_crossfeed_power, self.branch1_crossfeed_power * self.branch1_crossfeed_decay, idx_mixing_stop)) | |
| mixing_coeffs.extend((self.num_inference_steps - idx_mixing_stop) * [0]) | |
| list_latents_mixing = self.tree_latents[0] | |
| list_latents2 = self.run_diffusion( | |
| list_conditionings, | |
| latents_start=latents_start, | |
| idx_start=0, | |
| list_latents_mixing=list_latents_mixing, | |
| mixing_coeffs=mixing_coeffs) | |
| else: | |
| list_latents2 = self.run_diffusion(list_conditionings, latents_start) | |
| self.tree_latents[-1] = list_latents2 | |
| if return_image: | |
| return self.sdh.latent2image(list_latents2[-1]) | |
| else: | |
| return list_latents2 | |
| def compute_latents_mix(self, fract_mixing, b_parent1, b_parent2, idx_injection): | |
| r""" | |
| Runs a diffusion trajectory, using the latents from the respective parents | |
| Args: | |
| fract_mixing: float | |
| the fraction along the transition axis [0, 1] | |
| b_parent1: int | |
| index of parent1 to be used | |
| b_parent2: int | |
| index of parent2 to be used | |
| idx_injection: int | |
| the index in terms of diffusion steps, where the next insertion will start. | |
| """ | |
| list_conditionings = self.get_mixed_conditioning(fract_mixing) | |
| fract_mixing_parental = (fract_mixing - self.tree_fracts[b_parent1]) / (self.tree_fracts[b_parent2] - self.tree_fracts[b_parent1]) | |
| # idx_reversed = self.num_inference_steps - idx_injection | |
| list_latents_parental_mix = [] | |
| for i in range(self.num_inference_steps): | |
| latents_p1 = self.tree_latents[b_parent1][i] | |
| latents_p2 = self.tree_latents[b_parent2][i] | |
| if latents_p1 is None or latents_p2 is None: | |
| latents_parental = None | |
| else: | |
| latents_parental = interpolate_spherical(latents_p1, latents_p2, fract_mixing_parental) | |
| list_latents_parental_mix.append(latents_parental) | |
| idx_mixing_stop = int(round(self.num_inference_steps * self.parental_crossfeed_range)) | |
| mixing_coeffs = idx_injection * [self.parental_crossfeed_power] | |
| nmb_mixing = idx_mixing_stop - idx_injection | |
| if nmb_mixing > 0: | |
| mixing_coeffs.extend(list(np.linspace(self.parental_crossfeed_power, self.parental_crossfeed_power * self.parental_crossfeed_power_decay, nmb_mixing))) | |
| mixing_coeffs.extend((self.num_inference_steps - len(mixing_coeffs)) * [0]) | |
| latents_start = list_latents_parental_mix[idx_injection - 1] | |
| list_latents = self.run_diffusion( | |
| list_conditionings, | |
| latents_start=latents_start, | |
| idx_start=idx_injection, | |
| list_latents_mixing=list_latents_parental_mix, | |
| mixing_coeffs=mixing_coeffs) | |
| return list_latents | |
| def get_time_based_branching(self, depth_strength, t_compute_max_allowed=None, nmb_max_branches=None): | |
| r""" | |
| Sets up the branching scheme dependent on the time that is granted for compute. | |
| The scheme uses an estimation derived from the first image's computation speed. | |
| Either provide t_compute_max_allowed or nmb_max_branches | |
| Args: | |
| depth_strength: | |
| Determines how deep the first injection will happen. | |
| Deeper injections will cause (unwanted) formation of new structures, | |
| more shallow values will go into alpha-blendy land. | |
| t_compute_max_allowed: float | |
| The maximum time allowed for computation. Higher values give better results | |
| but take longer. Use this if you want to fix your waiting time for the results. | |
| nmb_max_branches: int | |
| The maximum number of branches to be computed. Higher values give better | |
| results. Use this if you want to have controllable results independent | |
| of your computer. | |
| """ | |
| idx_injection_base = int(round(self.num_inference_steps * depth_strength)) | |
| list_idx_injection = np.arange(idx_injection_base, self.num_inference_steps - 1, 3) | |
| list_nmb_stems = np.ones(len(list_idx_injection), dtype=np.int32) | |
| t_compute = 0 | |
| if nmb_max_branches is None: | |
| assert t_compute_max_allowed is not None, "Either specify t_compute_max_allowed or nmb_max_branches" | |
| stop_criterion = "t_compute_max_allowed" | |
| elif t_compute_max_allowed is None: | |
| assert nmb_max_branches is not None, "Either specify t_compute_max_allowed or nmb_max_branches" | |
| stop_criterion = "nmb_max_branches" | |
| nmb_max_branches -= 2 # Discounting the outer frames | |
| else: | |
| raise ValueError("Either specify t_compute_max_allowed or nmb_max_branches") | |
| stop_criterion_reached = False | |
| is_first_iteration = True | |
| while not stop_criterion_reached: | |
| list_compute_steps = self.num_inference_steps - list_idx_injection | |
| list_compute_steps *= list_nmb_stems | |
| t_compute = np.sum(list_compute_steps) * self.dt_per_diff + 0.15 * np.sum(list_nmb_stems) | |
| increase_done = False | |
| for s_idx in range(len(list_nmb_stems) - 1): | |
| if list_nmb_stems[s_idx + 1] / list_nmb_stems[s_idx] >= 2: | |
| list_nmb_stems[s_idx] += 1 | |
| increase_done = True | |
| break | |
| if not increase_done: | |
| list_nmb_stems[-1] += 1 | |
| if stop_criterion == "t_compute_max_allowed" and t_compute > t_compute_max_allowed: | |
| stop_criterion_reached = True | |
| elif stop_criterion == "nmb_max_branches" and np.sum(list_nmb_stems) >= nmb_max_branches: | |
| stop_criterion_reached = True | |
| if is_first_iteration: | |
| # Need to undersample. | |
| list_idx_injection = np.linspace(list_idx_injection[0], list_idx_injection[-1], nmb_max_branches).astype(np.int32) | |
| list_nmb_stems = np.ones(len(list_idx_injection), dtype=np.int32) | |
| else: | |
| is_first_iteration = False | |
| # print(f"t_compute {t_compute} list_nmb_stems {list_nmb_stems}") | |
| return list_idx_injection, list_nmb_stems | |
| def get_mixing_parameters(self, idx_injection): | |
| r""" | |
| Computes which parental latents should be mixed together to achieve a smooth blend. | |
| As metric, we are using lpips image similarity. The insertion takes place | |
| where the metric is maximal. | |
| Args: | |
| idx_injection: int | |
| the index in terms of diffusion steps, where the next insertion will start. | |
| """ | |
| # get_lpips_similarity | |
| similarities = [] | |
| for i in range(len(self.tree_final_imgs) - 1): | |
| similarities.append(self.get_lpips_similarity(self.tree_final_imgs[i], self.tree_final_imgs[i + 1])) | |
| b_closest1 = np.argmax(similarities) | |
| b_closest2 = b_closest1 + 1 | |
| fract_closest1 = self.tree_fracts[b_closest1] | |
| fract_closest2 = self.tree_fracts[b_closest2] | |
| # Ensure that the parents are indeed older! | |
| b_parent1 = b_closest1 | |
| while True: | |
| if self.tree_idx_injection[b_parent1] < idx_injection: | |
| break | |
| else: | |
| b_parent1 -= 1 | |
| b_parent2 = b_closest2 | |
| while True: | |
| if self.tree_idx_injection[b_parent2] < idx_injection: | |
| break | |
| else: | |
| b_parent2 += 1 | |
| fract_mixing = (fract_closest1 + fract_closest2) / 2 | |
| return fract_mixing, b_parent1, b_parent2 | |
| def insert_into_tree(self, fract_mixing, idx_injection, list_latents): | |
| r""" | |
| Inserts all necessary parameters into the trajectory tree. | |
| Args: | |
| fract_mixing: float | |
| the fraction along the transition axis [0, 1] | |
| idx_injection: int | |
| the index in terms of diffusion steps, where the next insertion will start. | |
| list_latents: list | |
| list of the latents to be inserted | |
| """ | |
| b_parent1, b_parent2 = self.get_closest_idx(fract_mixing) | |
| self.tree_latents.insert(b_parent1 + 1, list_latents) | |
| self.tree_final_imgs.insert(b_parent1 + 1, self.sdh.latent2image(list_latents[-1])) | |
| self.tree_fracts.insert(b_parent1 + 1, fract_mixing) | |
| self.tree_idx_injection.insert(b_parent1 + 1, idx_injection) | |
| def get_spatial_mask_template(self): | |
| r""" | |
| Experimental helper function to get a spatial mask template. | |
| """ | |
| shape_latents = [self.sdh.C, self.sdh.height // self.sdh.f, self.sdh.width // self.sdh.f] | |
| C, H, W = shape_latents | |
| return np.ones((H, W)) | |
| def set_spatial_mask(self, img_mask): | |
| r""" | |
| Experimental helper function to set a spatial mask. | |
| The mask forces latents to be overwritten. | |
| Args: | |
| img_mask: | |
| mask image [0,1]. You can get a template using get_spatial_mask_template | |
| """ | |
| shape_latents = [self.sdh.C, self.sdh.height // self.sdh.f, self.sdh.width // self.sdh.f] | |
| C, H, W = shape_latents | |
| img_mask = np.asarray(img_mask) | |
| assert len(img_mask.shape) == 2, "Currently, only 2D images are supported as mask" | |
| img_mask = np.clip(img_mask, 0, 1) | |
| assert img_mask.shape[0] == H, f"Your mask needs to be of dimension {H} x {W}" | |
| assert img_mask.shape[1] == W, f"Your mask needs to be of dimension {H} x {W}" | |
| spatial_mask = torch.from_numpy(img_mask).to(device=self.device) | |
| spatial_mask = torch.unsqueeze(spatial_mask, 0) | |
| spatial_mask = spatial_mask.repeat((C, 1, 1)) | |
| spatial_mask = torch.unsqueeze(spatial_mask, 0) | |
| self.spatial_mask = spatial_mask | |
| def get_noise(self, seed): | |
| r""" | |
| Helper function to get noise given seed. | |
| Args: | |
| seed: int | |
| """ | |
| generator = torch.Generator(device=self.sdh.device).manual_seed(int(seed)) | |
| if self.mode == 'standard': | |
| shape_latents = [self.sdh.C, self.sdh.height // self.sdh.f, self.sdh.width // self.sdh.f] | |
| C, H, W = shape_latents | |
| elif self.mode == 'upscale': | |
| w = self.image1_lowres.size[0] | |
| h = self.image1_lowres.size[1] | |
| shape_latents = [self.sdh.model.channels, h, w] | |
| C, H, W = shape_latents | |
| return torch.randn((1, C, H, W), generator=generator, device=self.sdh.device) | |
| def run_diffusion( | |
| self, | |
| list_conditionings, | |
| latents_start: torch.FloatTensor = None, | |
| idx_start: int = 0, | |
| list_latents_mixing=None, | |
| mixing_coeffs=0.0, | |
| return_image: Optional[bool] = False): | |
| r""" | |
| Wrapper function for diffusion runners. | |
| Depending on the mode, the correct one will be executed. | |
| Args: | |
| list_conditionings: list | |
| List of all conditionings for the diffusion model. | |
| latents_start: torch.FloatTensor | |
| Latents that are used for injection | |
| idx_start: int | |
| Index of the diffusion process start and where the latents_for_injection are injected | |
| list_latents_mixing: torch.FloatTensor | |
| List of latents (latent trajectories) that are used for mixing | |
| mixing_coeffs: float or list | |
| Coefficients, how strong each element of list_latents_mixing will be mixed in. | |
| return_image: Optional[bool] | |
| Optionally return image directly | |
| """ | |
| # Ensure correct num_inference_steps in Holder | |
| self.sdh.num_inference_steps = self.num_inference_steps | |
| assert type(list_conditionings) is list, "list_conditionings need to be a list" | |
| if self.mode == 'standard': | |
| text_embeddings = list_conditionings[0] | |
| return self.sdh.run_diffusion_standard( | |
| text_embeddings=text_embeddings, | |
| latents_start=latents_start, | |
| idx_start=idx_start, | |
| list_latents_mixing=list_latents_mixing, | |
| mixing_coeffs=mixing_coeffs, | |
| spatial_mask=self.spatial_mask, | |
| return_image=return_image) | |
| elif self.mode == 'upscale': | |
| cond = list_conditionings[0] | |
| uc_full = list_conditionings[1] | |
| return self.sdh.run_diffusion_upscaling( | |
| cond, | |
| uc_full, | |
| latents_start=latents_start, | |
| idx_start=idx_start, | |
| list_latents_mixing=list_latents_mixing, | |
| mixing_coeffs=mixing_coeffs, | |
| return_image=return_image) | |
| def run_upscaling( | |
| self, | |
| dp_img: str, | |
| depth_strength: float = 0.65, | |
| num_inference_steps: int = 100, | |
| nmb_max_branches_highres: int = 5, | |
| nmb_max_branches_lowres: int = 6, | |
| duration_single_segment=3, | |
| fps=24, | |
| fixed_seeds: Optional[List[int]] = None): | |
| r""" | |
| Runs upscaling with the x4 model. Requires that you run a transition before with a low-res model and save the results using write_imgs_transition. | |
| Args: | |
| dp_img: str | |
| Path to the low-res transition path (as saved in write_imgs_transition) | |
| depth_strength: | |
| Determines how deep the first injection will happen. | |
| Deeper injections will cause (unwanted) formation of new structures, | |
| more shallow values will go into alpha-blendy land. | |
| num_inference_steps: | |
| Number of diffusion steps. Higher values will take more compute time. | |
| nmb_max_branches_highres: int | |
| Number of final branches of the upscaling transition pass. Note this is the number | |
| of branches between each pair of low-res images. | |
| nmb_max_branches_lowres: int | |
| Number of input low-res images, subsampling all transition images written in the low-res pass. | |
| Setting this number lower (e.g. 6) will decrease the compute time but not affect the results too much. | |
| duration_single_segment: float | |
| The duration of each high-res movie segment. You will have nmb_max_branches_lowres-1 segments in total. | |
| fps: float | |
| frames per second of movie | |
| fixed_seeds: Optional[List[int)]: | |
| You can supply two seeds that are used for the first and second keyframe (prompt1 and prompt2). | |
| Otherwise random seeds will be taken. | |
| """ | |
| fp_yml = os.path.join(dp_img, "lowres.yaml") | |
| fp_movie = os.path.join(dp_img, "movie_highres.mp4") | |
| ms = MovieSaver(fp_movie, fps=fps) | |
| assert os.path.isfile(fp_yml), "lowres.yaml does not exist. did you forget run_upscaling_step1?" | |
| dict_stuff = yml_load(fp_yml) | |
| # load lowres images | |
| nmb_images_lowres = dict_stuff['nmb_images'] | |
| prompt1 = dict_stuff['prompt1'] | |
| prompt2 = dict_stuff['prompt2'] | |
| idx_img_lowres = np.round(np.linspace(0, nmb_images_lowres - 1, nmb_max_branches_lowres)).astype(np.int32) | |
| imgs_lowres = [] | |
| for i in idx_img_lowres: | |
| fp_img_lowres = os.path.join(dp_img, f"lowres_img_{str(i).zfill(4)}.jpg") | |
| assert os.path.isfile(fp_img_lowres), f"{fp_img_lowres} does not exist. did you forget run_upscaling_step1?" | |
| imgs_lowres.append(Image.open(fp_img_lowres)) | |
| # set up upscaling | |
| text_embeddingA = self.sdh.get_text_embedding(prompt1) | |
| text_embeddingB = self.sdh.get_text_embedding(prompt2) | |
| list_fract_mixing = np.linspace(0, 1, nmb_max_branches_lowres - 1) | |
| for i in range(nmb_max_branches_lowres - 1): | |
| print(f"Starting movie segment {i+1}/{nmb_max_branches_lowres-1}") | |
| self.text_embedding1 = interpolate_linear(text_embeddingA, text_embeddingB, list_fract_mixing[i]) | |
| self.text_embedding2 = interpolate_linear(text_embeddingA, text_embeddingB, 1 - list_fract_mixing[i]) | |
| if i == 0: | |
| recycle_img1 = False | |
| else: | |
| self.swap_forward() | |
| recycle_img1 = True | |
| self.set_image1(imgs_lowres[i]) | |
| self.set_image2(imgs_lowres[i + 1]) | |
| list_imgs = self.run_transition( | |
| recycle_img1=recycle_img1, | |
| recycle_img2=False, | |
| num_inference_steps=num_inference_steps, | |
| depth_strength=depth_strength, | |
| nmb_max_branches=nmb_max_branches_highres) | |
| list_imgs_interp = add_frames_linear_interp(list_imgs, fps, duration_single_segment) | |
| # Save movie frame | |
| for img in list_imgs_interp: | |
| ms.write_frame(img) | |
| ms.finalize() | |
| def get_mixed_conditioning(self, fract_mixing): | |
| if self.mode == 'standard': | |
| text_embeddings_mix = interpolate_linear(self.text_embedding1, self.text_embedding2, fract_mixing) | |
| list_conditionings = [text_embeddings_mix] | |
| elif self.mode == 'inpaint': | |
| text_embeddings_mix = interpolate_linear(self.text_embedding1, self.text_embedding2, fract_mixing) | |
| list_conditionings = [text_embeddings_mix] | |
| elif self.mode == 'upscale': | |
| text_embeddings_mix = interpolate_linear(self.text_embedding1, self.text_embedding2, fract_mixing) | |
| cond, uc_full = self.sdh.get_cond_upscaling(self.image1_lowres, text_embeddings_mix, self.noise_level_upscaling) | |
| condB, uc_fullB = self.sdh.get_cond_upscaling(self.image2_lowres, text_embeddings_mix, self.noise_level_upscaling) | |
| cond['c_concat'][0] = interpolate_spherical(cond['c_concat'][0], condB['c_concat'][0], fract_mixing) | |
| uc_full['c_concat'][0] = interpolate_spherical(uc_full['c_concat'][0], uc_fullB['c_concat'][0], fract_mixing) | |
| list_conditionings = [cond, uc_full] | |
| else: | |
| raise ValueError(f"mix_conditioning: unknown mode {self.mode}") | |
| return list_conditionings | |
| def get_text_embeddings( | |
| self, | |
| prompt: str): | |
| r""" | |
| Computes the text embeddings provided a string with a prompts. | |
| Adapted from stable diffusion repo | |
| Args: | |
| prompt: str | |
| ABC trending on artstation painted by Old Greg. | |
| """ | |
| return self.sdh.get_text_embedding(prompt) | |
| def write_imgs_transition(self, dp_img): | |
| r""" | |
| Writes the transition images into the folder dp_img. | |
| Requires run_transition to be completed. | |
| Args: | |
| dp_img: str | |
| Directory, into which the transition images, yaml file and latents are written. | |
| """ | |
| imgs_transition = self.tree_final_imgs | |
| os.makedirs(dp_img, exist_ok=True) | |
| for i, img in enumerate(imgs_transition): | |
| img_leaf = Image.fromarray(img) | |
| img_leaf.save(os.path.join(dp_img, f"lowres_img_{str(i).zfill(4)}.jpg")) | |
| fp_yml = os.path.join(dp_img, "lowres.yaml") | |
| self.save_statedict(fp_yml) | |
| def write_movie_transition(self, fp_movie, duration_transition, fps=30): | |
| r""" | |
| Writes the transition movie to fp_movie, using the given duration and fps.. | |
| The missing frames are linearly interpolated. | |
| Args: | |
| fp_movie: str | |
| file pointer to the final movie. | |
| duration_transition: float | |
| duration of the movie in seonds | |
| fps: int | |
| fps of the movie | |
| """ | |
| # Let's get more cheap frames via linear interpolation (duration_transition*fps frames) | |
| imgs_transition_ext = add_frames_linear_interp(self.tree_final_imgs, duration_transition, fps) | |
| # Save as MP4 | |
| if os.path.isfile(fp_movie): | |
| os.remove(fp_movie) | |
| ms = MovieSaver(fp_movie, fps=fps, shape_hw=[self.sdh.height, self.sdh.width]) | |
| for img in tqdm(imgs_transition_ext): | |
| ms.write_frame(img) | |
| ms.finalize() | |
| def save_statedict(self, fp_yml): | |
| # Dump everything relevant into yaml | |
| imgs_transition = self.tree_final_imgs | |
| state_dict = self.get_state_dict() | |
| state_dict['nmb_images'] = len(imgs_transition) | |
| yml_save(fp_yml, state_dict) | |
| def get_state_dict(self): | |
| state_dict = {} | |
| grab_vars = ['prompt1', 'prompt2', 'seed1', 'seed2', 'height', 'width', | |
| 'num_inference_steps', 'depth_strength', 'guidance_scale', | |
| 'guidance_scale_mid_damper', 'mid_compression_scaler', 'negative_prompt', | |
| 'branch1_crossfeed_power', 'branch1_crossfeed_range', 'branch1_crossfeed_decay' | |
| 'parental_crossfeed_power', 'parental_crossfeed_range', 'parental_crossfeed_power_decay'] | |
| for v in grab_vars: | |
| if hasattr(self, v): | |
| if v == 'seed1' or v == 'seed2': | |
| state_dict[v] = int(getattr(self, v)) | |
| elif v == 'guidance_scale': | |
| state_dict[v] = float(getattr(self, v)) | |
| else: | |
| try: | |
| state_dict[v] = getattr(self, v) | |
| except Exception: | |
| pass | |
| return state_dict | |
| def randomize_seed(self): | |
| r""" | |
| Set a random seed for a fresh start. | |
| """ | |
| seed = np.random.randint(999999999) | |
| self.set_seed(seed) | |
| def set_seed(self, seed: int): | |
| r""" | |
| Set a the seed for a fresh start. | |
| """ | |
| self.seed = seed | |
| self.sdh.seed = seed | |
| def set_width(self, width): | |
| r""" | |
| Set the width of the resulting image. | |
| """ | |
| assert np.mod(width, 64) == 0, "set_width: value needs to be divisible by 64" | |
| self.width = width | |
| self.sdh.width = width | |
| def set_height(self, height): | |
| r""" | |
| Set the height of the resulting image. | |
| """ | |
| assert np.mod(height, 64) == 0, "set_height: value needs to be divisible by 64" | |
| self.height = height | |
| self.sdh.height = height | |
| def swap_forward(self): | |
| r""" | |
| Moves over keyframe two -> keyframe one. Useful for making a sequence of transitions | |
| as in run_multi_transition() | |
| """ | |
| # Move over all latents | |
| self.tree_latents[0] = self.tree_latents[-1] | |
| # Move over prompts and text embeddings | |
| self.prompt1 = self.prompt2 | |
| self.text_embedding1 = self.text_embedding2 | |
| # Final cleanup for extra sanity | |
| self.tree_final_imgs = [] | |
| def get_lpips_similarity(self, imgA, imgB): | |
| r""" | |
| Computes the image similarity between two images imgA and imgB. | |
| Used to determine the optimal point of insertion to create smooth transitions. | |
| High values indicate low similarity. | |
| """ | |
| tensorA = torch.from_numpy(imgA).float().cuda(self.device) | |
| tensorA = 2 * tensorA / 255.0 - 1 | |
| tensorA = tensorA.permute([2, 0, 1]).unsqueeze(0) | |
| tensorB = torch.from_numpy(imgB).float().cuda(self.device) | |
| tensorB = 2 * tensorB / 255.0 - 1 | |
| tensorB = tensorB.permute([2, 0, 1]).unsqueeze(0) | |
| lploss = self.lpips(tensorA, tensorB) | |
| lploss = float(lploss[0][0][0][0]) | |
| return lploss | |
| # Auxiliary functions | |
| def get_closest_idx( | |
| self, | |
| fract_mixing: float): | |
| r""" | |
| Helper function to retrieve the parents for any given mixing. | |
| Example: fract_mixing = 0.4 and self.tree_fracts = [0, 0.3, 0.6, 1.0] | |
| Will return the two closest values here, i.e. [1, 2] | |
| """ | |
| pdist = fract_mixing - np.asarray(self.tree_fracts) | |
| pdist_pos = pdist.copy() | |
| pdist_pos[pdist_pos < 0] = np.inf | |
| b_parent1 = np.argmin(pdist_pos) | |
| pdist_neg = -pdist.copy() | |
| pdist_neg[pdist_neg <= 0] = np.inf | |
| b_parent2 = np.argmin(pdist_neg) | |
| if b_parent1 > b_parent2: | |
| tmp = b_parent2 | |
| b_parent2 = b_parent1 | |
| b_parent1 = tmp | |
| return b_parent1, b_parent2 | |