Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from typing import Callable | |
| from src.diffusion.base.training import * | |
| from src.diffusion.base.scheduling import BaseScheduler | |
| def inverse_sigma(alpha, sigma): | |
| return 1/sigma**2 | |
| def snr(alpha, sigma): | |
| return alpha/sigma | |
| def minsnr(alpha, sigma, threshold=5): | |
| return torch.clip(alpha/sigma, min=threshold) | |
| def maxsnr(alpha, sigma, threshold=5): | |
| return torch.clip(alpha/sigma, max=threshold) | |
| def constant(alpha, sigma): | |
| return 1 | |
| class VPTrainer(BaseTrainer): | |
| def __init__( | |
| self, | |
| scheduler: BaseScheduler, | |
| loss_weight_fn:Callable=constant, | |
| train_max_t=1000, | |
| lognorm_t=False, | |
| *args, | |
| **kwargs | |
| ): | |
| super().__init__(*args, **kwargs) | |
| self.lognorm_t = lognorm_t | |
| self.scheduler = scheduler | |
| self.loss_weight_fn = loss_weight_fn | |
| self.train_max_t = train_max_t | |
| def _impl_trainstep(self, net, ema_net, raw_images, x, y): | |
| batch_size = x.shape[0] | |
| if self.lognorm_t: | |
| t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid() | |
| else: | |
| t = torch.rand(batch_size).to(x.device, x.dtype) | |
| noise = torch.randn_like(x) | |
| alpha = self.scheduler.alpha(t) | |
| sigma = self.scheduler.sigma(t) | |
| x_t = alpha * x + noise * sigma | |
| out = net(x_t, t*self.train_max_t, y) | |
| weight = self.loss_weight_fn(alpha, sigma) | |
| loss = weight*(out - noise)**2 | |
| out = dict( | |
| loss=loss.mean(), | |
| ) | |
| return out | |
| class DDPMTrainer(BaseTrainer): | |
| def __init__( | |
| self, | |
| scheduler: BaseScheduler, | |
| loss_weight_fn: Callable = constant, | |
| train_max_t=1000, | |
| lognorm_t=False, | |
| *args, | |
| **kwargs | |
| ): | |
| super().__init__(*args, **kwargs) | |
| self.lognorm_t = lognorm_t | |
| self.scheduler = scheduler | |
| self.loss_weight_fn = loss_weight_fn | |
| self.train_max_t = train_max_t | |
| def _impl_trainstep(self, net, ema_net, raw_images, x, y): | |
| batch_size = x.shape[0] | |
| t = torch.randint(0, self.train_max_t, (batch_size,)) | |
| noise = torch.randn_like(x) | |
| alpha = self.scheduler.alpha(t) | |
| sigma = self.scheduler.sigma(t) | |
| x_t = alpha * x + noise * sigma | |
| out = net(x_t, t, y) | |
| weight = self.loss_weight_fn(alpha, sigma) | |
| loss = weight * (out - noise) ** 2 | |
| out = dict( | |
| loss=loss.mean(), | |
| ) | |
| return out |