USF00's picture
Initial commit
2b67076
raw
history blame
11.3 kB
import torch
from .utils import *
from functools import partial
# Many thanks to the LanPaint team for this implementation (https://github.com/scraed/LanPaint/)
def _pack_latents(latents):
batch_size, num_channels_latents, _, height, width = latents.shape
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
return latents
def _unpack_latents(latents, height, width, vae_scale_factor=8):
batch_size, num_patches, channels = latents.shape
height = 2 * (int(height) // (vae_scale_factor * 2))
width = 2 * (int(width) // (vae_scale_factor * 2))
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)
latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width)
return latents
class LanPaint():
def __init__(self, NSteps = 5, Friction = 15, Lambda = 8, Beta = 1, StepSize = 0.15, IS_FLUX = True, IS_FLOW = False):
self.n_steps = NSteps
self.chara_lamb = Lambda
self.IS_FLUX = IS_FLUX
self.IS_FLOW = IS_FLOW
self.step_size = StepSize
self.friction = Friction
self.chara_beta = Beta
self.img_dim_size = None
def add_none_dims(self, array):
# Create a tuple with ':' for the first dimension and 'None' repeated num_nones times
index = (slice(None),) + (None,) * (self.img_dim_size-1)
return array[index]
def remove_none_dims(self, array):
# Create a tuple with ':' for the first dimension and 'None' repeated num_nones times
index = (slice(None),) + (0,) * (self.img_dim_size-1)
return array[index]
def __call__(self, denoise, cfg_predictions, true_cfg_scale, cfg_BIG, x, latent_image, noise, sigma, latent_mask, n_steps=None, height =720, width = 1280, vae_scale_factor = 8):
latent_image = _unpack_latents(latent_image, height=height, width=width, vae_scale_factor=vae_scale_factor)
noise = _unpack_latents(noise, height=height, width=width, vae_scale_factor=vae_scale_factor)
x = _unpack_latents(x, height=height, width=width, vae_scale_factor=vae_scale_factor)
latent_mask = _unpack_latents(latent_mask, height=height, width=width, vae_scale_factor=vae_scale_factor)
self.height = height
self.width = width
self.vae_scale_factor = vae_scale_factor
self.img_dim_size = len(x.shape)
self.latent_image = latent_image
self.noise = noise
if n_steps is None:
n_steps = self.n_steps
out = self.LanPaint(denoise, cfg_predictions, true_cfg_scale, cfg_BIG, x, sigma, latent_mask, n_steps, self.IS_FLUX, self.IS_FLOW)
out = _pack_latents(out)
return out
def LanPaint(self, denoise, cfg_predictions, true_cfg_scale, cfg_BIG, x, sigma, latent_mask, n_steps, IS_FLUX, IS_FLOW):
if IS_FLUX:
cfg_BIG = 1.0
def double_denoise(latents, t):
latents = _pack_latents(latents)
noise_pred, neg_noise_pred = denoise(latents, true_cfg_scale)
if noise_pred == None: return None, None
predict_std = cfg_predictions(noise_pred, neg_noise_pred, true_cfg_scale, t)
predict_std = _unpack_latents(predict_std, self.height, self.width, self.vae_scale_factor)
if true_cfg_scale == cfg_BIG:
predict_big = predict_std
else:
predict_big = cfg_predictions(noise_pred, neg_noise_pred, cfg_BIG, t)
predict_big = _unpack_latents(predict_big, self.height, self.width, self.vae_scale_factor)
return predict_std, predict_big
if len(sigma.shape) == 0:
sigma = torch.tensor([sigma.item()])
latent_mask = 1 - latent_mask
if IS_FLUX or IS_FLOW:
Flow_t = sigma
abt = (1 - Flow_t)**2 / ((1 - Flow_t)**2 + Flow_t**2 )
VE_Sigma = Flow_t / (1 - Flow_t)
#print("t", torch.mean( sigma ).item(), "VE_Sigma", torch.mean( VE_Sigma ).item())
else:
VE_Sigma = sigma
abt = 1/( 1+VE_Sigma**2 )
Flow_t = (1-abt)**0.5 / ( (1-abt)**0.5 + abt**0.5 )
# VE_Sigma, abt, Flow_t = current_times
current_times = (VE_Sigma, abt, Flow_t)
step_size = self.step_size * (1 - abt)
step_size = self.add_none_dims(step_size)
# self.inner_model.inner_model.scale_latent_inpaint returns variance exploding x_t values
# This is the replace step
# x = x * (1 - latent_mask) + self.inner_model.inner_model.scale_latent_inpaint(x=x, sigma=sigma, noise=self.noise, latent_image=self.latent_image)* latent_mask
noisy_image = self.latent_image * (1.0 - sigma) + self.noise * sigma
x = x * (1 - latent_mask) + noisy_image * latent_mask
if IS_FLUX or IS_FLOW:
x_t = x * ( self.add_none_dims(abt)**0.5 + (1-self.add_none_dims(abt))**0.5 )
else:
x_t = x / ( 1+self.add_none_dims(VE_Sigma)**2 )**0.5 # switch to variance perserving x_t values
############ LanPaint Iterations Start ###############
# after noise_scaling, noise = latent_image + noise * sigma, which is x_t in the variance exploding diffusion model notation for the known region.
args = None
for i in range(n_steps):
score_func = partial( self.score_model, y = self.latent_image, mask = latent_mask, abt = self.add_none_dims(abt), sigma = self.add_none_dims(VE_Sigma), tflow = self.add_none_dims(Flow_t), denoise_func = double_denoise )
if score_func is None: return None
x_t, args = self.langevin_dynamics(x_t, score_func , latent_mask, step_size , current_times, sigma_x = self.add_none_dims(self.sigma_x(abt)), sigma_y = self.add_none_dims(self.sigma_y(abt)), args = args)
if IS_FLUX or IS_FLOW:
x = x_t / ( self.add_none_dims(abt)**0.5 + (1-self.add_none_dims(abt))**0.5 )
else:
x = x_t * ( 1+self.add_none_dims(VE_Sigma)**2 )**0.5 # switch to variance perserving x_t values
############ LanPaint Iterations End ###############
# out is x_0
# out, _ = self.inner_model(x, sigma, model_options=model_options, seed=seed)
# out = out * (1-latent_mask) + self.latent_image * latent_mask
# return out
return x
def score_model(self, x_t, y, mask, abt, sigma, tflow, denoise_func):
lamb = self.chara_lamb
if self.IS_FLUX or self.IS_FLOW:
# compute t for flow model, with a small epsilon compensating for numerical error.
x = x_t / ( abt**0.5 + (1-abt)**0.5 ) # switch to Gaussian flow matching
x_0, x_0_BIG = denoise_func(x, self.remove_none_dims(tflow))
if x_0 is None: return None
else:
x = x_t * ( 1+sigma**2 )**0.5 # switch to variance exploding
x_0, x_0_BIG = denoise_func(x, self.remove_none_dims(sigma))
if x_0 is None: return None
score_x = -(x_t - x_0)
score_y = - (1 + lamb) * ( x_t - y ) + lamb * (x_t - x_0_BIG)
return score_x * (1 - mask) + score_y * mask
def sigma_x(self, abt):
# the time scale for the x_t update
return abt**0
def sigma_y(self, abt):
beta = self.chara_beta * abt ** 0
return beta
def langevin_dynamics(self, x_t, score, mask, step_size, current_times, sigma_x=1, sigma_y=0, args=None):
# prepare the step size and time parameters
with torch.autocast(device_type=x_t.device.type, dtype=torch.float32):
step_sizes = self.prepare_step_size(current_times, step_size, sigma_x, sigma_y)
sigma, abt, dtx, dty, Gamma_x, Gamma_y, A_x, A_y, D_x, D_y = step_sizes
# print('mask',mask.device)
if torch.mean(dtx) <= 0.:
return x_t, args
# -------------------------------------------------------------------------
# Compute the Langevin dynamics update in variance perserving notation
# -------------------------------------------------------------------------
#x0 = self.x0_evalutation(x_t, score, sigma, args)
#C = abt**0.5 * x0 / (1-abt)
A = A_x * (1-mask) + A_y * mask
D = D_x * (1-mask) + D_y * mask
dt = dtx * (1-mask) + dty * mask
Gamma = Gamma_x * (1-mask) + Gamma_y * mask
def Coef_C(x_t):
x0 = self.x0_evalutation(x_t, score, sigma, args)
C = (abt**0.5 * x0 - x_t )/ (1-abt) + A * x_t
return C
def advance_time(x_t, v, dt, Gamma, A, C, D):
dtype = x_t.dtype
with torch.autocast(device_type=x_t.device.type, dtype=torch.float32):
osc = StochasticHarmonicOscillator(Gamma, A, C, D )
x_t, v = osc.dynamics(x_t, v, dt )
x_t = x_t.to(dtype)
v = v.to(dtype)
return x_t, v
if args is None:
#v = torch.zeros_like(x_t)
v = None
C = Coef_C(x_t)
#print(torch.squeeze(dtx), torch.squeeze(dty))
x_t, v = advance_time(x_t, v, dt, Gamma, A, C, D)
else:
v, C = args
x_t, v = advance_time(x_t, v, dt/2, Gamma, A, C, D)
C_new = Coef_C(x_t)
v = v + Gamma**0.5 * ( C_new - C) *dt
x_t, v = advance_time(x_t, v, dt/2, Gamma, A, C, D)
C = C_new
return x_t, (v, C)
def prepare_step_size(self, current_times, step_size, sigma_x, sigma_y):
# -------------------------------------------------------------------------
# Unpack current times parameters (sigma and abt)
sigma, abt, flow_t = current_times
sigma = self.add_none_dims(sigma)
abt = self.add_none_dims(abt)
# Compute time step (dtx, dty) for x and y branches.
dtx = 2 * step_size * sigma_x
dty = 2 * step_size * sigma_y
# -------------------------------------------------------------------------
# Define friction parameter Gamma_hat for each branch.
# Using dtx**0 provides a tensor of the proper device/dtype.
Gamma_hat_x = self.friction **2 * self.step_size * sigma_x / 0.1 * sigma**0
Gamma_hat_y = self.friction **2 * self.step_size * sigma_y / 0.1 * sigma**0
#print("Gamma_hat_x", torch.mean(Gamma_hat_x).item(), "Gamma_hat_y", torch.mean(Gamma_hat_y).item())
# adjust dt to match denoise-addnoise steps sizes
Gamma_hat_x /= 2.
Gamma_hat_y /= 2.
A_t_x = (1) / ( 1 - abt ) * dtx / 2
A_t_y = (1+self.chara_lamb) / ( 1 - abt ) * dty / 2
A_x = A_t_x / (dtx/2)
A_y = A_t_y / (dty/2)
Gamma_x = Gamma_hat_x / (dtx/2)
Gamma_y = Gamma_hat_y / (dty/2)
#D_x = (2 * (1 + sigma**2) )**0.5
#D_y = (2 * (1 + sigma**2) )**0.5
D_x = (2 * abt**0 )**0.5
D_y = (2 * abt**0 )**0.5
return sigma, abt, dtx/2, dty/2, Gamma_x, Gamma_y, A_x, A_y, D_x, D_y
def x0_evalutation(self, x_t, score, sigma, args):
x0 = x_t + score(x_t)
return x0