Hansheng Chen
commited on
Commit
·
f24a89a
1
Parent(s):
2d236ff
Rename policy.u() to policy.pi() to better align with the paper notation
Browse files
lakonlab/models/diffusions/piflow_policies/base.py
CHANGED
|
@@ -4,7 +4,7 @@ from abc import ABCMeta, abstractmethod
|
|
| 4 |
class BasePolicy(metaclass=ABCMeta):
|
| 5 |
|
| 6 |
@abstractmethod
|
| 7 |
-
def
|
| 8 |
"""Compute the flow velocity at (x_t, t).
|
| 9 |
|
| 10 |
Args:
|
|
|
|
| 4 |
class BasePolicy(metaclass=ABCMeta):
|
| 5 |
|
| 6 |
@abstractmethod
|
| 7 |
+
def pi(self, x_t, sigma_t):
|
| 8 |
"""Compute the flow velocity at (x_t, t).
|
| 9 |
|
| 10 |
Args:
|
lakonlab/models/diffusions/piflow_policies/dx.py
CHANGED
|
@@ -69,7 +69,7 @@ class DXPolicy(BasePolicy):
|
|
| 69 |
x_interp = (t1 - t) * x0x1[:, 0] + (t - t0) * x0x1[:, 1]
|
| 70 |
return x_interp
|
| 71 |
|
| 72 |
-
def
|
| 73 |
"""Compute the flow velocity at (x_t, t).
|
| 74 |
|
| 75 |
Args:
|
|
|
|
| 69 |
x_interp = (t1 - t) * x0x1[:, 0] + (t - t0) * x0x1[:, 1]
|
| 70 |
return x_interp
|
| 71 |
|
| 72 |
+
def pi(self, x_t, sigma_t):
|
| 73 |
"""Compute the flow velocity at (x_t, t).
|
| 74 |
|
| 75 |
Args:
|
lakonlab/models/diffusions/piflow_policies/gmflow.py
CHANGED
|
@@ -93,7 +93,7 @@ class GMFlowPolicy(BasePolicy):
|
|
| 93 |
gm_vars=gm_vars,
|
| 94 |
logweights=denoising_output['logweights'])
|
| 95 |
|
| 96 |
-
def
|
| 97 |
"""Compute the flow velocity at (x_t, t).
|
| 98 |
|
| 99 |
Args:
|
|
|
|
| 93 |
gm_vars=gm_vars,
|
| 94 |
logweights=denoising_output['logweights'])
|
| 95 |
|
| 96 |
+
def pi(self, x_t, sigma_t):
|
| 97 |
"""Compute the flow velocity at (x_t, t).
|
| 98 |
|
| 99 |
Args:
|
lakonlab/pipelines/piflux_pipeline.py
CHANGED
|
@@ -447,7 +447,7 @@ class PiFluxPipeline(FluxPipeline, PiFlowLoaderMixin):
|
|
| 447 |
for _ in range(num_inference_substeps[i]):
|
| 448 |
t = timesteps[timestep_id]
|
| 449 |
sigma_t = t / self.scheduler.config.num_train_timesteps
|
| 450 |
-
u = policy.
|
| 451 |
latents = self.scheduler.step(u, t, latents, return_dict=False)[0]
|
| 452 |
timestep_id += 1
|
| 453 |
|
|
|
|
| 447 |
for _ in range(num_inference_substeps[i]):
|
| 448 |
t = timesteps[timestep_id]
|
| 449 |
sigma_t = t / self.scheduler.config.num_train_timesteps
|
| 450 |
+
u = policy.pi(latents, sigma_t)
|
| 451 |
latents = self.scheduler.step(u, t, latents, return_dict=False)[0]
|
| 452 |
timestep_id += 1
|
| 453 |
|
lakonlab/pipelines/piqwen_pipeline.py
CHANGED
|
@@ -377,7 +377,7 @@ class PiQwenImagePipeline(QwenImagePipeline, PiFlowLoaderMixin):
|
|
| 377 |
for _ in range(num_inference_substeps[i]):
|
| 378 |
t = timesteps[timestep_id]
|
| 379 |
sigma_t = t / self.scheduler.config.num_train_timesteps
|
| 380 |
-
u = policy.
|
| 381 |
latents = self.scheduler.step(u, t, latents, return_dict=False)[0]
|
| 382 |
timestep_id += 1
|
| 383 |
|
|
|
|
| 377 |
for _ in range(num_inference_substeps[i]):
|
| 378 |
t = timesteps[timestep_id]
|
| 379 |
sigma_t = t / self.scheduler.config.num_train_timesteps
|
| 380 |
+
u = policy.pi(latents, sigma_t)
|
| 381 |
latents = self.scheduler.step(u, t, latents, return_dict=False)[0]
|
| 382 |
timestep_id += 1
|
| 383 |
|