Spaces:
Runtime error
Runtime error
| """ The code is based on https://github.com/apple/ml-gsn/ with adaption. """ | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from lib.torch_utils.ops.native_ops import ( | |
| FusedLeakyReLU, | |
| fused_leaky_relu, | |
| upfirdn2d, | |
| ) | |
| class DiscriminatorHead(nn.Module): | |
| def __init__(self, in_channel, disc_stddev=False): | |
| super().__init__() | |
| self.disc_stddev = disc_stddev | |
| stddev_dim = 1 if disc_stddev else 0 | |
| self.conv_stddev = ConvLayer2d( | |
| in_channel=in_channel + stddev_dim, | |
| out_channel=in_channel, | |
| kernel_size=3, | |
| activate=True | |
| ) | |
| self.final_linear = nn.Sequential( | |
| nn.Flatten(), | |
| EqualLinear(in_channel=in_channel * 4 * 4, out_channel=in_channel, activate=True), | |
| EqualLinear(in_channel=in_channel, out_channel=1), | |
| ) | |
| def cat_stddev(self, x, stddev_group=4, stddev_feat=1): | |
| perm = torch.randperm(len(x)) | |
| inv_perm = torch.argsort(perm) | |
| batch, channel, height, width = x.shape | |
| x = x[perm | |
| ] # shuffle inputs so that all views in a single trajectory don't get put together | |
| group = min(batch, stddev_group) | |
| stddev = x.view(group, -1, stddev_feat, channel // stddev_feat, height, width) | |
| stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) | |
| stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) | |
| stddev = stddev.repeat(group, 1, height, width) | |
| stddev = stddev[inv_perm] # reorder inputs | |
| x = x[inv_perm] | |
| out = torch.cat([x, stddev], 1) | |
| return out | |
| def forward(self, x): | |
| if self.disc_stddev: | |
| x = self.cat_stddev(x) | |
| x = self.conv_stddev(x) | |
| out = self.final_linear(x) | |
| return out | |
| class ConvDecoder(nn.Module): | |
| def __init__(self, in_channel, out_channel, in_res, out_res): | |
| super().__init__() | |
| log_size_in = int(math.log(in_res, 2)) | |
| log_size_out = int(math.log(out_res, 2)) | |
| self.layers = [] | |
| in_ch = in_channel | |
| for i in range(log_size_in, log_size_out): | |
| out_ch = in_ch // 2 | |
| self.layers.append( | |
| ConvLayer2d( | |
| in_channel=in_ch, | |
| out_channel=out_ch, | |
| kernel_size=3, | |
| upsample=True, | |
| bias=True, | |
| activate=True | |
| ) | |
| ) | |
| in_ch = out_ch | |
| self.layers.append( | |
| ConvLayer2d( | |
| in_channel=in_ch, out_channel=out_channel, kernel_size=3, bias=True, activate=False | |
| ) | |
| ) | |
| self.layers = nn.Sequential(*self.layers) | |
| def forward(self, x): | |
| return self.layers(x) | |
| class StyleDiscriminator(nn.Module): | |
| def __init__(self, in_channel, in_res, ch_mul=64, ch_max=512, **kwargs): | |
| super().__init__() | |
| log_size_in = int(math.log(in_res, 2)) | |
| log_size_out = int(math.log(4, 2)) | |
| self.conv_in = ConvLayer2d(in_channel=in_channel, out_channel=ch_mul, kernel_size=3) | |
| # each resblock will half the resolution and double the number of features (until a maximum of ch_max) | |
| self.layers = [] | |
| in_channels = ch_mul | |
| for i in range(log_size_in, log_size_out, -1): | |
| out_channels = int(min(in_channels * 2, ch_max)) | |
| self.layers.append( | |
| ConvResBlock2d(in_channel=in_channels, out_channel=out_channels, downsample=True) | |
| ) | |
| in_channels = out_channels | |
| self.layers = nn.Sequential(*self.layers) | |
| self.disc_out = DiscriminatorHead(in_channel=in_channels, disc_stddev=True) | |
| def forward(self, x): | |
| x = self.conv_in(x) | |
| x = self.layers(x) | |
| out = self.disc_out(x) | |
| return out | |
| def make_kernel(k): | |
| k = torch.tensor(k, dtype=torch.float32) | |
| if k.ndim == 1: | |
| k = k[None, :] * k[:, None] | |
| k /= k.sum() | |
| return k | |
| class Blur(nn.Module): | |
| """Blur layer. | |
| Applies a blur kernel to input image using finite impulse response filter. Blurring feature maps after | |
| convolutional upsampling or before convolutional downsampling helps produces models that are more robust to | |
| shifting inputs (https://richzhang.github.io/antialiased-cnns/). In the context of GANs, this can provide | |
| cleaner gradients, and therefore more stable training. | |
| Args: | |
| ---- | |
| kernel: list, int | |
| A list of integers representing a blur kernel. For exmaple: [1, 3, 3, 1]. | |
| pad: tuple, int | |
| A tuple of integers representing the number of rows/columns of padding to be added to the top/left and | |
| the bottom/right respectively. | |
| upsample_factor: int | |
| Upsample factor. | |
| """ | |
| def __init__(self, kernel, pad, upsample_factor=1): | |
| super().__init__() | |
| kernel = make_kernel(kernel) | |
| if upsample_factor > 1: | |
| kernel = kernel * (upsample_factor**2) | |
| self.register_buffer("kernel", kernel) | |
| self.pad = pad | |
| def forward(self, input): | |
| out = upfirdn2d(input, self.kernel, pad=self.pad) | |
| return out | |
| class Upsample(nn.Module): | |
| """Upsampling layer. | |
| Perform upsampling using a blur kernel. | |
| Args: | |
| ---- | |
| kernel: list, int | |
| A list of integers representing a blur kernel. For exmaple: [1, 3, 3, 1]. | |
| factor: int | |
| Upsampling factor. | |
| """ | |
| def __init__(self, kernel=[1, 3, 3, 1], factor=2): | |
| super().__init__() | |
| self.factor = factor | |
| kernel = make_kernel(kernel) * (factor**2) | |
| self.register_buffer("kernel", kernel) | |
| p = kernel.shape[0] - factor | |
| pad0 = (p + 1) // 2 + factor - 1 | |
| pad1 = p // 2 | |
| self.pad = (pad0, pad1) | |
| def forward(self, input): | |
| out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) | |
| return out | |
| class Downsample(nn.Module): | |
| """Downsampling layer. | |
| Perform downsampling using a blur kernel. | |
| Args: | |
| ---- | |
| kernel: list, int | |
| A list of integers representing a blur kernel. For exmaple: [1, 3, 3, 1]. | |
| factor: int | |
| Downsampling factor. | |
| """ | |
| def __init__(self, kernel=[1, 3, 3, 1], factor=2): | |
| super().__init__() | |
| self.factor = factor | |
| kernel = make_kernel(kernel) | |
| self.register_buffer("kernel", kernel) | |
| p = kernel.shape[0] - factor | |
| pad0 = (p + 1) // 2 | |
| pad1 = p // 2 | |
| self.pad = (pad0, pad1) | |
| def forward(self, input): | |
| out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad) | |
| return out | |
| class EqualLinear(nn.Module): | |
| """Linear layer with equalized learning rate. | |
| During the forward pass the weights are scaled by the inverse of the He constant (i.e. sqrt(in_dim)) to | |
| prevent vanishing gradients and accelerate training. This constant only works for ReLU or LeakyReLU | |
| activation functions. | |
| Args: | |
| ---- | |
| in_channel: int | |
| Input channels. | |
| out_channel: int | |
| Output channels. | |
| bias: bool | |
| Use bias term. | |
| bias_init: float | |
| Initial value for the bias. | |
| lr_mul: float | |
| Learning rate multiplier. By scaling weights and the bias we can proportionally scale the magnitude of | |
| the gradients, effectively increasing/decreasing the learning rate for this layer. | |
| activate: bool | |
| Apply leakyReLU activation. | |
| """ | |
| def __init__(self, in_channel, out_channel, bias=True, bias_init=0, lr_mul=1, activate=False): | |
| super().__init__() | |
| self.weight = nn.Parameter(torch.randn(out_channel, in_channel).div_(lr_mul)) | |
| if bias: | |
| self.bias = nn.Parameter(torch.zeros(out_channel).fill_(bias_init)) | |
| else: | |
| self.bias = None | |
| self.activate = activate | |
| self.scale = (1 / math.sqrt(in_channel)) * lr_mul | |
| self.lr_mul = lr_mul | |
| def forward(self, input): | |
| if self.activate: | |
| out = F.linear(input, self.weight * self.scale) | |
| out = fused_leaky_relu(out, self.bias * self.lr_mul) | |
| else: | |
| out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul) | |
| return out | |
| def __repr__(self): | |
| return f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})" | |
| class EqualConv2d(nn.Module): | |
| """2D convolution layer with equalized learning rate. | |
| During the forward pass the weights are scaled by the inverse of the He constant (i.e. sqrt(in_dim)) to | |
| prevent vanishing gradients and accelerate training. This constant only works for ReLU or LeakyReLU | |
| activation functions. | |
| Args: | |
| ---- | |
| in_channel: int | |
| Input channels. | |
| out_channel: int | |
| Output channels. | |
| kernel_size: int | |
| Kernel size. | |
| stride: int | |
| Stride of convolutional kernel across the input. | |
| padding: int | |
| Amount of zero padding applied to both sides of the input. | |
| bias: bool | |
| Use bias term. | |
| """ | |
| def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True): | |
| super().__init__() | |
| self.weight = nn.Parameter(torch.randn(out_channel, in_channel, kernel_size, kernel_size)) | |
| self.scale = 1 / math.sqrt(in_channel * kernel_size**2) | |
| self.stride = stride | |
| self.padding = padding | |
| if bias: | |
| self.bias = nn.Parameter(torch.zeros(out_channel)) | |
| else: | |
| self.bias = None | |
| def forward(self, input): | |
| out = F.conv2d( | |
| input, | |
| self.weight * self.scale, | |
| bias=self.bias, | |
| stride=self.stride, | |
| padding=self.padding | |
| ) | |
| return out | |
| def __repr__(self): | |
| return ( | |
| f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]}," | |
| f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})" | |
| ) | |
| class EqualConvTranspose2d(nn.Module): | |
| """2D transpose convolution layer with equalized learning rate. | |
| During the forward pass the weights are scaled by the inverse of the He constant (i.e. sqrt(in_dim)) to | |
| prevent vanishing gradients and accelerate training. This constant only works for ReLU or LeakyReLU | |
| activation functions. | |
| Args: | |
| ---- | |
| in_channel: int | |
| Input channels. | |
| out_channel: int | |
| Output channels. | |
| kernel_size: int | |
| Kernel size. | |
| stride: int | |
| Stride of convolutional kernel across the input. | |
| padding: int | |
| Amount of zero padding applied to both sides of the input. | |
| output_padding: int | |
| Extra padding added to input to achieve the desired output size. | |
| bias: bool | |
| Use bias term. | |
| """ | |
| def __init__( | |
| self, | |
| in_channel, | |
| out_channel, | |
| kernel_size, | |
| stride=1, | |
| padding=0, | |
| output_padding=0, | |
| bias=True | |
| ): | |
| super().__init__() | |
| self.weight = nn.Parameter(torch.randn(in_channel, out_channel, kernel_size, kernel_size)) | |
| self.scale = 1 / math.sqrt(in_channel * kernel_size**2) | |
| self.stride = stride | |
| self.padding = padding | |
| self.output_padding = output_padding | |
| if bias: | |
| self.bias = nn.Parameter(torch.zeros(out_channel)) | |
| else: | |
| self.bias = None | |
| def forward(self, input): | |
| out = F.conv_transpose2d( | |
| input, | |
| self.weight * self.scale, | |
| bias=self.bias, | |
| stride=self.stride, | |
| padding=self.padding, | |
| output_padding=self.output_padding, | |
| ) | |
| return out | |
| def __repr__(self): | |
| return ( | |
| f'{self.__class__.__name__}({self.weight.shape[0]}, {self.weight.shape[1]},' | |
| f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})' | |
| ) | |
| class ConvLayer2d(nn.Sequential): | |
| def __init__( | |
| self, | |
| in_channel, | |
| out_channel, | |
| kernel_size=3, | |
| upsample=False, | |
| downsample=False, | |
| blur_kernel=[1, 3, 3, 1], | |
| bias=True, | |
| activate=True, | |
| ): | |
| assert not (upsample and downsample), 'Cannot upsample and downsample simultaneously' | |
| layers = [] | |
| if upsample: | |
| factor = 2 | |
| p = (len(blur_kernel) - factor) - (kernel_size - 1) | |
| pad0 = (p + 1) // 2 + factor - 1 | |
| pad1 = p // 2 + 1 | |
| layers.append( | |
| EqualConvTranspose2d( | |
| in_channel, | |
| out_channel, | |
| kernel_size, | |
| padding=0, | |
| stride=2, | |
| bias=bias and not activate | |
| ) | |
| ) | |
| layers.append(Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)) | |
| if downsample: | |
| factor = 2 | |
| p = (len(blur_kernel) - factor) + (kernel_size - 1) | |
| pad0 = (p + 1) // 2 | |
| pad1 = p // 2 | |
| layers.append(Blur(blur_kernel, pad=(pad0, pad1))) | |
| layers.append( | |
| EqualConv2d( | |
| in_channel, | |
| out_channel, | |
| kernel_size, | |
| padding=0, | |
| stride=2, | |
| bias=bias and not activate | |
| ) | |
| ) | |
| if (not downsample) and (not upsample): | |
| padding = kernel_size // 2 | |
| layers.append( | |
| EqualConv2d( | |
| in_channel, | |
| out_channel, | |
| kernel_size, | |
| padding=padding, | |
| stride=1, | |
| bias=bias and not activate | |
| ) | |
| ) | |
| if activate: | |
| layers.append(FusedLeakyReLU(out_channel, bias=bias)) | |
| super().__init__(*layers) | |
| class ConvResBlock2d(nn.Module): | |
| """2D convolutional residual block with equalized learning rate. | |
| Residual block composed of 3x3 convolutions and leaky ReLUs. | |
| Args: | |
| ---- | |
| in_channel: int | |
| Input channels. | |
| out_channel: int | |
| Output channels. | |
| upsample: bool | |
| Apply upsampling via strided convolution in the first conv. | |
| downsample: bool | |
| Apply downsampling via strided convolution in the second conv. | |
| """ | |
| def __init__(self, in_channel, out_channel, upsample=False, downsample=False): | |
| super().__init__() | |
| assert not (upsample and downsample), 'Cannot upsample and downsample simultaneously' | |
| mid_ch = in_channel if downsample else out_channel | |
| self.conv1 = ConvLayer2d(in_channel, mid_ch, upsample=upsample, kernel_size=3) | |
| self.conv2 = ConvLayer2d(mid_ch, out_channel, downsample=downsample, kernel_size=3) | |
| if (in_channel != out_channel) or upsample or downsample: | |
| self.skip = ConvLayer2d( | |
| in_channel, | |
| out_channel, | |
| upsample=upsample, | |
| downsample=downsample, | |
| kernel_size=1, | |
| activate=False, | |
| bias=False, | |
| ) | |
| def forward(self, input): | |
| out = self.conv1(input) | |
| out = self.conv2(out) | |
| if hasattr(self, 'skip'): | |
| skip = self.skip(input) | |
| out = (out + skip) / math.sqrt(2) | |
| else: | |
| out = (out + input) / math.sqrt(2) | |
| return out | |