Spaces:
Runtime error
Runtime error
| # Copyright 2020 Erik Härkönen. All rights reserved. | |
| # This file is licensed to you under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. You may obtain a copy | |
| # of the License at http://www.apache.org/licenses/LICENSE-2.0 | |
| # Unless required by applicable law or agreed to in writing, software distributed under | |
| # the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS | |
| # OF ANY KIND, either express or implied. See the License for the specific language | |
| # governing permissions and limitations under the License. | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from collections import OrderedDict | |
| from pathlib import Path | |
| import requests | |
| import pickle | |
| import sys | |
| import numpy as np | |
| # Reimplementation of StyleGAN in PyTorch | |
| # Source: https://github.com/lernapparat/lernapparat/blob/master/style_gan/pytorch_style_gan.ipynb | |
| class MyLinear(nn.Module): | |
| """Linear layer with equalized learning rate and custom learning rate multiplier.""" | |
| def __init__(self, input_size, output_size, gain=2**(0.5), use_wscale=False, lrmul=1, bias=True): | |
| super().__init__() | |
| he_std = gain * input_size**(-0.5) # He init | |
| # Equalized learning rate and custom learning rate multiplier. | |
| if use_wscale: | |
| init_std = 1.0 / lrmul | |
| self.w_mul = he_std * lrmul | |
| else: | |
| init_std = he_std / lrmul | |
| self.w_mul = lrmul | |
| self.weight = torch.nn.Parameter(torch.randn(output_size, input_size) * init_std) | |
| if bias: | |
| self.bias = torch.nn.Parameter(torch.zeros(output_size)) | |
| self.b_mul = lrmul | |
| else: | |
| self.bias = None | |
| def forward(self, x): | |
| bias = self.bias | |
| if bias is not None: | |
| bias = bias * self.b_mul | |
| return F.linear(x, self.weight * self.w_mul, bias) | |
| class MyConv2d(nn.Module): | |
| """Conv layer with equalized learning rate and custom learning rate multiplier.""" | |
| def __init__(self, input_channels, output_channels, kernel_size, gain=2**(0.5), use_wscale=False, lrmul=1, bias=True, | |
| intermediate=None, upscale=False): | |
| super().__init__() | |
| if upscale: | |
| self.upscale = Upscale2d() | |
| else: | |
| self.upscale = None | |
| he_std = gain * (input_channels * kernel_size ** 2) ** (-0.5) # He init | |
| self.kernel_size = kernel_size | |
| if use_wscale: | |
| init_std = 1.0 / lrmul | |
| self.w_mul = he_std * lrmul | |
| else: | |
| init_std = he_std / lrmul | |
| self.w_mul = lrmul | |
| self.weight = torch.nn.Parameter(torch.randn(output_channels, input_channels, kernel_size, kernel_size) * init_std) | |
| if bias: | |
| self.bias = torch.nn.Parameter(torch.zeros(output_channels)) | |
| self.b_mul = lrmul | |
| else: | |
| self.bias = None | |
| self.intermediate = intermediate | |
| def forward(self, x): | |
| bias = self.bias | |
| if bias is not None: | |
| bias = bias * self.b_mul | |
| have_convolution = False | |
| if self.upscale is not None and min(x.shape[2:]) * 2 >= 128: | |
| # this is the fused upscale + conv from StyleGAN, sadly this seems incompatible with the non-fused way | |
| # this really needs to be cleaned up and go into the conv... | |
| w = self.weight * self.w_mul | |
| w = w.permute(1, 0, 2, 3) | |
| # probably applying a conv on w would be more efficient. also this quadruples the weight (average)?! | |
| w = F.pad(w, (1,1,1,1)) | |
| w = w[:, :, 1:, 1:]+ w[:, :, :-1, 1:] + w[:, :, 1:, :-1] + w[:, :, :-1, :-1] | |
| x = F.conv_transpose2d(x, w, stride=2, padding=(w.size(-1)-1)//2) | |
| have_convolution = True | |
| elif self.upscale is not None: | |
| x = self.upscale(x) | |
| if not have_convolution and self.intermediate is None: | |
| return F.conv2d(x, self.weight * self.w_mul, bias, padding=self.kernel_size//2) | |
| elif not have_convolution: | |
| x = F.conv2d(x, self.weight * self.w_mul, None, padding=self.kernel_size//2) | |
| if self.intermediate is not None: | |
| x = self.intermediate(x) | |
| if bias is not None: | |
| x = x + bias.view(1, -1, 1, 1) | |
| return x | |
| class NoiseLayer(nn.Module): | |
| """adds noise. noise is per pixel (constant over channels) with per-channel weight""" | |
| def __init__(self, channels): | |
| super().__init__() | |
| self.weight = nn.Parameter(torch.zeros(channels)) | |
| self.noise = None | |
| def forward(self, x, noise=None): | |
| if noise is None and self.noise is None: | |
| noise = torch.randn(x.size(0), 1, x.size(2), x.size(3), device=x.device, dtype=x.dtype) | |
| elif noise is None: | |
| # here is a little trick: if you get all the noiselayers and set each | |
| # modules .noise attribute, you can have pre-defined noise. | |
| # Very useful for analysis | |
| noise = self.noise | |
| x = x + self.weight.view(1, -1, 1, 1) * noise | |
| return x | |
| class StyleMod(nn.Module): | |
| def __init__(self, latent_size, channels, use_wscale): | |
| super(StyleMod, self).__init__() | |
| self.lin = MyLinear(latent_size, | |
| channels * 2, | |
| gain=1.0, use_wscale=use_wscale) | |
| def forward(self, x, latent): | |
| style = self.lin(latent) # style => [batch_size, n_channels*2] | |
| shape = [-1, 2, x.size(1)] + (x.dim() - 2) * [1] | |
| style = style.view(shape) # [batch_size, 2, n_channels, ...] | |
| x = x * (style[:, 0] + 1.) + style[:, 1] | |
| return x | |
| class PixelNormLayer(nn.Module): | |
| def __init__(self, epsilon=1e-8): | |
| super().__init__() | |
| self.epsilon = epsilon | |
| def forward(self, x): | |
| return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + self.epsilon) | |
| class BlurLayer(nn.Module): | |
| def __init__(self, kernel=[1, 2, 1], normalize=True, flip=False, stride=1): | |
| super(BlurLayer, self).__init__() | |
| kernel=[1, 2, 1] | |
| kernel = torch.tensor(kernel, dtype=torch.float32) | |
| kernel = kernel[:, None] * kernel[None, :] | |
| kernel = kernel[None, None] | |
| if normalize: | |
| kernel = kernel / kernel.sum() | |
| if flip: | |
| kernel = kernel[:, :, ::-1, ::-1] | |
| self.register_buffer('kernel', kernel) | |
| self.stride = stride | |
| def forward(self, x): | |
| # expand kernel channels | |
| kernel = self.kernel.expand(x.size(1), -1, -1, -1) | |
| x = F.conv2d( | |
| x, | |
| kernel, | |
| stride=self.stride, | |
| padding=int((self.kernel.size(2)-1)/2), | |
| groups=x.size(1) | |
| ) | |
| return x | |
| def upscale2d(x, factor=2, gain=1): | |
| assert x.dim() == 4 | |
| if gain != 1: | |
| x = x * gain | |
| if factor != 1: | |
| shape = x.shape | |
| x = x.view(shape[0], shape[1], shape[2], 1, shape[3], 1).expand(-1, -1, -1, factor, -1, factor) | |
| x = x.contiguous().view(shape[0], shape[1], factor * shape[2], factor * shape[3]) | |
| return x | |
| class Upscale2d(nn.Module): | |
| def __init__(self, factor=2, gain=1): | |
| super().__init__() | |
| assert isinstance(factor, int) and factor >= 1 | |
| self.gain = gain | |
| self.factor = factor | |
| def forward(self, x): | |
| return upscale2d(x, factor=self.factor, gain=self.gain) | |
| class G_mapping(nn.Sequential): | |
| def __init__(self, nonlinearity='lrelu', use_wscale=True): | |
| act, gain = {'relu': (torch.relu, np.sqrt(2)), | |
| 'lrelu': (nn.LeakyReLU(negative_slope=0.2), np.sqrt(2))}[nonlinearity] | |
| layers = [ | |
| ('pixel_norm', PixelNormLayer()), | |
| ('dense0', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)), | |
| ('dense0_act', act), | |
| ('dense1', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)), | |
| ('dense1_act', act), | |
| ('dense2', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)), | |
| ('dense2_act', act), | |
| ('dense3', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)), | |
| ('dense3_act', act), | |
| ('dense4', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)), | |
| ('dense4_act', act), | |
| ('dense5', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)), | |
| ('dense5_act', act), | |
| ('dense6', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)), | |
| ('dense6_act', act), | |
| ('dense7', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)), | |
| ('dense7_act', act) | |
| ] | |
| super().__init__(OrderedDict(layers)) | |
| def forward(self, x): | |
| return super().forward(x) | |
| class Truncation(nn.Module): | |
| def __init__(self, avg_latent, max_layer=8, threshold=0.7): | |
| super().__init__() | |
| self.max_layer = max_layer | |
| self.threshold = threshold | |
| self.register_buffer('avg_latent', avg_latent) | |
| def forward(self, x): | |
| assert x.dim() == 3 | |
| interp = torch.lerp(self.avg_latent, x, self.threshold) | |
| do_trunc = (torch.arange(x.size(1)) < self.max_layer).view(1, -1, 1) | |
| return torch.where(do_trunc, interp, x) | |
| class LayerEpilogue(nn.Module): | |
| """Things to do at the end of each layer.""" | |
| def __init__(self, channels, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer): | |
| super().__init__() | |
| layers = [] | |
| if use_noise: | |
| layers.append(('noise', NoiseLayer(channels))) | |
| layers.append(('activation', activation_layer)) | |
| if use_pixel_norm: | |
| layers.append(('pixel_norm', PixelNorm())) | |
| if use_instance_norm: | |
| layers.append(('instance_norm', nn.InstanceNorm2d(channels))) | |
| self.top_epi = nn.Sequential(OrderedDict(layers)) | |
| if use_styles: | |
| self.style_mod = StyleMod(dlatent_size, channels, use_wscale=use_wscale) | |
| else: | |
| self.style_mod = None | |
| def forward(self, x, dlatents_in_slice=None): | |
| x = self.top_epi(x) | |
| if self.style_mod is not None: | |
| x = self.style_mod(x, dlatents_in_slice) | |
| else: | |
| assert dlatents_in_slice is None | |
| return x | |
| class InputBlock(nn.Module): | |
| def __init__(self, nf, dlatent_size, const_input_layer, gain, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer): | |
| super().__init__() | |
| self.const_input_layer = const_input_layer | |
| self.nf = nf | |
| if self.const_input_layer: | |
| # called 'const' in tf | |
| self.const = nn.Parameter(torch.ones(1, nf, 4, 4)) | |
| self.bias = nn.Parameter(torch.ones(nf)) | |
| else: | |
| self.dense = MyLinear(dlatent_size, nf*16, gain=gain/4, use_wscale=use_wscale) # tweak gain to match the official implementation of Progressing GAN | |
| self.epi1 = LayerEpilogue(nf, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer) | |
| self.conv = MyConv2d(nf, nf, 3, gain=gain, use_wscale=use_wscale) | |
| self.epi2 = LayerEpilogue(nf, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer) | |
| def forward(self, dlatents_in_range): | |
| batch_size = dlatents_in_range.size(0) | |
| if self.const_input_layer: | |
| x = self.const.expand(batch_size, -1, -1, -1) | |
| x = x + self.bias.view(1, -1, 1, 1) | |
| else: | |
| x = self.dense(dlatents_in_range[:, 0]).view(batch_size, self.nf, 4, 4) | |
| x = self.epi1(x, dlatents_in_range[:, 0]) | |
| x = self.conv(x) | |
| x = self.epi2(x, dlatents_in_range[:, 1]) | |
| return x | |
| class GSynthesisBlock(nn.Module): | |
| def __init__(self, in_channels, out_channels, blur_filter, dlatent_size, gain, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer): | |
| # 2**res x 2**res # res = 3..resolution_log2 | |
| super().__init__() | |
| if blur_filter: | |
| blur = BlurLayer(blur_filter) | |
| else: | |
| blur = None | |
| self.conv0_up = MyConv2d(in_channels, out_channels, kernel_size=3, gain=gain, use_wscale=use_wscale, | |
| intermediate=blur, upscale=True) | |
| self.epi1 = LayerEpilogue(out_channels, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer) | |
| self.conv1 = MyConv2d(out_channels, out_channels, kernel_size=3, gain=gain, use_wscale=use_wscale) | |
| self.epi2 = LayerEpilogue(out_channels, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer) | |
| def forward(self, x, dlatents_in_range): | |
| x = self.conv0_up(x) | |
| x = self.epi1(x, dlatents_in_range[:, 0]) | |
| x = self.conv1(x) | |
| x = self.epi2(x, dlatents_in_range[:, 1]) | |
| return x | |
| class G_synthesis(nn.Module): | |
| def __init__(self, | |
| dlatent_size = 512, # Disentangled latent (W) dimensionality. | |
| num_channels = 3, # Number of output color channels. | |
| resolution = 1024, # Output resolution. | |
| fmap_base = 8192, # Overall multiplier for the number of feature maps. | |
| fmap_decay = 1.0, # log2 feature map reduction when doubling the resolution. | |
| fmap_max = 512, # Maximum number of feature maps in any layer. | |
| use_styles = True, # Enable style inputs? | |
| const_input_layer = True, # First layer is a learned constant? | |
| use_noise = True, # Enable noise inputs? | |
| randomize_noise = True, # True = randomize noise inputs every time (non-deterministic), False = read noise inputs from variables. | |
| nonlinearity = 'lrelu', # Activation function: 'relu', 'lrelu' | |
| use_wscale = True, # Enable equalized learning rate? | |
| use_pixel_norm = False, # Enable pixelwise feature vector normalization? | |
| use_instance_norm = True, # Enable instance normalization? | |
| dtype = torch.float32, # Data type to use for activations and outputs. | |
| blur_filter = [1,2,1], # Low-pass filter to apply when resampling activations. None = no filtering. | |
| ): | |
| super().__init__() | |
| def nf(stage): | |
| return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max) | |
| self.dlatent_size = dlatent_size | |
| resolution_log2 = int(np.log2(resolution)) | |
| assert resolution == 2**resolution_log2 and resolution >= 4 | |
| act, gain = {'relu': (torch.relu, np.sqrt(2)), | |
| 'lrelu': (nn.LeakyReLU(negative_slope=0.2), np.sqrt(2))}[nonlinearity] | |
| num_layers = resolution_log2 * 2 - 2 | |
| num_styles = num_layers if use_styles else 1 | |
| torgbs = [] | |
| blocks = [] | |
| for res in range(2, resolution_log2 + 1): | |
| channels = nf(res-1) | |
| name = '{s}x{s}'.format(s=2**res) | |
| if res == 2: | |
| blocks.append((name, | |
| InputBlock(channels, dlatent_size, const_input_layer, gain, use_wscale, | |
| use_noise, use_pixel_norm, use_instance_norm, use_styles, act))) | |
| else: | |
| blocks.append((name, | |
| GSynthesisBlock(last_channels, channels, blur_filter, dlatent_size, gain, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, act))) | |
| last_channels = channels | |
| self.torgb = MyConv2d(channels, num_channels, 1, gain=1, use_wscale=use_wscale) | |
| self.blocks = nn.ModuleDict(OrderedDict(blocks)) | |
| def forward(self, dlatents_in): | |
| # Input: Disentangled latents (W) [minibatch, num_layers, dlatent_size]. | |
| # lod_in = tf.cast(tf.get_variable('lod', initializer=np.float32(0), trainable=False), dtype) | |
| batch_size = dlatents_in.size(0) | |
| for i, m in enumerate(self.blocks.values()): | |
| if i == 0: | |
| x = m(dlatents_in[:, 2*i:2*i+2]) | |
| else: | |
| x = m(x, dlatents_in[:, 2*i:2*i+2]) | |
| rgb = self.torgb(x) | |
| return rgb | |
| class StyleGAN_G(nn.Sequential): | |
| def __init__(self, resolution, truncation=1.0): | |
| self.resolution = resolution | |
| self.layers = OrderedDict([ | |
| ('g_mapping', G_mapping()), | |
| #('truncation', Truncation(avg_latent)), | |
| ('g_synthesis', G_synthesis(resolution=resolution)), | |
| ]) | |
| super().__init__(self.layers) | |
| def forward(self, x, latent_is_w=False): | |
| if isinstance(x, list): | |
| assert len(x) == 18, 'Must provide 1 or 18 latents' | |
| if not latent_is_w: | |
| x = [self.layers['g_mapping'].forward(l) for l in x] | |
| x = torch.stack(x, dim=1) | |
| else: | |
| if not latent_is_w: | |
| x = self.layers['g_mapping'].forward(x) | |
| x = x.unsqueeze(1).expand(-1, 18, -1) | |
| x = self.layers['g_synthesis'].forward(x) | |
| return x | |
| # From: https://github.com/lernapparat/lernapparat/releases/download/v2019-02-01/ | |
| def load_weights(self, checkpoint): | |
| self.load_state_dict(torch.load(checkpoint)) | |
| def export_from_tf(self, pickle_path): | |
| module_path = Path(__file__).parent / 'stylegan_tf' | |
| sys.path.append(str(module_path.resolve())) | |
| import dnnlib, dnnlib.tflib, pickle, torch, collections | |
| dnnlib.tflib.init_tf() | |
| weights = pickle.load(open(pickle_path,'rb')) | |
| weights_pt = [collections.OrderedDict([(k, torch.from_numpy(v.value().eval())) for k,v in w.trainables.items()]) for w in weights] | |
| #torch.save(weights_pt, pytorch_name) | |
| # then on the PyTorch side run | |
| state_G, state_D, state_Gs = weights_pt #torch.load('./karras2019stylegan-ffhq-1024x1024.pt') | |
| def key_translate(k): | |
| k = k.lower().split('/') | |
| if k[0] == 'g_synthesis': | |
| if not k[1].startswith('torgb'): | |
| k.insert(1, 'blocks') | |
| k = '.'.join(k) | |
| k = (k.replace('const.const','const').replace('const.bias','bias').replace('const.stylemod','epi1.style_mod.lin') | |
| .replace('const.noise.weight','epi1.top_epi.noise.weight') | |
| .replace('conv.noise.weight','epi2.top_epi.noise.weight') | |
| .replace('conv.stylemod','epi2.style_mod.lin') | |
| .replace('conv0_up.noise.weight', 'epi1.top_epi.noise.weight') | |
| .replace('conv0_up.stylemod','epi1.style_mod.lin') | |
| .replace('conv1.noise.weight', 'epi2.top_epi.noise.weight') | |
| .replace('conv1.stylemod','epi2.style_mod.lin') | |
| .replace('torgb_lod0','torgb')) | |
| else: | |
| k = '.'.join(k) | |
| return k | |
| def weight_translate(k, w): | |
| k = key_translate(k) | |
| if k.endswith('.weight'): | |
| if w.dim() == 2: | |
| w = w.t() | |
| elif w.dim() == 1: | |
| pass | |
| else: | |
| assert w.dim() == 4 | |
| w = w.permute(3, 2, 0, 1) | |
| return w | |
| # we delete the useless torgb filters | |
| param_dict = {key_translate(k) : weight_translate(k, v) for k,v in state_Gs.items() if 'torgb_lod' not in key_translate(k)} | |
| if 1: | |
| sd_shapes = {k : v.shape for k,v in self.state_dict().items()} | |
| param_shapes = {k : v.shape for k,v in param_dict.items() } | |
| for k in list(sd_shapes)+list(param_shapes): | |
| pds = param_shapes.get(k) | |
| sds = sd_shapes.get(k) | |
| if pds is None: | |
| print ("sd only", k, sds) | |
| elif sds is None: | |
| print ("pd only", k, pds) | |
| elif sds != pds: | |
| print ("mismatch!", k, pds, sds) | |
| self.load_state_dict(param_dict, strict=False) # needed for the blur kernels | |
| torch.save(self.state_dict(), Path(pickle_path).with_suffix('.pt')) |