Spaces:
Configuration error
Configuration error
File size: 11,332 Bytes
2b67076 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 |
import torch
from .utils import *
from functools import partial
# Many thanks to the LanPaint team for this implementation (https://github.com/scraed/LanPaint/)
def _pack_latents(latents):
batch_size, num_channels_latents, _, height, width = latents.shape
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
return latents
def _unpack_latents(latents, height, width, vae_scale_factor=8):
batch_size, num_patches, channels = latents.shape
height = 2 * (int(height) // (vae_scale_factor * 2))
width = 2 * (int(width) // (vae_scale_factor * 2))
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)
latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width)
return latents
class LanPaint():
def __init__(self, NSteps = 5, Friction = 15, Lambda = 8, Beta = 1, StepSize = 0.15, IS_FLUX = True, IS_FLOW = False):
self.n_steps = NSteps
self.chara_lamb = Lambda
self.IS_FLUX = IS_FLUX
self.IS_FLOW = IS_FLOW
self.step_size = StepSize
self.friction = Friction
self.chara_beta = Beta
self.img_dim_size = None
def add_none_dims(self, array):
# Create a tuple with ':' for the first dimension and 'None' repeated num_nones times
index = (slice(None),) + (None,) * (self.img_dim_size-1)
return array[index]
def remove_none_dims(self, array):
# Create a tuple with ':' for the first dimension and 'None' repeated num_nones times
index = (slice(None),) + (0,) * (self.img_dim_size-1)
return array[index]
def __call__(self, denoise, cfg_predictions, true_cfg_scale, cfg_BIG, x, latent_image, noise, sigma, latent_mask, n_steps=None, height =720, width = 1280, vae_scale_factor = 8):
latent_image = _unpack_latents(latent_image, height=height, width=width, vae_scale_factor=vae_scale_factor)
noise = _unpack_latents(noise, height=height, width=width, vae_scale_factor=vae_scale_factor)
x = _unpack_latents(x, height=height, width=width, vae_scale_factor=vae_scale_factor)
latent_mask = _unpack_latents(latent_mask, height=height, width=width, vae_scale_factor=vae_scale_factor)
self.height = height
self.width = width
self.vae_scale_factor = vae_scale_factor
self.img_dim_size = len(x.shape)
self.latent_image = latent_image
self.noise = noise
if n_steps is None:
n_steps = self.n_steps
out = self.LanPaint(denoise, cfg_predictions, true_cfg_scale, cfg_BIG, x, sigma, latent_mask, n_steps, self.IS_FLUX, self.IS_FLOW)
out = _pack_latents(out)
return out
def LanPaint(self, denoise, cfg_predictions, true_cfg_scale, cfg_BIG, x, sigma, latent_mask, n_steps, IS_FLUX, IS_FLOW):
if IS_FLUX:
cfg_BIG = 1.0
def double_denoise(latents, t):
latents = _pack_latents(latents)
noise_pred, neg_noise_pred = denoise(latents, true_cfg_scale)
if noise_pred == None: return None, None
predict_std = cfg_predictions(noise_pred, neg_noise_pred, true_cfg_scale, t)
predict_std = _unpack_latents(predict_std, self.height, self.width, self.vae_scale_factor)
if true_cfg_scale == cfg_BIG:
predict_big = predict_std
else:
predict_big = cfg_predictions(noise_pred, neg_noise_pred, cfg_BIG, t)
predict_big = _unpack_latents(predict_big, self.height, self.width, self.vae_scale_factor)
return predict_std, predict_big
if len(sigma.shape) == 0:
sigma = torch.tensor([sigma.item()])
latent_mask = 1 - latent_mask
if IS_FLUX or IS_FLOW:
Flow_t = sigma
abt = (1 - Flow_t)**2 / ((1 - Flow_t)**2 + Flow_t**2 )
VE_Sigma = Flow_t / (1 - Flow_t)
#print("t", torch.mean( sigma ).item(), "VE_Sigma", torch.mean( VE_Sigma ).item())
else:
VE_Sigma = sigma
abt = 1/( 1+VE_Sigma**2 )
Flow_t = (1-abt)**0.5 / ( (1-abt)**0.5 + abt**0.5 )
# VE_Sigma, abt, Flow_t = current_times
current_times = (VE_Sigma, abt, Flow_t)
step_size = self.step_size * (1 - abt)
step_size = self.add_none_dims(step_size)
# self.inner_model.inner_model.scale_latent_inpaint returns variance exploding x_t values
# This is the replace step
# x = x * (1 - latent_mask) + self.inner_model.inner_model.scale_latent_inpaint(x=x, sigma=sigma, noise=self.noise, latent_image=self.latent_image)* latent_mask
noisy_image = self.latent_image * (1.0 - sigma) + self.noise * sigma
x = x * (1 - latent_mask) + noisy_image * latent_mask
if IS_FLUX or IS_FLOW:
x_t = x * ( self.add_none_dims(abt)**0.5 + (1-self.add_none_dims(abt))**0.5 )
else:
x_t = x / ( 1+self.add_none_dims(VE_Sigma)**2 )**0.5 # switch to variance perserving x_t values
############ LanPaint Iterations Start ###############
# after noise_scaling, noise = latent_image + noise * sigma, which is x_t in the variance exploding diffusion model notation for the known region.
args = None
for i in range(n_steps):
score_func = partial( self.score_model, y = self.latent_image, mask = latent_mask, abt = self.add_none_dims(abt), sigma = self.add_none_dims(VE_Sigma), tflow = self.add_none_dims(Flow_t), denoise_func = double_denoise )
if score_func is None: return None
x_t, args = self.langevin_dynamics(x_t, score_func , latent_mask, step_size , current_times, sigma_x = self.add_none_dims(self.sigma_x(abt)), sigma_y = self.add_none_dims(self.sigma_y(abt)), args = args)
if IS_FLUX or IS_FLOW:
x = x_t / ( self.add_none_dims(abt)**0.5 + (1-self.add_none_dims(abt))**0.5 )
else:
x = x_t * ( 1+self.add_none_dims(VE_Sigma)**2 )**0.5 # switch to variance perserving x_t values
############ LanPaint Iterations End ###############
# out is x_0
# out, _ = self.inner_model(x, sigma, model_options=model_options, seed=seed)
# out = out * (1-latent_mask) + self.latent_image * latent_mask
# return out
return x
def score_model(self, x_t, y, mask, abt, sigma, tflow, denoise_func):
lamb = self.chara_lamb
if self.IS_FLUX or self.IS_FLOW:
# compute t for flow model, with a small epsilon compensating for numerical error.
x = x_t / ( abt**0.5 + (1-abt)**0.5 ) # switch to Gaussian flow matching
x_0, x_0_BIG = denoise_func(x, self.remove_none_dims(tflow))
if x_0 is None: return None
else:
x = x_t * ( 1+sigma**2 )**0.5 # switch to variance exploding
x_0, x_0_BIG = denoise_func(x, self.remove_none_dims(sigma))
if x_0 is None: return None
score_x = -(x_t - x_0)
score_y = - (1 + lamb) * ( x_t - y ) + lamb * (x_t - x_0_BIG)
return score_x * (1 - mask) + score_y * mask
def sigma_x(self, abt):
# the time scale for the x_t update
return abt**0
def sigma_y(self, abt):
beta = self.chara_beta * abt ** 0
return beta
def langevin_dynamics(self, x_t, score, mask, step_size, current_times, sigma_x=1, sigma_y=0, args=None):
# prepare the step size and time parameters
with torch.autocast(device_type=x_t.device.type, dtype=torch.float32):
step_sizes = self.prepare_step_size(current_times, step_size, sigma_x, sigma_y)
sigma, abt, dtx, dty, Gamma_x, Gamma_y, A_x, A_y, D_x, D_y = step_sizes
# print('mask',mask.device)
if torch.mean(dtx) <= 0.:
return x_t, args
# -------------------------------------------------------------------------
# Compute the Langevin dynamics update in variance perserving notation
# -------------------------------------------------------------------------
#x0 = self.x0_evalutation(x_t, score, sigma, args)
#C = abt**0.5 * x0 / (1-abt)
A = A_x * (1-mask) + A_y * mask
D = D_x * (1-mask) + D_y * mask
dt = dtx * (1-mask) + dty * mask
Gamma = Gamma_x * (1-mask) + Gamma_y * mask
def Coef_C(x_t):
x0 = self.x0_evalutation(x_t, score, sigma, args)
C = (abt**0.5 * x0 - x_t )/ (1-abt) + A * x_t
return C
def advance_time(x_t, v, dt, Gamma, A, C, D):
dtype = x_t.dtype
with torch.autocast(device_type=x_t.device.type, dtype=torch.float32):
osc = StochasticHarmonicOscillator(Gamma, A, C, D )
x_t, v = osc.dynamics(x_t, v, dt )
x_t = x_t.to(dtype)
v = v.to(dtype)
return x_t, v
if args is None:
#v = torch.zeros_like(x_t)
v = None
C = Coef_C(x_t)
#print(torch.squeeze(dtx), torch.squeeze(dty))
x_t, v = advance_time(x_t, v, dt, Gamma, A, C, D)
else:
v, C = args
x_t, v = advance_time(x_t, v, dt/2, Gamma, A, C, D)
C_new = Coef_C(x_t)
v = v + Gamma**0.5 * ( C_new - C) *dt
x_t, v = advance_time(x_t, v, dt/2, Gamma, A, C, D)
C = C_new
return x_t, (v, C)
def prepare_step_size(self, current_times, step_size, sigma_x, sigma_y):
# -------------------------------------------------------------------------
# Unpack current times parameters (sigma and abt)
sigma, abt, flow_t = current_times
sigma = self.add_none_dims(sigma)
abt = self.add_none_dims(abt)
# Compute time step (dtx, dty) for x and y branches.
dtx = 2 * step_size * sigma_x
dty = 2 * step_size * sigma_y
# -------------------------------------------------------------------------
# Define friction parameter Gamma_hat for each branch.
# Using dtx**0 provides a tensor of the proper device/dtype.
Gamma_hat_x = self.friction **2 * self.step_size * sigma_x / 0.1 * sigma**0
Gamma_hat_y = self.friction **2 * self.step_size * sigma_y / 0.1 * sigma**0
#print("Gamma_hat_x", torch.mean(Gamma_hat_x).item(), "Gamma_hat_y", torch.mean(Gamma_hat_y).item())
# adjust dt to match denoise-addnoise steps sizes
Gamma_hat_x /= 2.
Gamma_hat_y /= 2.
A_t_x = (1) / ( 1 - abt ) * dtx / 2
A_t_y = (1+self.chara_lamb) / ( 1 - abt ) * dty / 2
A_x = A_t_x / (dtx/2)
A_y = A_t_y / (dty/2)
Gamma_x = Gamma_hat_x / (dtx/2)
Gamma_y = Gamma_hat_y / (dty/2)
#D_x = (2 * (1 + sigma**2) )**0.5
#D_y = (2 * (1 + sigma**2) )**0.5
D_x = (2 * abt**0 )**0.5
D_y = (2 * abt**0 )**0.5
return sigma, abt, dtx/2, dty/2, Gamma_x, Gamma_y, A_x, A_y, D_x, D_y
def x0_evalutation(self, x_t, score, sigma, args):
x0 = x_t + score(x_t)
return x0 |