Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import copy | |
| import timm | |
| from torch.nn import Parameter | |
| from src.utils.no_grad import no_grad | |
| from typing import Callable, Iterator, Tuple | |
| from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD | |
| from torchvision.transforms import Normalize | |
| from src.diffusion.base.training import * | |
| from src.diffusion.base.scheduling import BaseScheduler | |
| def inverse_sigma(alpha, sigma): | |
| return 1/sigma**2 | |
| def snr(alpha, sigma): | |
| return alpha/sigma | |
| def minsnr(alpha, sigma, threshold=5): | |
| return torch.clip(alpha/sigma, min=threshold) | |
| def maxsnr(alpha, sigma, threshold=5): | |
| return torch.clip(alpha/sigma, max=threshold) | |
| def constant(alpha, sigma): | |
| return 1 | |
| class DINOv2(nn.Module): | |
| def __init__(self, weight_path:str): | |
| super(DINOv2, self).__init__() | |
| self.encoder = torch.hub.load('facebookresearch/dinov2', weight_path) | |
| self.pos_embed = copy.deepcopy(self.encoder.pos_embed) | |
| self.encoder.head = torch.nn.Identity() | |
| self.patch_size = self.encoder.patch_embed.patch_size | |
| self.precomputed_pos_embed = dict() | |
| def fetch_pos(self, h, w): | |
| key = (h, w) | |
| if key in self.precomputed_pos_embed: | |
| return self.precomputed_pos_embed[key] | |
| value = timm.layers.pos_embed.resample_abs_pos_embed( | |
| self.pos_embed.data, [h, w], | |
| ) | |
| self.precomputed_pos_embed[key] = value | |
| return value | |
| def forward(self, x): | |
| b, c, h, w = x.shape | |
| x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) | |
| x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic') | |
| b, c, h, w = x.shape | |
| patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1] | |
| pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w) | |
| self.encoder.pos_embed.data = pos_embed_data | |
| feature = self.encoder.forward_features(x)['x_norm_patchtokens'] | |
| return feature | |
| class REPATrainer(BaseTrainer): | |
| def __init__( | |
| self, | |
| scheduler: BaseScheduler, | |
| loss_weight_fn:Callable=constant, | |
| feat_loss_weight: float=0.5, | |
| lognorm_t=False, | |
| encoder_weight_path=None, | |
| align_layer=8, | |
| proj_denoiser_dim=256, | |
| proj_hidden_dim=256, | |
| proj_encoder_dim=256, | |
| *args, | |
| **kwargs | |
| ): | |
| super().__init__(*args, **kwargs) | |
| self.lognorm_t = lognorm_t | |
| self.scheduler = scheduler | |
| self.loss_weight_fn = loss_weight_fn | |
| self.feat_loss_weight = feat_loss_weight | |
| self.align_layer = align_layer | |
| self.encoder = DINOv2(encoder_weight_path) | |
| no_grad(self.encoder) | |
| self.proj = nn.Sequential( | |
| nn.Sequential( | |
| nn.Linear(proj_denoiser_dim, proj_hidden_dim), | |
| nn.SiLU(), | |
| nn.Linear(proj_hidden_dim, proj_hidden_dim), | |
| nn.SiLU(), | |
| nn.Linear(proj_hidden_dim, proj_encoder_dim), | |
| ) | |
| ) | |
| def _impl_trainstep(self, net, ema_net, raw_images, x, y): | |
| batch_size, c, height, width = x.shape | |
| if self.lognorm_t: | |
| base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid() | |
| else: | |
| base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype) | |
| t = base_t | |
| noise = torch.randn_like(x) | |
| alpha = self.scheduler.alpha(t) | |
| dalpha = self.scheduler.dalpha(t) | |
| sigma = self.scheduler.sigma(t) | |
| dsigma = self.scheduler.dsigma(t) | |
| x_t = alpha * x + noise * sigma | |
| v_t = dalpha * x + dsigma * noise | |
| src_feature = [] | |
| def forward_hook(net, input, output): | |
| src_feature.append(output) | |
| handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook) | |
| out = net(x_t, t, y) | |
| src_feature = self.proj(src_feature[0]) | |
| handle.remove() | |
| with torch.no_grad(): | |
| dst_feature = self.encoder(raw_images) | |
| cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1) | |
| cos_loss = 1 - cos_sim | |
| weight = self.loss_weight_fn(alpha, sigma) | |
| fm_loss = weight*(out - v_t)**2 | |
| out = dict( | |
| fm_loss=fm_loss.mean(), | |
| cos_loss=cos_loss.mean(), | |
| loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean(), | |
| ) | |
| return out | |
| def state_dict(self, *args, destination=None, prefix="", keep_vars=False): | |
| self.proj.state_dict( | |
| destination=destination, | |
| prefix=prefix + "proj.", | |
| keep_vars=keep_vars) | |