Hansheng Chen
Rename policy.u() to policy.pi() to better align with the paper notation
f24a89a
# Copyright (c) 2025 Hansheng Chen
import torch
from .base import BasePolicy
class DXPolicy(BasePolicy):
"""DX policy. The number of grid points N is inferred from the denoising output.
Note: segment_size and shift are intrinsic parameters of the DX policy. For elastic inference (i.e., changing
the number of function evaluations or noise schedule at test time), these parameters should be kept unchanged.
Args:
denoising_output (torch.Tensor): The output of the denoising model. Shape (B, N, C, H, W) or (B, N, C, T, H, W).
x_t_src (torch.Tensor): The initial noisy sample. Shape (B, C, H, W) or (B, C, T, H, W).
sigma_t_src (torch.Tensor): The initial noise level. Shape (B,).
segment_size (float): The size of each DX policy time segment. Defaults to 1.0.
shift (float): The shift parameter for the DX policy noise schedule. Defaults to 1.0.
eps (float): A small value to avoid numerical issues. Defaults to 1e-4.
"""
def __init__(
self,
denoising_output: torch.Tensor,
x_t_src: torch.Tensor,
sigma_t_src: torch.Tensor,
segment_size: float = 1.0,
shift: float = 1.0,
eps: float = 1e-4):
self.x_t_src = x_t_src
self.ndim = x_t_src.dim()
self.shift = shift
self.eps = eps
self.sigma_t_src = sigma_t_src.reshape(*sigma_t_src.size(), *((self.ndim - sigma_t_src.dim()) * [1]))
self.raw_t_src = self._unwarp_t(self.sigma_t_src)
self.raw_t_dst = (self.raw_t_src - segment_size).clamp(min=0)
self.segment_size = (self.raw_t_src - self.raw_t_dst).clamp(min=eps)
self.denoising_output_x_0 = self._u_to_x_0(
denoising_output, self.x_t_src, self.sigma_t_src)
def _unwarp_t(self, sigma_t):
return sigma_t / (self.shift + (1 - self.shift) * sigma_t)
@staticmethod
def _u_to_x_0(denoising_output, x_t, sigma_t):
x_0 = x_t.unsqueeze(1) - sigma_t.unsqueeze(1) * denoising_output
return x_0
@staticmethod
def _interpolate(x, t):
"""
Args:
x (torch.Tensor): (B, N, *)
t (torch.Tensor): (B, *) in [0, 1]
Returns:
torch.Tensor: (B, *)
"""
n = x.size(1)
if n < 2:
return x.squeeze(1)
t = t.clamp(min=0, max=1) * (n - 1)
t0 = t.floor().to(torch.long).clamp(min=0, max=n - 2)
t1 = t0 + 1
t0t1 = torch.stack([t0, t1], dim=1) # (B, 2, *)
x0x1 = torch.gather(x, dim=1, index=t0t1.expand(-1, -1, *x.shape[2:]))
x_interp = (t1 - t) * x0x1[:, 0] + (t - t0) * x0x1[:, 1]
return x_interp
def pi(self, x_t, sigma_t):
"""Compute the flow velocity at (x_t, t).
Args:
x_t (torch.Tensor): Noisy input at time t.
sigma_t (torch.Tensor): Noise level at time t.
Returns:
torch.Tensor: The computed flow velocity u_t.
"""
sigma_t = sigma_t.reshape(*sigma_t.size(), *((self.ndim - sigma_t.dim()) * [1]))
raw_t = self._unwarp_t(sigma_t)
x_0 = self._interpolate(
self.denoising_output_x_0, (raw_t - self.raw_t_dst) / self.segment_size)
u = (x_t - x_0) / sigma_t.clamp(min=self.eps)
return u
def copy(self):
new_policy = DXPolicy.__new__(DXPolicy)
new_policy.x_t_src = self.x_t_src
new_policy.ndim = self.ndim
new_policy.shift = self.shift
new_policy.eps = self.eps
new_policy.sigma_t_src = self.sigma_t_src
new_policy.raw_t_src = self.raw_t_src
new_policy.raw_t_dst = self.raw_t_dst
new_policy.segment_size = self.segment_size
new_policy.denoising_output_x_0 = self.denoising_output_x_0
return new_policy
def detach_(self):
self.denoising_output_x_0 = self.denoising_output_x_0.detach()
return self
def detach(self):
new_policy = self.copy()
return new_policy.detach_()