Hansheng Chen commited on
Commit
f24a89a
·
1 Parent(s): 2d236ff

Rename policy.u() to policy.pi() to better align with the paper notation

Browse files
lakonlab/models/diffusions/piflow_policies/base.py CHANGED
@@ -4,7 +4,7 @@ from abc import ABCMeta, abstractmethod
4
  class BasePolicy(metaclass=ABCMeta):
5
 
6
  @abstractmethod
7
- def u(self, x_t, sigma_t):
8
  """Compute the flow velocity at (x_t, t).
9
 
10
  Args:
 
4
  class BasePolicy(metaclass=ABCMeta):
5
 
6
  @abstractmethod
7
+ def pi(self, x_t, sigma_t):
8
  """Compute the flow velocity at (x_t, t).
9
 
10
  Args:
lakonlab/models/diffusions/piflow_policies/dx.py CHANGED
@@ -69,7 +69,7 @@ class DXPolicy(BasePolicy):
69
  x_interp = (t1 - t) * x0x1[:, 0] + (t - t0) * x0x1[:, 1]
70
  return x_interp
71
 
72
- def u(self, x_t, sigma_t):
73
  """Compute the flow velocity at (x_t, t).
74
 
75
  Args:
 
69
  x_interp = (t1 - t) * x0x1[:, 0] + (t - t0) * x0x1[:, 1]
70
  return x_interp
71
 
72
+ def pi(self, x_t, sigma_t):
73
  """Compute the flow velocity at (x_t, t).
74
 
75
  Args:
lakonlab/models/diffusions/piflow_policies/gmflow.py CHANGED
@@ -93,7 +93,7 @@ class GMFlowPolicy(BasePolicy):
93
  gm_vars=gm_vars,
94
  logweights=denoising_output['logweights'])
95
 
96
- def u(self, x_t, sigma_t):
97
  """Compute the flow velocity at (x_t, t).
98
 
99
  Args:
 
93
  gm_vars=gm_vars,
94
  logweights=denoising_output['logweights'])
95
 
96
+ def pi(self, x_t, sigma_t):
97
  """Compute the flow velocity at (x_t, t).
98
 
99
  Args:
lakonlab/pipelines/piflux_pipeline.py CHANGED
@@ -447,7 +447,7 @@ class PiFluxPipeline(FluxPipeline, PiFlowLoaderMixin):
447
  for _ in range(num_inference_substeps[i]):
448
  t = timesteps[timestep_id]
449
  sigma_t = t / self.scheduler.config.num_train_timesteps
450
- u = policy.u(latents, sigma_t)
451
  latents = self.scheduler.step(u, t, latents, return_dict=False)[0]
452
  timestep_id += 1
453
 
 
447
  for _ in range(num_inference_substeps[i]):
448
  t = timesteps[timestep_id]
449
  sigma_t = t / self.scheduler.config.num_train_timesteps
450
+ u = policy.pi(latents, sigma_t)
451
  latents = self.scheduler.step(u, t, latents, return_dict=False)[0]
452
  timestep_id += 1
453
 
lakonlab/pipelines/piqwen_pipeline.py CHANGED
@@ -377,7 +377,7 @@ class PiQwenImagePipeline(QwenImagePipeline, PiFlowLoaderMixin):
377
  for _ in range(num_inference_substeps[i]):
378
  t = timesteps[timestep_id]
379
  sigma_t = t / self.scheduler.config.num_train_timesteps
380
- u = policy.u(latents, sigma_t)
381
  latents = self.scheduler.step(u, t, latents, return_dict=False)[0]
382
  timestep_id += 1
383
 
 
377
  for _ in range(num_inference_substeps[i]):
378
  t = timesteps[timestep_id]
379
  sigma_t = t / self.scheduler.config.num_train_timesteps
380
+ u = policy.pi(latents, sigma_t)
381
  latents = self.scheduler.step(u, t, latents, return_dict=False)[0]
382
  timestep_id += 1
383