File size: 11,332 Bytes
2b67076
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
import torch
from .utils import *
from functools import partial

# Many thanks to the LanPaint team for this implementation (https://github.com/scraed/LanPaint/)

def _pack_latents(latents):
    batch_size, num_channels_latents, _, height, width = latents.shape 

    latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
    latents = latents.permute(0, 2, 4, 1, 3, 5)
    latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)

    return latents

def _unpack_latents(latents, height, width, vae_scale_factor=8):
    batch_size, num_patches, channels = latents.shape

    height = 2 * (int(height) // (vae_scale_factor * 2))
    width = 2 * (int(width) // (vae_scale_factor * 2))

    latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
    latents = latents.permute(0, 3, 1, 4, 2, 5)

    latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width)

    return latents

class LanPaint():
    def __init__(self, NSteps = 5, Friction = 15, Lambda = 8, Beta = 1, StepSize = 0.15, IS_FLUX = True, IS_FLOW = False):
        self.n_steps = NSteps
        self.chara_lamb = Lambda
        self.IS_FLUX = IS_FLUX
        self.IS_FLOW = IS_FLOW
        self.step_size = StepSize
        self.friction = Friction
        self.chara_beta = Beta
        self.img_dim_size = None
    def add_none_dims(self, array):
        # Create a tuple with ':' for the first dimension and 'None' repeated num_nones times
        index = (slice(None),) + (None,) * (self.img_dim_size-1)
        return array[index]
    def remove_none_dims(self, array):
        # Create a tuple with ':' for the first dimension and 'None' repeated num_nones times
        index = (slice(None),) + (0,) * (self.img_dim_size-1)
        return array[index]
    def __call__(self, denoise, cfg_predictions, true_cfg_scale, cfg_BIG, x, latent_image, noise, sigma, latent_mask, n_steps=None, height =720, width = 1280, vae_scale_factor = 8):
        latent_image = _unpack_latents(latent_image, height=height, width=width, vae_scale_factor=vae_scale_factor)
        noise = _unpack_latents(noise, height=height, width=width, vae_scale_factor=vae_scale_factor)
        x = _unpack_latents(x, height=height, width=width, vae_scale_factor=vae_scale_factor)
        latent_mask = _unpack_latents(latent_mask, height=height, width=width, vae_scale_factor=vae_scale_factor)
        self.height = height
        self.width = width
        self.vae_scale_factor = vae_scale_factor
        self.img_dim_size = len(x.shape)
        self.latent_image = latent_image
        self.noise = noise
        if n_steps is None:
            n_steps = self.n_steps
        out = self.LanPaint(denoise, cfg_predictions, true_cfg_scale, cfg_BIG, x, sigma, latent_mask, n_steps, self.IS_FLUX, self.IS_FLOW)
        out = _pack_latents(out)
        return out
    def LanPaint(self, denoise, cfg_predictions, true_cfg_scale, cfg_BIG,  x, sigma, latent_mask, n_steps, IS_FLUX, IS_FLOW):
        if IS_FLUX:
            cfg_BIG = 1.0

        def double_denoise(latents, t):
            latents = _pack_latents(latents)
            noise_pred, neg_noise_pred = denoise(latents, true_cfg_scale)
            if noise_pred == None: return None, None
            predict_std = cfg_predictions(noise_pred, neg_noise_pred, true_cfg_scale, t)
            predict_std = _unpack_latents(predict_std, self.height, self.width, self.vae_scale_factor)
            if true_cfg_scale ==  cfg_BIG:
                predict_big = predict_std
            else:
                predict_big = cfg_predictions(noise_pred, neg_noise_pred, cfg_BIG, t)
                predict_big = _unpack_latents(predict_big, self.height, self.width, self.vae_scale_factor)
            return predict_std, predict_big
        
        if len(sigma.shape) == 0:
            sigma = torch.tensor([sigma.item()])
        latent_mask = 1 - latent_mask
        if IS_FLUX or IS_FLOW:
            Flow_t = sigma
            abt = (1 - Flow_t)**2 / ((1 - Flow_t)**2 + Flow_t**2 )
            VE_Sigma = Flow_t / (1 - Flow_t)
            #print("t", torch.mean( sigma ).item(), "VE_Sigma", torch.mean( VE_Sigma ).item())
        else:
            VE_Sigma = sigma 
            abt = 1/( 1+VE_Sigma**2 )
            Flow_t = (1-abt)**0.5 / ( (1-abt)**0.5 + abt**0.5  )
        # VE_Sigma, abt, Flow_t = current_times
        current_times =  (VE_Sigma, abt, Flow_t)
        
        step_size = self.step_size * (1 - abt)
        step_size = self.add_none_dims(step_size)
        # self.inner_model.inner_model.scale_latent_inpaint returns variance exploding x_t values
        # This is the replace step
        # x = x * (1 - latent_mask) +  self.inner_model.inner_model.scale_latent_inpaint(x=x, sigma=sigma, noise=self.noise, latent_image=self.latent_image)* latent_mask

        noisy_image  = self.latent_image  * (1.0 - sigma) + self.noise * sigma 
        x = x * (1 - latent_mask) +  noisy_image * latent_mask

        if IS_FLUX or IS_FLOW:
            x_t = x * ( self.add_none_dims(abt)**0.5 + (1-self.add_none_dims(abt))**0.5 )
        else:
            x_t = x / ( 1+self.add_none_dims(VE_Sigma)**2 )**0.5 # switch to variance perserving x_t values

        ############ LanPaint Iterations Start ###############
        # after noise_scaling, noise = latent_image + noise * sigma, which is x_t in the variance exploding diffusion model notation for the known region.
        args = None
        for i in range(n_steps):
            score_func = partial( self.score_model, y = self.latent_image, mask = latent_mask, abt = self.add_none_dims(abt), sigma = self.add_none_dims(VE_Sigma), tflow = self.add_none_dims(Flow_t), denoise_func = double_denoise )
            if score_func is None: return None
            x_t, args = self.langevin_dynamics(x_t, score_func , latent_mask, step_size , current_times, sigma_x = self.add_none_dims(self.sigma_x(abt)), sigma_y = self.add_none_dims(self.sigma_y(abt)), args = args)  
        if IS_FLUX or IS_FLOW:
            x = x_t / ( self.add_none_dims(abt)**0.5 + (1-self.add_none_dims(abt))**0.5 )
        else:
            x = x_t * ( 1+self.add_none_dims(VE_Sigma)**2 )**0.5 # switch to variance perserving x_t values
        ############ LanPaint Iterations End ###############
        # out is x_0
        # out, _ = self.inner_model(x, sigma, model_options=model_options, seed=seed)
        # out = out * (1-latent_mask) + self.latent_image * latent_mask
        # return out
        return x

    def score_model(self, x_t, y, mask, abt, sigma, tflow, denoise_func):
        
        lamb = self.chara_lamb
        if self.IS_FLUX or self.IS_FLOW:
            # compute t for flow model, with a small epsilon compensating for numerical error.
            x = x_t / ( abt**0.5 + (1-abt)**0.5 ) # switch to Gaussian flow matching
            x_0, x_0_BIG = denoise_func(x, self.remove_none_dims(tflow))
            if x_0 is None: return None
        else:
            x = x_t * ( 1+sigma**2 )**0.5 # switch to variance exploding
            x_0, x_0_BIG = denoise_func(x, self.remove_none_dims(sigma))
            if x_0 is None: return None

        score_x = -(x_t - x_0)
        score_y =  - (1 + lamb) * ( x_t - y )  + lamb * (x_t - x_0_BIG)  
        return score_x * (1 - mask) + score_y * mask
    def sigma_x(self, abt):
        # the time scale for the x_t update
        return abt**0
    def sigma_y(self, abt):
        beta = self.chara_beta * abt ** 0
        return beta

    def langevin_dynamics(self, x_t, score, mask, step_size, current_times, sigma_x=1, sigma_y=0, args=None):
        # prepare the step size and time parameters
        with torch.autocast(device_type=x_t.device.type, dtype=torch.float32):
            step_sizes = self.prepare_step_size(current_times, step_size, sigma_x, sigma_y)
            sigma, abt, dtx, dty, Gamma_x, Gamma_y, A_x, A_y, D_x, D_y = step_sizes
        # print('mask',mask.device)
        if torch.mean(dtx) <= 0.:
            return x_t, args
        # -------------------------------------------------------------------------
        # Compute the Langevin dynamics update in variance perserving notation
        # -------------------------------------------------------------------------
        #x0 = self.x0_evalutation(x_t, score, sigma, args)
        #C = abt**0.5 * x0 / (1-abt)
        A = A_x * (1-mask) + A_y * mask
        D = D_x * (1-mask) + D_y * mask
        dt = dtx * (1-mask) + dty * mask
        Gamma = Gamma_x * (1-mask) + Gamma_y * mask


        def Coef_C(x_t):
            x0 = self.x0_evalutation(x_t, score, sigma, args)
            C = (abt**0.5 * x0  - x_t )/ (1-abt) + A * x_t
            return C
        def advance_time(x_t, v, dt, Gamma, A, C, D):
            dtype = x_t.dtype
            with torch.autocast(device_type=x_t.device.type, dtype=torch.float32):
                osc = StochasticHarmonicOscillator(Gamma, A, C, D )
                x_t, v = osc.dynamics(x_t, v, dt )
            x_t = x_t.to(dtype)
            v = v.to(dtype)
            return x_t, v
        if args is None:
            #v = torch.zeros_like(x_t)
            v = None
            C = Coef_C(x_t)
            #print(torch.squeeze(dtx), torch.squeeze(dty))
            x_t, v = advance_time(x_t, v, dt, Gamma, A, C, D)
        else:
            v, C = args

            x_t, v = advance_time(x_t, v, dt/2, Gamma, A, C, D)

            C_new = Coef_C(x_t)
            v = v + Gamma**0.5 * ( C_new - C) *dt

            x_t, v = advance_time(x_t, v, dt/2, Gamma, A, C, D)

            C = C_new
  
        return x_t, (v, C)

    def prepare_step_size(self, current_times, step_size, sigma_x, sigma_y):
        # -------------------------------------------------------------------------
        # Unpack current times parameters (sigma and abt)
        sigma, abt, flow_t = current_times
        sigma = self.add_none_dims(sigma)
        abt = self.add_none_dims(abt)
        # Compute time step (dtx, dty) for x and y branches.
        dtx = 2 * step_size * sigma_x
        dty = 2 * step_size * sigma_y
        
        # -------------------------------------------------------------------------
        # Define friction parameter Gamma_hat for each branch.
        # Using dtx**0 provides a tensor of the proper device/dtype.

        Gamma_hat_x = self.friction **2 * self.step_size * sigma_x / 0.1 * sigma**0
        Gamma_hat_y = self.friction **2 * self.step_size * sigma_y / 0.1 * sigma**0
        #print("Gamma_hat_x", torch.mean(Gamma_hat_x).item(), "Gamma_hat_y", torch.mean(Gamma_hat_y).item())
        # adjust dt to match denoise-addnoise steps sizes
        Gamma_hat_x /= 2.
        Gamma_hat_y /= 2.
        A_t_x = (1) / ( 1 - abt ) * dtx / 2
        A_t_y =  (1+self.chara_lamb) / ( 1 - abt ) * dty / 2


        A_x = A_t_x / (dtx/2)
        A_y = A_t_y / (dty/2)
        Gamma_x = Gamma_hat_x / (dtx/2)
        Gamma_y = Gamma_hat_y / (dty/2)

        #D_x = (2 * (1 + sigma**2) )**0.5
        #D_y = (2 * (1 + sigma**2) )**0.5
        D_x = (2 * abt**0 )**0.5
        D_y = (2 * abt**0 )**0.5
        return sigma, abt, dtx/2, dty/2, Gamma_x, Gamma_y, A_x, A_y, D_x, D_y



    def x0_evalutation(self, x_t, score, sigma, args):
        x0 = x_t + score(x_t)
        return x0