Spaces:
Runtime error
Runtime error
| # -*- coding: utf-8 -*- | |
| # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is | |
| # holder of all proprietary rights on this computer program. | |
| # You can only use this computer program if you have closed | |
| # a license agreement with MPG or you get the right to use the computer | |
| # program from someone who is authorized to grant you that right. | |
| # Any use of the computer program without a valid license is prohibited and | |
| # liable to prosecution. | |
| # | |
| # Copyright©2019 Max-Planck-Gesellschaft zur Förderung | |
| # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute | |
| # for Intelligent Systems. All rights reserved. | |
| # | |
| # Contact: [email protected] | |
| import functools | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.autograd import grad | |
| from torch.nn import init | |
| def gradient(inputs, outputs): | |
| d_points = torch.ones_like(outputs, requires_grad=False, device=outputs.device) | |
| points_grad = grad( | |
| outputs=outputs, | |
| inputs=inputs, | |
| grad_outputs=d_points, | |
| create_graph=True, | |
| retain_graph=True, | |
| only_inputs=True, | |
| allow_unused=True, | |
| )[0] | |
| return points_grad | |
| # def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False): | |
| # "3x3 convolution with padding" | |
| # return nn.Conv2d(in_planes, out_planes, kernel_size=3, | |
| # stride=strd, padding=padding, bias=bias) | |
| def conv3x3(in_planes, out_planes, kernel=3, strd=1, dilation=1, padding=1, bias=False): | |
| "3x3 convolution with padding" | |
| return nn.Conv2d( | |
| in_planes, | |
| out_planes, | |
| kernel_size=kernel, | |
| dilation=dilation, | |
| stride=strd, | |
| padding=padding, | |
| bias=bias, | |
| ) | |
| def conv1x1(in_planes, out_planes, stride=1): | |
| """1x1 convolution""" | |
| return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) | |
| def init_weights(net, init_type="normal", init_gain=0.02): | |
| """Initialize network weights. | |
| Parameters: | |
| net (network) -- network to be initialized | |
| init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal | |
| init_gain (float) -- scaling factor for normal, xavier and orthogonal. | |
| We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might | |
| work better for some applications. Feel free to try yourself. | |
| """ | |
| def init_func(m): # define the initialization function | |
| classname = m.__class__.__name__ | |
| if hasattr(m, | |
| "weight") and (classname.find("Conv") != -1 or classname.find("Linear") != -1): | |
| if init_type == "normal": | |
| init.normal_(m.weight.data, 0.0, init_gain) | |
| elif init_type == "xavier": | |
| init.xavier_normal_(m.weight.data, gain=init_gain) | |
| elif init_type == "kaiming": | |
| init.kaiming_normal_(m.weight.data, a=0, mode="fan_in") | |
| elif init_type == "orthogonal": | |
| init.orthogonal_(m.weight.data, gain=init_gain) | |
| else: | |
| raise NotImplementedError( | |
| "initialization method [%s] is not implemented" % init_type | |
| ) | |
| if hasattr(m, "bias") and m.bias is not None: | |
| init.constant_(m.bias.data, 0.0) | |
| elif ( | |
| classname.find("BatchNorm2d") != -1 | |
| ): # BatchNorm Layer's weight is not a matrix; only normal distribution applies. | |
| init.normal_(m.weight.data, 1.0, init_gain) | |
| init.constant_(m.bias.data, 0.0) | |
| # print('initialize network with %s' % init_type) | |
| net.apply(init_func) # apply the initialization function <init_func> | |
| def init_net(net, init_type="xavier", init_gain=0.02, gpu_ids=[]): | |
| """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights | |
| Parameters: | |
| net (network) -- the network to be initialized | |
| init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal | |
| gain (float) -- scaling factor for normal, xavier and orthogonal. | |
| gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 | |
| Return an initialized network. | |
| """ | |
| if len(gpu_ids) > 0: | |
| assert torch.cuda.is_available() | |
| net = torch.nn.DataParallel(net) # multi-GPUs | |
| init_weights(net, init_type, init_gain=init_gain) | |
| return net | |
| def imageSpaceRotation(xy, rot): | |
| """ | |
| args: | |
| xy: (B, 2, N) input | |
| rot: (B, 2) x,y axis rotation angles | |
| rotation center will be always image center (other rotation center can be represented by additional z translation) | |
| """ | |
| disp = rot.unsqueeze(2).sin().expand_as(xy) | |
| return (disp * xy).sum(dim=1) | |
| def cal_gradient_penalty( | |
| netD, real_data, fake_data, device, type="mixed", constant=1.0, lambda_gp=10.0 | |
| ): | |
| """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028 | |
| Arguments: | |
| netD (network) -- discriminator network | |
| real_data (tensor array) -- real images | |
| fake_data (tensor array) -- generated images from the generator | |
| device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') | |
| type (str) -- if we mix real and fake data or not [real | fake | mixed]. | |
| constant (float) -- the constant used in formula ( | |gradient||_2 - constant)^2 | |
| lambda_gp (float) -- weight for this loss | |
| Returns the gradient penalty loss | |
| """ | |
| if lambda_gp > 0.0: | |
| # either use real images, fake images, or a linear interpolation of two. | |
| if type == "real": | |
| interpolatesv = real_data | |
| elif type == "fake": | |
| interpolatesv = fake_data | |
| elif type == "mixed": | |
| alpha = torch.rand(real_data.shape[0], 1) | |
| alpha = ( | |
| alpha.expand(real_data.shape[0], | |
| real_data.nelement() // | |
| real_data.shape[0]).contiguous().view(*real_data.shape) | |
| ) | |
| alpha = alpha.to(device) | |
| interpolatesv = alpha * real_data + ((1 - alpha) * fake_data) | |
| else: | |
| raise NotImplementedError("{} not implemented".format(type)) | |
| interpolatesv.requires_grad_(True) | |
| disc_interpolates = netD(interpolatesv) | |
| gradients = torch.autograd.grad( | |
| outputs=disc_interpolates, | |
| inputs=interpolatesv, | |
| grad_outputs=torch.ones(disc_interpolates.size()).to(device), | |
| create_graph=True, | |
| retain_graph=True, | |
| only_inputs=True, | |
| ) | |
| gradients = gradients[0].view(real_data.size(0), -1) # flat the data | |
| gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant)** | |
| 2).mean() * lambda_gp # added eps | |
| return gradient_penalty, gradients | |
| else: | |
| return 0.0, None | |
| def get_norm_layer(norm_type="instance"): | |
| """Return a normalization layer | |
| Parameters: | |
| norm_type (str) -- the name of the normalization layer: batch | instance | none | |
| For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev). | |
| For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics. | |
| """ | |
| if norm_type == "batch": | |
| norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) | |
| elif norm_type == "instance": | |
| norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) | |
| elif norm_type == "group": | |
| norm_layer = functools.partial(nn.GroupNorm, 32) | |
| elif norm_type == "none": | |
| norm_layer = None | |
| else: | |
| raise NotImplementedError("normalization layer [%s] is not found" % norm_type) | |
| return norm_layer | |
| class Flatten(nn.Module): | |
| def forward(self, input): | |
| return input.view(input.size(0), -1) | |
| class ConvBlock(nn.Module): | |
| def __init__(self, in_planes, out_planes, opt): | |
| super(ConvBlock, self).__init__() | |
| [k, s, d, p] = opt.conv3x3 | |
| self.conv1 = conv3x3(in_planes, int(out_planes / 2), k, s, d, p) | |
| self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4), k, s, d, p) | |
| self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4), k, s, d, p) | |
| if opt.norm == "batch": | |
| self.bn1 = nn.BatchNorm2d(in_planes) | |
| self.bn2 = nn.BatchNorm2d(int(out_planes / 2)) | |
| self.bn3 = nn.BatchNorm2d(int(out_planes / 4)) | |
| self.bn4 = nn.BatchNorm2d(in_planes) | |
| elif opt.norm == "group": | |
| self.bn1 = nn.GroupNorm(32, in_planes) | |
| self.bn2 = nn.GroupNorm(32, int(out_planes / 2)) | |
| self.bn3 = nn.GroupNorm(32, int(out_planes / 4)) | |
| self.bn4 = nn.GroupNorm(32, in_planes) | |
| if in_planes != out_planes: | |
| self.downsample = nn.Sequential( | |
| self.bn4, | |
| nn.ReLU(True), | |
| nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, bias=False), | |
| ) | |
| else: | |
| self.downsample = None | |
| def forward(self, x): | |
| residual = x | |
| out1 = self.bn1(x) | |
| out1 = F.relu(out1, True) | |
| out1 = self.conv1(out1) | |
| out2 = self.bn2(out1) | |
| out2 = F.relu(out2, True) | |
| out2 = self.conv2(out2) | |
| out3 = self.bn3(out2) | |
| out3 = F.relu(out3, True) | |
| out3 = self.conv3(out3) | |
| out3 = torch.cat((out1, out2, out3), 1) | |
| if self.downsample is not None: | |
| residual = self.downsample(residual) | |
| out3 += residual | |
| return out3 | |