Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import math | |
| import torch as th | |
| from torch import nn | |
| from .utils import capture_init, center_trim | |
| class BLSTM(nn.Module): | |
| def __init__(self, dim, layers=1): | |
| super().__init__() | |
| self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim) | |
| self.linear = nn.Linear(2 * dim, dim) | |
| def forward(self, x): | |
| x = x.permute(2, 0, 1) | |
| x = self.lstm(x)[0] | |
| x = self.linear(x) | |
| x = x.permute(1, 2, 0) | |
| return x | |
| def rescale_conv(conv, reference): | |
| std = conv.weight.std().detach() | |
| scale = (std / reference)**0.5 | |
| conv.weight.data /= scale | |
| if conv.bias is not None: | |
| conv.bias.data /= scale | |
| def rescale_module(module, reference): | |
| for sub in module.modules(): | |
| if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)): | |
| rescale_conv(sub, reference) | |
| def upsample(x, stride): | |
| """ | |
| Linear upsampling, the output will be `stride` times longer. | |
| """ | |
| batch, channels, time = x.size() | |
| weight = th.arange(stride, device=x.device, dtype=th.float) / stride | |
| x = x.view(batch, channels, time, 1) | |
| out = x[..., :-1, :] * (1 - weight) + x[..., 1:, :] * weight | |
| return out.reshape(batch, channels, -1) | |
| def downsample(x, stride): | |
| """ | |
| Downsample x by decimation. | |
| """ | |
| return x[:, :, ::stride] | |
| class Demucs(nn.Module): | |
| def __init__(self, | |
| sources=4, | |
| audio_channels=2, | |
| channels=64, | |
| depth=6, | |
| rewrite=True, | |
| glu=True, | |
| upsample=False, | |
| rescale=0.1, | |
| kernel_size=8, | |
| stride=4, | |
| growth=2., | |
| lstm_layers=2, | |
| context=3, | |
| samplerate=44100): | |
| """ | |
| Args: | |
| sources (int): number of sources to separate | |
| audio_channels (int): stereo or mono | |
| channels (int): first convolution channels | |
| depth (int): number of encoder/decoder layers | |
| rewrite (bool): add 1x1 convolution to each encoder layer | |
| and a convolution to each decoder layer. | |
| For the decoder layer, `context` gives the kernel size. | |
| glu (bool): use glu instead of ReLU | |
| upsample (bool): use linear upsampling with convolutions | |
| Wave-U-Net style, instead of transposed convolutions | |
| rescale (int): rescale initial weights of convolutions | |
| to get their standard deviation closer to `rescale` | |
| kernel_size (int): kernel size for convolutions | |
| stride (int): stride for convolutions | |
| growth (float): multiply (resp divide) number of channels by that | |
| for each layer of the encoder (resp decoder) | |
| lstm_layers (int): number of lstm layers, 0 = no lstm | |
| context (int): kernel size of the convolution in the | |
| decoder before the transposed convolution. If > 1, | |
| will provide some context from neighboring time | |
| steps. | |
| """ | |
| super().__init__() | |
| self.audio_channels = audio_channels | |
| self.sources = sources | |
| self.kernel_size = kernel_size | |
| self.context = context | |
| self.stride = stride | |
| self.depth = depth | |
| self.upsample = upsample | |
| self.channels = channels | |
| self.samplerate = samplerate | |
| self.encoder = nn.ModuleList() | |
| self.decoder = nn.ModuleList() | |
| self.final = None | |
| if upsample: | |
| self.final = nn.Conv1d(channels + audio_channels, sources * audio_channels, 1) | |
| stride = 1 | |
| if glu: | |
| activation = nn.GLU(dim=1) | |
| ch_scale = 2 | |
| else: | |
| activation = nn.ReLU() | |
| ch_scale = 1 | |
| in_channels = audio_channels | |
| for index in range(depth): | |
| encode = [] | |
| encode += [nn.Conv1d(in_channels, channels, kernel_size, stride), nn.ReLU()] | |
| if rewrite: | |
| encode += [nn.Conv1d(channels, ch_scale * channels, 1), activation] | |
| self.encoder.append(nn.Sequential(*encode)) | |
| decode = [] | |
| if index > 0: | |
| out_channels = in_channels | |
| else: | |
| if upsample: | |
| out_channels = channels | |
| else: | |
| out_channels = sources * audio_channels | |
| if rewrite: | |
| decode += [nn.Conv1d(channels, ch_scale * channels, context), activation] | |
| if upsample: | |
| decode += [ | |
| nn.Conv1d(channels, out_channels, kernel_size, stride=1), | |
| ] | |
| else: | |
| decode += [nn.ConvTranspose1d(channels, out_channels, kernel_size, stride)] | |
| if index > 0: | |
| decode.append(nn.ReLU()) | |
| self.decoder.insert(0, nn.Sequential(*decode)) | |
| in_channels = channels | |
| channels = int(growth * channels) | |
| channels = in_channels | |
| if lstm_layers: | |
| self.lstm = BLSTM(channels, lstm_layers) | |
| else: | |
| self.lstm = None | |
| if rescale: | |
| rescale_module(self, reference=rescale) | |
| def valid_length(self, length): | |
| """ | |
| Return the nearest valid length to use with the model so that | |
| there is no time steps left over in a convolutions, e.g. for all | |
| layers, size of the input - kernel_size % stride = 0. | |
| If the mixture has a valid length, the estimated sources | |
| will have exactly the same length when context = 1. If context > 1, | |
| the two signals can be center trimmed to match. | |
| For training, extracts should have a valid length.For evaluation | |
| on full tracks we recommend passing `pad = True` to :method:`forward`. | |
| """ | |
| for _ in range(self.depth): | |
| if self.upsample: | |
| length = math.ceil(length / self.stride) + self.kernel_size - 1 | |
| else: | |
| length = math.ceil((length - self.kernel_size) / self.stride) + 1 | |
| length = max(1, length) | |
| length += self.context - 1 | |
| for _ in range(self.depth): | |
| if self.upsample: | |
| length = length * self.stride + self.kernel_size - 1 | |
| else: | |
| length = (length - 1) * self.stride + self.kernel_size | |
| return int(length) | |
| def forward(self, mix): | |
| x = mix | |
| saved = [x] | |
| for encode in self.encoder: | |
| x = encode(x) | |
| saved.append(x) | |
| if self.upsample: | |
| x = downsample(x, self.stride) | |
| if self.lstm: | |
| x = self.lstm(x) | |
| for decode in self.decoder: | |
| if self.upsample: | |
| x = upsample(x, stride=self.stride) | |
| skip = center_trim(saved.pop(-1), x) | |
| x = x + skip | |
| x = decode(x) | |
| if self.final: | |
| skip = center_trim(saved.pop(-1), x) | |
| x = th.cat([x, skip], dim=1) | |
| x = self.final(x) | |
| x = x.view(x.size(0), self.sources, self.audio_channels, x.size(-1)) | |
| return x | |