Spaces:
Runtime error
Runtime error
| """ | |
| Copyright (C) 2019 NVIDIA Corporation. Ting-Chun Wang, Ming-Yu Liu, Jun-Yan Zhu. | |
| BSD License. All rights reserved. | |
| Redistribution and use in source and binary forms, with or without | |
| modification, are permitted provided that the following conditions are met: | |
| * Redistributions of source code must retain the above copyright notice, this | |
| list of conditions and the following disclaimer. | |
| * Redistributions in binary form must reproduce the above copyright notice, | |
| this list of conditions and the following disclaimer in the documentation | |
| and/or other materials provided with the distribution. | |
| THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING ALL | |
| IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR ANY PARTICULAR PURPOSE. | |
| IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL | |
| DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, | |
| WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING | |
| OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. | |
| """ | |
| import functools | |
| import numpy as np | |
| import pytorch_lightning as pl | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torchvision import models | |
| ############################################################################### | |
| # Functions | |
| ############################################################################### | |
| def weights_init(m): | |
| classname = m.__class__.__name__ | |
| if classname.find("Conv") != -1: | |
| m.weight.data.normal_(0.0, 0.02) | |
| elif classname.find("BatchNorm2d") != -1: | |
| m.weight.data.normal_(1.0, 0.02) | |
| m.bias.data.fill_(0) | |
| def get_norm_layer(norm_type="instance"): | |
| if norm_type == "batch": | |
| norm_layer = functools.partial(nn.BatchNorm2d, affine=True) | |
| elif norm_type == "instance": | |
| norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) | |
| else: | |
| raise NotImplementedError("normalization layer [%s] is not found" % norm_type) | |
| return norm_layer | |
| def define_G( | |
| input_nc, | |
| output_nc, | |
| ngf, | |
| netG, | |
| n_downsample_global=3, | |
| n_blocks_global=9, | |
| n_local_enhancers=1, | |
| n_blocks_local=3, | |
| norm="instance", | |
| gpu_ids=[], | |
| last_op=nn.Tanh(), | |
| ): | |
| norm_layer = get_norm_layer(norm_type=norm) | |
| if netG == "global": | |
| netG = GlobalGenerator( | |
| input_nc, | |
| output_nc, | |
| ngf, | |
| n_downsample_global, | |
| n_blocks_global, | |
| norm_layer, | |
| last_op=last_op, | |
| ) | |
| elif netG == "local": | |
| netG = LocalEnhancer( | |
| input_nc, | |
| output_nc, | |
| ngf, | |
| n_downsample_global, | |
| n_blocks_global, | |
| n_local_enhancers, | |
| n_blocks_local, | |
| norm_layer, | |
| ) | |
| elif netG == "encoder": | |
| netG = Encoder(input_nc, output_nc, ngf, n_downsample_global, norm_layer) | |
| else: | |
| raise ("generator not implemented!") | |
| # print(netG) | |
| if len(gpu_ids) > 0: | |
| assert torch.cuda.is_available() | |
| netG.cuda(gpu_ids[0]) | |
| netG.apply(weights_init) | |
| return netG | |
| def define_D( | |
| input_nc, | |
| ndf, | |
| n_layers_D, | |
| norm='instance', | |
| use_sigmoid=False, | |
| num_D=1, | |
| getIntermFeat=False, | |
| gpu_ids=[] | |
| ): | |
| norm_layer = get_norm_layer(norm_type=norm) | |
| netD = MultiscaleDiscriminator( | |
| input_nc, ndf, n_layers_D, norm_layer, use_sigmoid, num_D, getIntermFeat | |
| ) | |
| if len(gpu_ids) > 0: | |
| assert (torch.cuda.is_available()) | |
| netD.cuda(gpu_ids[0]) | |
| netD.apply(weights_init) | |
| return netD | |
| def print_network(net): | |
| if isinstance(net, list): | |
| net = net[0] | |
| num_params = 0 | |
| for param in net.parameters(): | |
| num_params += param.numel() | |
| print(net) | |
| print("Total number of parameters: %d" % num_params) | |
| ############################################################################## | |
| # Generator | |
| ############################################################################## | |
| class LocalEnhancer(pl.LightningModule): | |
| def __init__( | |
| self, | |
| input_nc, | |
| output_nc, | |
| ngf=32, | |
| n_downsample_global=3, | |
| n_blocks_global=9, | |
| n_local_enhancers=1, | |
| n_blocks_local=3, | |
| norm_layer=nn.BatchNorm2d, | |
| padding_type="reflect", | |
| ): | |
| super(LocalEnhancer, self).__init__() | |
| self.n_local_enhancers = n_local_enhancers | |
| ###### global generator model ##### | |
| ngf_global = ngf * (2**n_local_enhancers) | |
| model_global = GlobalGenerator( | |
| input_nc, | |
| output_nc, | |
| ngf_global, | |
| n_downsample_global, | |
| n_blocks_global, | |
| norm_layer, | |
| ).model | |
| model_global = [ | |
| model_global[i] for i in range(len(model_global) - 3) | |
| ] # get rid of final convolution layers | |
| self.model = nn.Sequential(*model_global) | |
| ###### local enhancer layers ##### | |
| for n in range(1, n_local_enhancers + 1): | |
| # downsample | |
| ngf_global = ngf * (2**(n_local_enhancers - n)) | |
| model_downsample = [ | |
| nn.ReflectionPad2d(3), | |
| nn.Conv2d(input_nc, ngf_global, kernel_size=7, padding=0), | |
| norm_layer(ngf_global), | |
| nn.ReLU(True), | |
| nn.Conv2d(ngf_global, ngf_global * 2, kernel_size=3, stride=2, padding=1), | |
| norm_layer(ngf_global * 2), | |
| nn.ReLU(True), | |
| ] | |
| # residual blocks | |
| model_upsample = [] | |
| for i in range(n_blocks_local): | |
| model_upsample += [ | |
| ResnetBlock(ngf_global * 2, padding_type=padding_type, norm_layer=norm_layer) | |
| ] | |
| # upsample | |
| model_upsample += [ | |
| nn.ConvTranspose2d( | |
| ngf_global * 2, | |
| ngf_global, | |
| kernel_size=3, | |
| stride=2, | |
| padding=1, | |
| output_padding=1, | |
| ), | |
| norm_layer(ngf_global), | |
| nn.ReLU(True), | |
| ] | |
| # final convolution | |
| if n == n_local_enhancers: | |
| model_upsample += [ | |
| nn.ReflectionPad2d(3), | |
| nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), | |
| nn.Tanh(), | |
| ] | |
| setattr(self, "model" + str(n) + "_1", nn.Sequential(*model_downsample)) | |
| setattr(self, "model" + str(n) + "_2", nn.Sequential(*model_upsample)) | |
| self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False) | |
| def forward(self, input): | |
| # create input pyramid | |
| input_downsampled = [input] | |
| for i in range(self.n_local_enhancers): | |
| input_downsampled.append(self.downsample(input_downsampled[-1])) | |
| # output at coarest level | |
| output_prev = self.model(input_downsampled[-1]) | |
| # build up one layer at a time | |
| for n_local_enhancers in range(1, self.n_local_enhancers + 1): | |
| model_downsample = getattr(self, "model" + str(n_local_enhancers) + "_1") | |
| model_upsample = getattr(self, "model" + str(n_local_enhancers) + "_2") | |
| input_i = input_downsampled[self.n_local_enhancers - n_local_enhancers] | |
| output_prev = model_upsample(model_downsample(input_i) + output_prev) | |
| return output_prev | |
| class GlobalGenerator(pl.LightningModule): | |
| def __init__( | |
| self, | |
| input_nc, | |
| output_nc, | |
| ngf=64, | |
| n_downsampling=3, | |
| n_blocks=9, | |
| norm_layer=nn.BatchNorm2d, | |
| padding_type="reflect", | |
| last_op=nn.Tanh(), | |
| ): | |
| assert n_blocks >= 0 | |
| super(GlobalGenerator, self).__init__() | |
| activation = nn.ReLU(True) | |
| model = [ | |
| nn.ReflectionPad2d(3), | |
| nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), | |
| norm_layer(ngf), | |
| activation, | |
| ] | |
| # downsample | |
| for i in range(n_downsampling): | |
| mult = 2**i | |
| model += [ | |
| nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1), | |
| norm_layer(ngf * mult * 2), | |
| activation, | |
| ] | |
| # resnet blocks | |
| mult = 2**n_downsampling | |
| for i in range(n_blocks): | |
| model += [ | |
| ResnetBlock( | |
| ngf * mult, | |
| padding_type=padding_type, | |
| activation=activation, | |
| norm_layer=norm_layer, | |
| ) | |
| ] | |
| # upsample | |
| for i in range(n_downsampling): | |
| mult = 2**(n_downsampling - i) | |
| model += [ | |
| nn.ConvTranspose2d( | |
| ngf * mult, | |
| int(ngf * mult / 2), | |
| kernel_size=3, | |
| stride=2, | |
| padding=1, | |
| output_padding=1, | |
| ), | |
| norm_layer(int(ngf * mult / 2)), | |
| activation, | |
| ] | |
| model += [ | |
| nn.ReflectionPad2d(3), | |
| nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), | |
| ] | |
| if last_op is not None: | |
| model += [last_op] | |
| self.model = nn.Sequential(*model) | |
| def forward(self, input): | |
| return self.model(input) | |
| # Defines the PatchGAN discriminator with the specified arguments. | |
| class NLayerDiscriminator(nn.Module): | |
| def __init__( | |
| self, | |
| input_nc, | |
| ndf=64, | |
| n_layers=3, | |
| norm_layer=nn.BatchNorm2d, | |
| use_sigmoid=False, | |
| getIntermFeat=False | |
| ): | |
| super(NLayerDiscriminator, self).__init__() | |
| self.getIntermFeat = getIntermFeat | |
| self.n_layers = n_layers | |
| kw = 4 | |
| padw = int(np.ceil((kw - 1.0) / 2)) | |
| sequence = [[ | |
| nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), | |
| nn.LeakyReLU(0.2, True) | |
| ]] | |
| nf = ndf | |
| for n in range(1, n_layers): | |
| nf_prev = nf | |
| nf = min(nf * 2, 512) | |
| sequence += [[ | |
| nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), | |
| norm_layer(nf), | |
| nn.LeakyReLU(0.2, True) | |
| ]] | |
| nf_prev = nf | |
| nf = min(nf * 2, 512) | |
| sequence += [[ | |
| nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), | |
| norm_layer(nf), | |
| nn.LeakyReLU(0.2, True) | |
| ]] | |
| sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]] | |
| if use_sigmoid: | |
| sequence += [[nn.Sigmoid()]] | |
| if getIntermFeat: | |
| for n in range(len(sequence)): | |
| setattr(self, 'model' + str(n), nn.Sequential(*sequence[n])) | |
| else: | |
| sequence_stream = [] | |
| for n in range(len(sequence)): | |
| sequence_stream += sequence[n] | |
| self.model = nn.Sequential(*sequence_stream) | |
| def forward(self, input): | |
| if self.getIntermFeat: | |
| res = [input] | |
| for n in range(self.n_layers + 2): | |
| model = getattr(self, 'model' + str(n)) | |
| res.append(model(res[-1])) | |
| return res[1:] | |
| else: | |
| return self.model(input) | |
| class MultiscaleDiscriminator(pl.LightningModule): | |
| def __init__( | |
| self, | |
| input_nc, | |
| ndf=64, | |
| n_layers=3, | |
| norm_layer=nn.BatchNorm2d, | |
| use_sigmoid=False, | |
| num_D=3, | |
| getIntermFeat=False | |
| ): | |
| super(MultiscaleDiscriminator, self).__init__() | |
| self.num_D = num_D | |
| self.n_layers = n_layers | |
| self.getIntermFeat = getIntermFeat | |
| for i in range(num_D): | |
| netD = NLayerDiscriminator( | |
| input_nc, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat | |
| ) | |
| if getIntermFeat: | |
| for j in range(n_layers + 2): | |
| setattr( | |
| self, 'scale' + str(i) + '_layer' + str(j), getattr(netD, 'model' + str(j)) | |
| ) | |
| else: | |
| setattr(self, 'layer' + str(i), netD.model) | |
| self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False) | |
| def singleD_forward(self, model, input): | |
| if self.getIntermFeat: | |
| result = [input] | |
| for i in range(len(model)): | |
| result.append(model[i](result[-1])) | |
| return result[1:] | |
| else: | |
| return [model(input)] | |
| def forward(self, input): | |
| num_D = self.num_D | |
| result = [] | |
| input_downsampled = input.clone() | |
| for i in range(num_D): | |
| if self.getIntermFeat: | |
| model = [ | |
| getattr(self, 'scale' + str(num_D - 1 - i) + '_layer' + str(j)) | |
| for j in range(self.n_layers + 2) | |
| ] | |
| else: | |
| model = getattr(self, 'layer' + str(num_D - 1 - i)) | |
| result.append(self.singleD_forward(model, input_downsampled)) | |
| if i != (num_D - 1): | |
| input_downsampled = self.downsample(input_downsampled) | |
| return result | |
| # Define a resnet block | |
| class ResnetBlock(pl.LightningModule): | |
| def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False): | |
| super(ResnetBlock, self).__init__() | |
| self.conv_block = self.build_conv_block( | |
| dim, padding_type, norm_layer, activation, use_dropout | |
| ) | |
| def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout): | |
| conv_block = [] | |
| p = 0 | |
| if padding_type == "reflect": | |
| conv_block += [nn.ReflectionPad2d(1)] | |
| elif padding_type == "replicate": | |
| conv_block += [nn.ReplicationPad2d(1)] | |
| elif padding_type == "zero": | |
| p = 1 | |
| else: | |
| raise NotImplementedError("padding [%s] is not implemented" % padding_type) | |
| conv_block += [ | |
| nn.Conv2d(dim, dim, kernel_size=3, padding=p), | |
| norm_layer(dim), | |
| activation, | |
| ] | |
| if use_dropout: | |
| conv_block += [nn.Dropout(0.5)] | |
| p = 0 | |
| if padding_type == "reflect": | |
| conv_block += [nn.ReflectionPad2d(1)] | |
| elif padding_type == "replicate": | |
| conv_block += [nn.ReplicationPad2d(1)] | |
| elif padding_type == "zero": | |
| p = 1 | |
| else: | |
| raise NotImplementedError("padding [%s] is not implemented" % padding_type) | |
| conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), norm_layer(dim)] | |
| return nn.Sequential(*conv_block) | |
| def forward(self, x): | |
| out = x + self.conv_block(x) | |
| return out | |
| class Encoder(pl.LightningModule): | |
| def __init__(self, input_nc, output_nc, ngf=32, n_downsampling=4, norm_layer=nn.BatchNorm2d): | |
| super(Encoder, self).__init__() | |
| self.output_nc = output_nc | |
| model = [ | |
| nn.ReflectionPad2d(3), | |
| nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), | |
| norm_layer(ngf), | |
| nn.ReLU(True), | |
| ] | |
| # downsample | |
| for i in range(n_downsampling): | |
| mult = 2**i | |
| model += [ | |
| nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1), | |
| norm_layer(ngf * mult * 2), | |
| nn.ReLU(True), | |
| ] | |
| # upsample | |
| for i in range(n_downsampling): | |
| mult = 2**(n_downsampling - i) | |
| model += [ | |
| nn.ConvTranspose2d( | |
| ngf * mult, | |
| int(ngf * mult / 2), | |
| kernel_size=3, | |
| stride=2, | |
| padding=1, | |
| output_padding=1, | |
| ), | |
| norm_layer(int(ngf * mult / 2)), | |
| nn.ReLU(True), | |
| ] | |
| model += [ | |
| nn.ReflectionPad2d(3), | |
| nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), | |
| nn.Tanh(), | |
| ] | |
| self.model = nn.Sequential(*model) | |
| def forward(self, input, inst): | |
| outputs = self.model(input) | |
| # instance-wise average pooling | |
| outputs_mean = outputs.clone() | |
| inst_list = np.unique(inst.cpu().numpy().astype(int)) | |
| for i in inst_list: | |
| for b in range(input.size()[0]): | |
| indices = (inst[b:b + 1] == int(i)).nonzero() # n x 4 | |
| for j in range(self.output_nc): | |
| output_ins = outputs[indices[:, 0] + b, indices[:, 1] + j, indices[:, 2], | |
| indices[:, 3], ] | |
| mean_feat = torch.mean(output_ins).expand_as(output_ins) | |
| outputs_mean[indices[:, 0] + b, indices[:, 1] + j, indices[:, 2], | |
| indices[:, 3], ] = mean_feat | |
| return outputs_mean | |
| class Vgg19(nn.Module): | |
| def __init__(self, requires_grad=False): | |
| super(Vgg19, self).__init__() | |
| vgg_pretrained_features = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features | |
| self.slice1 = torch.nn.Sequential() | |
| self.slice2 = torch.nn.Sequential() | |
| self.slice3 = torch.nn.Sequential() | |
| self.slice4 = torch.nn.Sequential() | |
| self.slice5 = torch.nn.Sequential() | |
| for x in range(2): | |
| self.slice1.add_module(str(x), vgg_pretrained_features[x]) | |
| for x in range(2, 7): | |
| self.slice2.add_module(str(x), vgg_pretrained_features[x]) | |
| for x in range(7, 12): | |
| self.slice3.add_module(str(x), vgg_pretrained_features[x]) | |
| for x in range(12, 21): | |
| self.slice4.add_module(str(x), vgg_pretrained_features[x]) | |
| for x in range(21, 30): | |
| self.slice5.add_module(str(x), vgg_pretrained_features[x]) | |
| if not requires_grad: | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| def forward(self, X): | |
| h_relu1 = self.slice1(X) | |
| h_relu2 = self.slice2(h_relu1) | |
| h_relu3 = self.slice3(h_relu2) | |
| h_relu4 = self.slice4(h_relu3) | |
| h_relu5 = self.slice5(h_relu4) | |
| out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] | |
| return out | |
| class VGG19FeatLayer(nn.Module): | |
| def __init__(self): | |
| super(VGG19FeatLayer, self).__init__() | |
| self.vgg19 = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features.eval() | |
| self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) | |
| self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) | |
| def forward(self, x): | |
| out = {} | |
| x = x - self.mean | |
| x = x / self.std | |
| ci = 1 | |
| ri = 0 | |
| for layer in self.vgg19.children(): | |
| if isinstance(layer, nn.Conv2d): | |
| ri += 1 | |
| name = 'conv{}_{}'.format(ci, ri) | |
| elif isinstance(layer, nn.ReLU): | |
| ri += 1 | |
| name = 'relu{}_{}'.format(ci, ri) | |
| layer = nn.ReLU(inplace=False) | |
| elif isinstance(layer, nn.MaxPool2d): | |
| ri = 0 | |
| name = 'pool_{}'.format(ci) | |
| ci += 1 | |
| elif isinstance(layer, nn.BatchNorm2d): | |
| name = 'bn_{}'.format(ci) | |
| else: | |
| raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__)) | |
| x = layer(x) | |
| out[name] = x | |
| # print([x for x in out]) | |
| return out | |
| class VGGLoss(pl.LightningModule): | |
| def __init__(self): | |
| super(VGGLoss, self).__init__() | |
| self.vgg = Vgg19().eval() | |
| self.criterion = nn.L1Loss() | |
| self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0] | |
| def forward(self, x, y): | |
| x_vgg, y_vgg = self.vgg(x), self.vgg(y) | |
| loss = 0 | |
| for i in range(len(x_vgg)): | |
| loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) | |
| return loss | |
| class GANLoss(pl.LightningModule): | |
| def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0): | |
| super(GANLoss, self).__init__() | |
| self.real_label = target_real_label | |
| self.fake_label = target_fake_label | |
| self.real_label_var = None | |
| self.fake_label_var = None | |
| self.tensor = torch.cuda.FloatTensor | |
| if use_lsgan: | |
| self.loss = nn.MSELoss() | |
| else: | |
| self.loss = nn.BCELoss() | |
| def get_target_tensor(self, input, target_is_real): | |
| target_tensor = None | |
| if target_is_real: | |
| create_label = ((self.real_label_var is None) or | |
| (self.real_label_var.numel() != input.numel())) | |
| if create_label: | |
| real_tensor = self.tensor(input.size()).fill_(self.real_label) | |
| self.real_label_var = real_tensor | |
| self.real_label_var.requires_grad = False | |
| target_tensor = self.real_label_var | |
| else: | |
| create_label = ((self.fake_label_var is None) or | |
| (self.fake_label_var.numel() != input.numel())) | |
| if create_label: | |
| fake_tensor = self.tensor(input.size()).fill_(self.fake_label) | |
| self.fake_label_var = fake_tensor | |
| self.fake_label_var.requires_grad = False | |
| target_tensor = self.fake_label_var | |
| return target_tensor | |
| def __call__(self, input, target_is_real): | |
| if isinstance(input[0], list): | |
| loss = 0 | |
| for input_i in input: | |
| pred = input_i[-1] | |
| target_tensor = self.get_target_tensor(pred, target_is_real) | |
| loss += self.loss(pred, target_tensor) | |
| return loss | |
| else: | |
| target_tensor = self.get_target_tensor(input[-1], target_is_real) | |
| return self.loss(input[-1], target_tensor) | |
| class IDMRFLoss(pl.LightningModule): | |
| def __init__(self, featlayer=VGG19FeatLayer): | |
| super(IDMRFLoss, self).__init__() | |
| self.featlayer = featlayer() | |
| self.feat_style_layers = {'relu3_2': 1.0, 'relu4_2': 1.0} | |
| self.feat_content_layers = {'relu4_2': 1.0} | |
| self.bias = 1.0 | |
| self.nn_stretch_sigma = 0.5 | |
| self.lambda_style = 1.0 | |
| self.lambda_content = 1.0 | |
| def sum_normalize(self, featmaps): | |
| reduce_sum = torch.sum(featmaps, dim=1, keepdim=True) | |
| return featmaps / reduce_sum | |
| def patch_extraction(self, featmaps): | |
| patch_size = 1 | |
| patch_stride = 1 | |
| patches_as_depth_vectors = featmaps.unfold(2, patch_size, patch_stride).unfold( | |
| 3, patch_size, patch_stride | |
| ) | |
| self.patches_OIHW = patches_as_depth_vectors.permute(0, 2, 3, 1, 4, 5) | |
| dims = self.patches_OIHW.size() | |
| self.patches_OIHW = self.patches_OIHW.view(-1, dims[3], dims[4], dims[5]) | |
| return self.patches_OIHW | |
| def compute_relative_distances(self, cdist): | |
| epsilon = 1e-5 | |
| div = torch.min(cdist, dim=1, keepdim=True)[0] | |
| relative_dist = cdist / (div + epsilon) | |
| return relative_dist | |
| def exp_norm_relative_dist(self, relative_dist): | |
| scaled_dist = relative_dist | |
| dist_before_norm = torch.exp((self.bias - scaled_dist) / self.nn_stretch_sigma) | |
| self.cs_NCHW = self.sum_normalize(dist_before_norm) | |
| return self.cs_NCHW | |
| def mrf_loss(self, gen, tar): | |
| meanT = torch.mean(tar, 1, keepdim=True) | |
| gen_feats, tar_feats = gen - meanT, tar - meanT | |
| gen_feats_norm = torch.norm(gen_feats, p=2, dim=1, keepdim=True) | |
| tar_feats_norm = torch.norm(tar_feats, p=2, dim=1, keepdim=True) | |
| gen_normalized = gen_feats / gen_feats_norm | |
| tar_normalized = tar_feats / tar_feats_norm | |
| cosine_dist_l = [] | |
| BatchSize = tar.size(0) | |
| for i in range(BatchSize): | |
| tar_feat_i = tar_normalized[i:i + 1, :, :, :] | |
| gen_feat_i = gen_normalized[i:i + 1, :, :, :] | |
| patches_OIHW = self.patch_extraction(tar_feat_i) | |
| cosine_dist_i = F.conv2d(gen_feat_i, patches_OIHW) | |
| cosine_dist_l.append(cosine_dist_i) | |
| cosine_dist = torch.cat(cosine_dist_l, dim=0) | |
| cosine_dist_zero_2_one = -(cosine_dist - 1) / 2 | |
| relative_dist = self.compute_relative_distances(cosine_dist_zero_2_one) | |
| rela_dist = self.exp_norm_relative_dist(relative_dist) | |
| dims_div_mrf = rela_dist.size() | |
| k_max_nc = torch.max(rela_dist.view(dims_div_mrf[0], dims_div_mrf[1], -1), dim=2)[0] | |
| div_mrf = torch.mean(k_max_nc, dim=1) | |
| div_mrf_sum = -torch.log(div_mrf) | |
| div_mrf_sum = torch.sum(div_mrf_sum) | |
| return div_mrf_sum | |
| def forward(self, gen, tar): | |
| ## gen: [bz,3,h,w] rgb [0,1] | |
| gen_vgg_feats = self.featlayer(gen) | |
| tar_vgg_feats = self.featlayer(tar) | |
| style_loss_list = [ | |
| self.feat_style_layers[layer] * | |
| self.mrf_loss(gen_vgg_feats[layer], tar_vgg_feats[layer]) | |
| for layer in self.feat_style_layers | |
| ] | |
| self.style_loss = functools.reduce(lambda x, y: x + y, style_loss_list) * self.lambda_style | |
| content_loss_list = [ | |
| self.feat_content_layers[layer] * | |
| self.mrf_loss(gen_vgg_feats[layer], tar_vgg_feats[layer]) | |
| for layer in self.feat_content_layers | |
| ] | |
| self.content_loss = functools.reduce( | |
| lambda x, y: x + y, content_loss_list | |
| ) * self.lambda_content | |
| return self.style_loss + self.content_loss | |