Spaces:
Runtime error
Runtime error
| import torch | |
| import numpy as np | |
| import pytorch_lightning as pl | |
| def extract_into_tensor(a, t, x_shape): | |
| b, *_ = t.shape | |
| out = a.gather(-1, t) | |
| return out.reshape(b, *((1,) * (len(x_shape) - 1))) | |
| class DDIMSolver: | |
| def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50): | |
| # DDIM sampling parameters | |
| step_ratio = timesteps // ddim_timesteps | |
| self.ddim_timesteps = (np.arange(1, ddim_timesteps + 1) * step_ratio).round().astype(np.int64) - 1 | |
| self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps] | |
| self.ddim_alpha_cumprods_prev = np.asarray( | |
| [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist() | |
| ) | |
| # convert to torch tensors | |
| self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long() | |
| self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods) | |
| self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev) | |
| def to(self, device): | |
| self.ddim_timesteps = self.ddim_timesteps.to(device) | |
| self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device) | |
| self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device) | |
| return self | |
| def ddim_step(self, pred_x0, pred_noise, timestep_index): | |
| alpha_cumprod_prev = extract_into_tensor(self.ddim_alpha_cumprods_prev.to(pred_x0.device), timestep_index, pred_x0.shape) | |
| dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise | |
| x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt | |
| return x_prev |