Spaces:
Build error
Build error
Upload 16 files
Browse files- sound_extraction/model/LASSNet.py +25 -0
- sound_extraction/model/__pycache__/LASSNet.cpython-38.pyc +0 -0
- sound_extraction/model/__pycache__/film.cpython-38.pyc +0 -0
- sound_extraction/model/__pycache__/modules.cpython-38.pyc +0 -0
- sound_extraction/model/__pycache__/resunet_film.cpython-38.pyc +0 -0
- sound_extraction/model/__pycache__/text_encoder.cpython-38.pyc +0 -0
- sound_extraction/model/film.py +27 -0
- sound_extraction/model/modules.py +483 -0
- sound_extraction/model/resunet_film.py +110 -0
- sound_extraction/model/text_encoder.py +45 -0
- sound_extraction/useful_ckpts/LASSNet.pt +3 -0
- sound_extraction/utils/__pycache__/stft.cpython-38.pyc +0 -0
- sound_extraction/utils/__pycache__/wav_io.cpython-38.pyc +0 -0
- sound_extraction/utils/create_mixtures.py +98 -0
- sound_extraction/utils/stft.py +159 -0
- sound_extraction/utils/wav_io.py +23 -0
sound_extraction/model/LASSNet.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from .text_encoder import Text_Encoder
|
| 5 |
+
from .resunet_film import UNetRes_FiLM
|
| 6 |
+
|
| 7 |
+
class LASSNet(nn.Module):
|
| 8 |
+
def __init__(self, device='cuda'):
|
| 9 |
+
super(LASSNet, self).__init__()
|
| 10 |
+
self.text_embedder = Text_Encoder(device)
|
| 11 |
+
self.UNet = UNetRes_FiLM(channels=1, cond_embedding_dim=256)
|
| 12 |
+
|
| 13 |
+
def forward(self, x, caption):
|
| 14 |
+
# x: (Batch, 1, T, 128))
|
| 15 |
+
input_ids, attns_mask = self.text_embedder.tokenize(caption)
|
| 16 |
+
|
| 17 |
+
cond_vec = self.text_embedder(input_ids, attns_mask)[0]
|
| 18 |
+
dec_cond_vec = cond_vec
|
| 19 |
+
|
| 20 |
+
mask = self.UNet(x, cond_vec, dec_cond_vec)
|
| 21 |
+
mask = torch.sigmoid(mask)
|
| 22 |
+
return mask
|
| 23 |
+
|
| 24 |
+
def get_tokenizer(self):
|
| 25 |
+
return self.text_embedder.tokenizer
|
sound_extraction/model/__pycache__/LASSNet.cpython-38.pyc
ADDED
|
Binary file (1.27 kB). View file
|
|
|
sound_extraction/model/__pycache__/film.cpython-38.pyc
ADDED
|
Binary file (1.26 kB). View file
|
|
|
sound_extraction/model/__pycache__/modules.cpython-38.pyc
ADDED
|
Binary file (14.7 kB). View file
|
|
|
sound_extraction/model/__pycache__/resunet_film.cpython-38.pyc
ADDED
|
Binary file (3.26 kB). View file
|
|
|
sound_extraction/model/__pycache__/text_encoder.cpython-38.pyc
ADDED
|
Binary file (1.69 kB). View file
|
|
|
sound_extraction/model/film.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
class Film(nn.Module):
|
| 5 |
+
def __init__(self, channels, cond_embedding_dim):
|
| 6 |
+
super(Film, self).__init__()
|
| 7 |
+
self.linear = nn.Sequential(
|
| 8 |
+
nn.Linear(cond_embedding_dim, channels * 2),
|
| 9 |
+
nn.ReLU(inplace=True),
|
| 10 |
+
nn.Linear(channels * 2, channels),
|
| 11 |
+
nn.ReLU(inplace=True)
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
def forward(self, data, cond_vec):
|
| 15 |
+
"""
|
| 16 |
+
:param data: [batchsize, channels, samples] or [batchsize, channels, T, F] or [batchsize, channels, F, T]
|
| 17 |
+
:param cond_vec: [batchsize, cond_embedding_dim]
|
| 18 |
+
:return:
|
| 19 |
+
"""
|
| 20 |
+
bias = self.linear(cond_vec) # [batchsize, channels]
|
| 21 |
+
if len(list(data.size())) == 3:
|
| 22 |
+
data = data + bias[..., None]
|
| 23 |
+
elif len(list(data.size())) == 4:
|
| 24 |
+
data = data + bias[..., None, None]
|
| 25 |
+
else:
|
| 26 |
+
print("Warning: The size of input tensor,", data.size(), "is not correct. Film is not working.")
|
| 27 |
+
return data
|
sound_extraction/model/modules.py
ADDED
|
@@ -0,0 +1,483 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import math
|
| 5 |
+
from .film import Film
|
| 6 |
+
|
| 7 |
+
class ConvBlock(nn.Module):
|
| 8 |
+
def __init__(self, in_channels, out_channels, kernel_size, activation, momentum):
|
| 9 |
+
super(ConvBlock, self).__init__()
|
| 10 |
+
|
| 11 |
+
self.activation = activation
|
| 12 |
+
padding = (kernel_size[0] // 2, kernel_size[1] // 2)
|
| 13 |
+
|
| 14 |
+
self.conv1 = nn.Conv2d(
|
| 15 |
+
in_channels=in_channels,
|
| 16 |
+
out_channels=out_channels,
|
| 17 |
+
kernel_size=kernel_size,
|
| 18 |
+
stride=(1, 1),
|
| 19 |
+
dilation=(1, 1),
|
| 20 |
+
padding=padding,
|
| 21 |
+
bias=False,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
self.bn1 = nn.BatchNorm2d(out_channels, momentum=momentum)
|
| 25 |
+
|
| 26 |
+
self.conv2 = nn.Conv2d(
|
| 27 |
+
in_channels=out_channels,
|
| 28 |
+
out_channels=out_channels,
|
| 29 |
+
kernel_size=kernel_size,
|
| 30 |
+
stride=(1, 1),
|
| 31 |
+
dilation=(1, 1),
|
| 32 |
+
padding=padding,
|
| 33 |
+
bias=False,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
self.bn2 = nn.BatchNorm2d(out_channels, momentum=momentum)
|
| 37 |
+
|
| 38 |
+
self.init_weights()
|
| 39 |
+
|
| 40 |
+
def init_weights(self):
|
| 41 |
+
init_layer(self.conv1)
|
| 42 |
+
init_layer(self.conv2)
|
| 43 |
+
init_bn(self.bn1)
|
| 44 |
+
init_bn(self.bn2)
|
| 45 |
+
|
| 46 |
+
def forward(self, x):
|
| 47 |
+
x = act(self.bn1(self.conv1(x)), self.activation)
|
| 48 |
+
x = act(self.bn2(self.conv2(x)), self.activation)
|
| 49 |
+
return x
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class EncoderBlock(nn.Module):
|
| 53 |
+
def __init__(self, in_channels, out_channels, kernel_size, downsample, activation, momentum):
|
| 54 |
+
super(EncoderBlock, self).__init__()
|
| 55 |
+
|
| 56 |
+
self.conv_block = ConvBlock(
|
| 57 |
+
in_channels, out_channels, kernel_size, activation, momentum
|
| 58 |
+
)
|
| 59 |
+
self.downsample = downsample
|
| 60 |
+
|
| 61 |
+
def forward(self, x):
|
| 62 |
+
encoder = self.conv_block(x)
|
| 63 |
+
encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample)
|
| 64 |
+
return encoder_pool, encoder
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class DecoderBlock(nn.Module):
|
| 68 |
+
def __init__(self, in_channels, out_channels, kernel_size, upsample, activation, momentum):
|
| 69 |
+
super(DecoderBlock, self).__init__()
|
| 70 |
+
self.kernel_size = kernel_size
|
| 71 |
+
self.stride = upsample
|
| 72 |
+
self.activation = activation
|
| 73 |
+
|
| 74 |
+
self.conv1 = torch.nn.ConvTranspose2d(
|
| 75 |
+
in_channels=in_channels,
|
| 76 |
+
out_channels=out_channels,
|
| 77 |
+
kernel_size=self.stride,
|
| 78 |
+
stride=self.stride,
|
| 79 |
+
padding=(0, 0),
|
| 80 |
+
bias=False,
|
| 81 |
+
dilation=(1, 1),
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
self.bn1 = nn.BatchNorm2d(out_channels, momentum=momentum)
|
| 85 |
+
|
| 86 |
+
self.conv_block2 = ConvBlock(
|
| 87 |
+
out_channels * 2, out_channels, kernel_size, activation, momentum
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
def init_weights(self):
|
| 91 |
+
init_layer(self.conv1)
|
| 92 |
+
init_bn(self.bn)
|
| 93 |
+
|
| 94 |
+
def prune(self, x):
|
| 95 |
+
"""Prune the shape of x after transpose convolution."""
|
| 96 |
+
padding = (self.kernel_size[0] // 2, self.kernel_size[1] // 2)
|
| 97 |
+
x = x[
|
| 98 |
+
:,
|
| 99 |
+
:,
|
| 100 |
+
padding[0] : padding[0] - self.stride[0],
|
| 101 |
+
padding[1] : padding[1] - self.stride[1]]
|
| 102 |
+
return x
|
| 103 |
+
|
| 104 |
+
def forward(self, input_tensor, concat_tensor):
|
| 105 |
+
x = act(self.bn1(self.conv1(input_tensor)), self.activation)
|
| 106 |
+
# from IPython import embed; embed(using=False); os._exit(0)
|
| 107 |
+
# x = self.prune(x)
|
| 108 |
+
x = torch.cat((x, concat_tensor), dim=1)
|
| 109 |
+
x = self.conv_block2(x)
|
| 110 |
+
return x
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class EncoderBlockRes1B(nn.Module):
|
| 114 |
+
def __init__(self, in_channels, out_channels, downsample, activation, momentum):
|
| 115 |
+
super(EncoderBlockRes1B, self).__init__()
|
| 116 |
+
size = (3,3)
|
| 117 |
+
|
| 118 |
+
self.conv_block1 = ConvBlockRes(in_channels, out_channels, size, activation, momentum)
|
| 119 |
+
self.conv_block2 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
|
| 120 |
+
self.conv_block3 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
|
| 121 |
+
self.conv_block4 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
|
| 122 |
+
self.downsample = downsample
|
| 123 |
+
|
| 124 |
+
def forward(self, x):
|
| 125 |
+
encoder = self.conv_block1(x)
|
| 126 |
+
encoder = self.conv_block2(encoder)
|
| 127 |
+
encoder = self.conv_block3(encoder)
|
| 128 |
+
encoder = self.conv_block4(encoder)
|
| 129 |
+
encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample)
|
| 130 |
+
return encoder_pool, encoder
|
| 131 |
+
|
| 132 |
+
class DecoderBlockRes1B(nn.Module):
|
| 133 |
+
def __init__(self, in_channels, out_channels, stride, activation, momentum):
|
| 134 |
+
super(DecoderBlockRes1B, self).__init__()
|
| 135 |
+
size = (3,3)
|
| 136 |
+
self.activation = activation
|
| 137 |
+
|
| 138 |
+
self.conv1 = torch.nn.ConvTranspose2d(in_channels=in_channels,
|
| 139 |
+
out_channels=out_channels, kernel_size=size, stride=stride,
|
| 140 |
+
padding=(0, 0), output_padding=(0, 0), bias=False, dilation=1)
|
| 141 |
+
|
| 142 |
+
self.bn1 = nn.BatchNorm2d(in_channels)
|
| 143 |
+
self.conv_block2 = ConvBlockRes(out_channels * 2, out_channels, size, activation, momentum)
|
| 144 |
+
self.conv_block3 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
|
| 145 |
+
self.conv_block4 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
|
| 146 |
+
self.conv_block5 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
|
| 147 |
+
|
| 148 |
+
def init_weights(self):
|
| 149 |
+
init_layer(self.conv1)
|
| 150 |
+
|
| 151 |
+
def prune(self, x, both=False):
|
| 152 |
+
"""Prune the shape of x after transpose convolution.
|
| 153 |
+
"""
|
| 154 |
+
if(both): x = x[:, :, 0 : - 1, 0:-1]
|
| 155 |
+
else: x = x[:, :, 0: - 1, :]
|
| 156 |
+
return x
|
| 157 |
+
|
| 158 |
+
def forward(self, input_tensor, concat_tensor,both=False):
|
| 159 |
+
x = self.conv1(F.relu_(self.bn1(input_tensor)))
|
| 160 |
+
x = self.prune(x,both=both)
|
| 161 |
+
x = torch.cat((x, concat_tensor), dim=1)
|
| 162 |
+
x = self.conv_block2(x)
|
| 163 |
+
x = self.conv_block3(x)
|
| 164 |
+
x = self.conv_block4(x)
|
| 165 |
+
x = self.conv_block5(x)
|
| 166 |
+
return x
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class EncoderBlockRes2BCond(nn.Module):
|
| 170 |
+
def __init__(self, in_channels, out_channels, downsample, activation, momentum, cond_embedding_dim):
|
| 171 |
+
super(EncoderBlockRes2BCond, self).__init__()
|
| 172 |
+
size = (3, 3)
|
| 173 |
+
|
| 174 |
+
self.conv_block1 = ConvBlockResCond(in_channels, out_channels, size, activation, momentum, cond_embedding_dim)
|
| 175 |
+
self.conv_block2 = ConvBlockResCond(out_channels, out_channels, size, activation, momentum, cond_embedding_dim)
|
| 176 |
+
self.downsample = downsample
|
| 177 |
+
|
| 178 |
+
def forward(self, x, cond_vec):
|
| 179 |
+
encoder = self.conv_block1(x, cond_vec)
|
| 180 |
+
encoder = self.conv_block2(encoder, cond_vec)
|
| 181 |
+
encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample)
|
| 182 |
+
return encoder_pool, encoder
|
| 183 |
+
|
| 184 |
+
class DecoderBlockRes2BCond(nn.Module):
|
| 185 |
+
def __init__(self, in_channels, out_channels, stride, activation, momentum, cond_embedding_dim):
|
| 186 |
+
super(DecoderBlockRes2BCond, self).__init__()
|
| 187 |
+
size = (3, 3)
|
| 188 |
+
self.activation = activation
|
| 189 |
+
|
| 190 |
+
self.conv1 = torch.nn.ConvTranspose2d(in_channels=in_channels,
|
| 191 |
+
out_channels=out_channels, kernel_size=size, stride=stride,
|
| 192 |
+
padding=(0, 0), output_padding=(0, 0), bias=False, dilation=1)
|
| 193 |
+
|
| 194 |
+
self.bn1 = nn.BatchNorm2d(in_channels)
|
| 195 |
+
self.conv_block2 = ConvBlockResCond(out_channels * 2, out_channels, size, activation, momentum, cond_embedding_dim)
|
| 196 |
+
self.conv_block3 = ConvBlockResCond(out_channels, out_channels, size, activation, momentum, cond_embedding_dim)
|
| 197 |
+
|
| 198 |
+
def init_weights(self):
|
| 199 |
+
init_layer(self.conv1)
|
| 200 |
+
|
| 201 |
+
def prune(self, x, both=False):
|
| 202 |
+
"""Prune the shape of x after transpose convolution.
|
| 203 |
+
"""
|
| 204 |
+
if(both): x = x[:, :, 0 : - 1, 0:-1]
|
| 205 |
+
else: x = x[:, :, 0: - 1, :]
|
| 206 |
+
return x
|
| 207 |
+
|
| 208 |
+
def forward(self, input_tensor, concat_tensor, cond_vec, both=False):
|
| 209 |
+
x = self.conv1(F.relu_(self.bn1(input_tensor)))
|
| 210 |
+
x = self.prune(x, both=both)
|
| 211 |
+
x = torch.cat((x, concat_tensor), dim=1)
|
| 212 |
+
x = self.conv_block2(x, cond_vec)
|
| 213 |
+
x = self.conv_block3(x, cond_vec)
|
| 214 |
+
return x
|
| 215 |
+
|
| 216 |
+
class EncoderBlockRes4BCond(nn.Module):
|
| 217 |
+
def __init__(self, in_channels, out_channels, downsample, activation, momentum, cond_embedding_dim):
|
| 218 |
+
super(EncoderBlockRes4B, self).__init__()
|
| 219 |
+
size = (3,3)
|
| 220 |
+
|
| 221 |
+
self.conv_block1 = ConvBlockResCond(in_channels, out_channels, size, activation, momentum, cond_embedding_dim)
|
| 222 |
+
self.conv_block2 = ConvBlockResCond(out_channels, out_channels, size, activation, momentum, cond_embedding_dim)
|
| 223 |
+
self.conv_block3 = ConvBlockResCond(out_channels, out_channels, size, activation, momentum, cond_embedding_dim)
|
| 224 |
+
self.conv_block4 = ConvBlockResCond(out_channels, out_channels, size, activation, momentum, cond_embedding_dim)
|
| 225 |
+
self.downsample = downsample
|
| 226 |
+
|
| 227 |
+
def forward(self, x, cond_vec):
|
| 228 |
+
encoder = self.conv_block1(x, cond_vec)
|
| 229 |
+
encoder = self.conv_block2(encoder, cond_vec)
|
| 230 |
+
encoder = self.conv_block3(encoder, cond_vec)
|
| 231 |
+
encoder = self.conv_block4(encoder, cond_vec)
|
| 232 |
+
encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample)
|
| 233 |
+
return encoder_pool, encoder
|
| 234 |
+
|
| 235 |
+
class DecoderBlockRes4BCond(nn.Module):
|
| 236 |
+
def __init__(self, in_channels, out_channels, stride, activation, momentum, cond_embedding_dim):
|
| 237 |
+
super(DecoderBlockRes4B, self).__init__()
|
| 238 |
+
size = (3, 3)
|
| 239 |
+
self.activation = activation
|
| 240 |
+
|
| 241 |
+
self.conv1 = torch.nn.ConvTranspose2d(in_channels=in_channels,
|
| 242 |
+
out_channels=out_channels, kernel_size=size, stride=stride,
|
| 243 |
+
padding=(0, 0), output_padding=(0, 0), bias=False, dilation=1)
|
| 244 |
+
|
| 245 |
+
self.bn1 = nn.BatchNorm2d(in_channels)
|
| 246 |
+
self.conv_block2 = ConvBlockResCond(out_channels * 2, out_channels, size, activation, momentum, cond_embedding_dim)
|
| 247 |
+
self.conv_block3 = ConvBlockResCond(out_channels, out_channels, size, activation, momentum, cond_embedding_dim)
|
| 248 |
+
self.conv_block4 = ConvBlockResCond(out_channels, out_channels, size, activation, momentum, cond_embedding_dim)
|
| 249 |
+
self.conv_block5 = ConvBlockResCond(out_channels, out_channels, size, activation, momentum, cond_embedding_dim)
|
| 250 |
+
|
| 251 |
+
def init_weights(self):
|
| 252 |
+
init_layer(self.conv1)
|
| 253 |
+
|
| 254 |
+
def prune(self, x, both=False):
|
| 255 |
+
"""Prune the shape of x after transpose convolution.
|
| 256 |
+
"""
|
| 257 |
+
if(both): x = x[:, :, 0 : - 1, 0:-1]
|
| 258 |
+
else: x = x[:, :, 0: - 1, :]
|
| 259 |
+
return x
|
| 260 |
+
|
| 261 |
+
def forward(self, input_tensor, concat_tensor, cond_vec, both=False):
|
| 262 |
+
x = self.conv1(F.relu_(self.bn1(input_tensor)))
|
| 263 |
+
x = self.prune(x,both=both)
|
| 264 |
+
x = torch.cat((x, concat_tensor), dim=1)
|
| 265 |
+
x = self.conv_block2(x, cond_vec)
|
| 266 |
+
x = self.conv_block3(x, cond_vec)
|
| 267 |
+
x = self.conv_block4(x, cond_vec)
|
| 268 |
+
x = self.conv_block5(x, cond_vec)
|
| 269 |
+
return x
|
| 270 |
+
|
| 271 |
+
class EncoderBlockRes4B(nn.Module):
|
| 272 |
+
def __init__(self, in_channels, out_channels, downsample, activation, momentum):
|
| 273 |
+
super(EncoderBlockRes4B, self).__init__()
|
| 274 |
+
size = (3, 3)
|
| 275 |
+
|
| 276 |
+
self.conv_block1 = ConvBlockRes(in_channels, out_channels, size, activation, momentum)
|
| 277 |
+
self.conv_block2 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
|
| 278 |
+
self.conv_block3 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
|
| 279 |
+
self.conv_block4 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
|
| 280 |
+
self.downsample = downsample
|
| 281 |
+
|
| 282 |
+
def forward(self, x):
|
| 283 |
+
encoder = self.conv_block1(x)
|
| 284 |
+
encoder = self.conv_block2(encoder)
|
| 285 |
+
encoder = self.conv_block3(encoder)
|
| 286 |
+
encoder = self.conv_block4(encoder)
|
| 287 |
+
encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample)
|
| 288 |
+
return encoder_pool, encoder
|
| 289 |
+
|
| 290 |
+
class DecoderBlockRes4B(nn.Module):
|
| 291 |
+
def __init__(self, in_channels, out_channels, stride, activation, momentum):
|
| 292 |
+
super(DecoderBlockRes4B, self).__init__()
|
| 293 |
+
size = (3,3)
|
| 294 |
+
self.activation = activation
|
| 295 |
+
|
| 296 |
+
self.conv1 = torch.nn.ConvTranspose2d(in_channels=in_channels,
|
| 297 |
+
out_channels=out_channels, kernel_size=size, stride=stride,
|
| 298 |
+
padding=(0, 0), output_padding=(0, 0), bias=False, dilation=1)
|
| 299 |
+
|
| 300 |
+
self.bn1 = nn.BatchNorm2d(in_channels)
|
| 301 |
+
self.conv_block2 = ConvBlockRes(out_channels * 2, out_channels, size, activation, momentum)
|
| 302 |
+
self.conv_block3 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
|
| 303 |
+
self.conv_block4 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
|
| 304 |
+
self.conv_block5 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
|
| 305 |
+
|
| 306 |
+
def init_weights(self):
|
| 307 |
+
init_layer(self.conv1)
|
| 308 |
+
|
| 309 |
+
def prune(self, x, both=False):
|
| 310 |
+
"""Prune the shape of x after transpose convolution.
|
| 311 |
+
"""
|
| 312 |
+
if(both): x = x[:, :, 0 : - 1, 0:-1]
|
| 313 |
+
else: x = x[:, :, 0: - 1, :]
|
| 314 |
+
return x
|
| 315 |
+
|
| 316 |
+
def forward(self, input_tensor, concat_tensor,both=False):
|
| 317 |
+
x = self.conv1(F.relu_(self.bn1(input_tensor)))
|
| 318 |
+
x = self.prune(x,both=both)
|
| 319 |
+
x = torch.cat((x, concat_tensor), dim=1)
|
| 320 |
+
x = self.conv_block2(x)
|
| 321 |
+
x = self.conv_block3(x)
|
| 322 |
+
x = self.conv_block4(x)
|
| 323 |
+
x = self.conv_block5(x)
|
| 324 |
+
return x
|
| 325 |
+
|
| 326 |
+
class ConvBlockResCond(nn.Module):
|
| 327 |
+
def __init__(self, in_channels, out_channels, kernel_size, activation, momentum, cond_embedding_dim):
|
| 328 |
+
r"""Residual block.
|
| 329 |
+
"""
|
| 330 |
+
super(ConvBlockResCond, self).__init__()
|
| 331 |
+
|
| 332 |
+
self.activation = activation
|
| 333 |
+
padding = [kernel_size[0] // 2, kernel_size[1] // 2]
|
| 334 |
+
|
| 335 |
+
self.bn1 = nn.BatchNorm2d(in_channels)
|
| 336 |
+
self.bn2 = nn.BatchNorm2d(out_channels)
|
| 337 |
+
|
| 338 |
+
self.conv1 = nn.Conv2d(in_channels=in_channels,
|
| 339 |
+
out_channels=out_channels,
|
| 340 |
+
kernel_size=kernel_size, stride=(1, 1),
|
| 341 |
+
dilation=(1, 1), padding=padding, bias=False)
|
| 342 |
+
self.film1 = Film(channels=out_channels, cond_embedding_dim=cond_embedding_dim)
|
| 343 |
+
self.conv2 = nn.Conv2d(in_channels=out_channels,
|
| 344 |
+
out_channels=out_channels,
|
| 345 |
+
kernel_size=kernel_size, stride=(1, 1),
|
| 346 |
+
dilation=(1, 1), padding=padding, bias=False)
|
| 347 |
+
self.film2 = Film(channels=out_channels, cond_embedding_dim=cond_embedding_dim)
|
| 348 |
+
|
| 349 |
+
if in_channels != out_channels:
|
| 350 |
+
self.shortcut = nn.Conv2d(in_channels=in_channels,
|
| 351 |
+
out_channels=out_channels, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
|
| 352 |
+
self.film_res = Film(channels=out_channels, cond_embedding_dim=cond_embedding_dim)
|
| 353 |
+
self.is_shortcut = True
|
| 354 |
+
else:
|
| 355 |
+
self.is_shortcut = False
|
| 356 |
+
|
| 357 |
+
self.init_weights()
|
| 358 |
+
|
| 359 |
+
def init_weights(self):
|
| 360 |
+
init_bn(self.bn1)
|
| 361 |
+
init_bn(self.bn2)
|
| 362 |
+
init_layer(self.conv1)
|
| 363 |
+
init_layer(self.conv2)
|
| 364 |
+
|
| 365 |
+
if self.is_shortcut:
|
| 366 |
+
init_layer(self.shortcut)
|
| 367 |
+
|
| 368 |
+
def forward(self, x, cond_vec):
|
| 369 |
+
origin = x
|
| 370 |
+
x = self.conv1(F.leaky_relu_(self.bn1(x), negative_slope=0.01))
|
| 371 |
+
x = self.film1(x, cond_vec)
|
| 372 |
+
x = self.conv2(F.leaky_relu_(self.bn2(x), negative_slope=0.01))
|
| 373 |
+
x = self.film2(x, cond_vec)
|
| 374 |
+
if self.is_shortcut:
|
| 375 |
+
residual = self.shortcut(origin)
|
| 376 |
+
residual = self.film_res(residual, cond_vec)
|
| 377 |
+
return residual + x
|
| 378 |
+
else:
|
| 379 |
+
return origin + x
|
| 380 |
+
|
| 381 |
+
class ConvBlockRes(nn.Module):
|
| 382 |
+
def __init__(self, in_channels, out_channels, kernel_size, activation, momentum):
|
| 383 |
+
r"""Residual block.
|
| 384 |
+
"""
|
| 385 |
+
super(ConvBlockRes, self).__init__()
|
| 386 |
+
|
| 387 |
+
self.activation = activation
|
| 388 |
+
padding = [kernel_size[0] // 2, kernel_size[1] // 2]
|
| 389 |
+
|
| 390 |
+
self.bn1 = nn.BatchNorm2d(in_channels)
|
| 391 |
+
self.bn2 = nn.BatchNorm2d(out_channels)
|
| 392 |
+
|
| 393 |
+
self.conv1 = nn.Conv2d(in_channels=in_channels,
|
| 394 |
+
out_channels=out_channels,
|
| 395 |
+
kernel_size=kernel_size, stride=(1, 1),
|
| 396 |
+
dilation=(1, 1), padding=padding, bias=False)
|
| 397 |
+
|
| 398 |
+
self.conv2 = nn.Conv2d(in_channels=out_channels,
|
| 399 |
+
out_channels=out_channels,
|
| 400 |
+
kernel_size=kernel_size, stride=(1, 1),
|
| 401 |
+
dilation=(1, 1), padding=padding, bias=False)
|
| 402 |
+
|
| 403 |
+
if in_channels != out_channels:
|
| 404 |
+
self.shortcut = nn.Conv2d(in_channels=in_channels,
|
| 405 |
+
out_channels=out_channels, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
|
| 406 |
+
self.is_shortcut = True
|
| 407 |
+
else:
|
| 408 |
+
self.is_shortcut = False
|
| 409 |
+
|
| 410 |
+
self.init_weights()
|
| 411 |
+
|
| 412 |
+
def init_weights(self):
|
| 413 |
+
init_bn(self.bn1)
|
| 414 |
+
init_bn(self.bn2)
|
| 415 |
+
init_layer(self.conv1)
|
| 416 |
+
init_layer(self.conv2)
|
| 417 |
+
|
| 418 |
+
if self.is_shortcut:
|
| 419 |
+
init_layer(self.shortcut)
|
| 420 |
+
|
| 421 |
+
def forward(self, x):
|
| 422 |
+
origin = x
|
| 423 |
+
x = self.conv1(F.leaky_relu_(self.bn1(x), negative_slope=0.01))
|
| 424 |
+
x = self.conv2(F.leaky_relu_(self.bn2(x), negative_slope=0.01))
|
| 425 |
+
|
| 426 |
+
if self.is_shortcut:
|
| 427 |
+
return self.shortcut(origin) + x
|
| 428 |
+
else:
|
| 429 |
+
return origin + x
|
| 430 |
+
|
| 431 |
+
def init_layer(layer):
|
| 432 |
+
"""Initialize a Linear or Convolutional layer. """
|
| 433 |
+
nn.init.xavier_uniform_(layer.weight)
|
| 434 |
+
|
| 435 |
+
if hasattr(layer, 'bias'):
|
| 436 |
+
if layer.bias is not None:
|
| 437 |
+
layer.bias.data.fill_(0.)
|
| 438 |
+
|
| 439 |
+
def init_bn(bn):
|
| 440 |
+
"""Initialize a Batchnorm layer. """
|
| 441 |
+
bn.bias.data.fill_(0.)
|
| 442 |
+
bn.weight.data.fill_(1.)
|
| 443 |
+
|
| 444 |
+
def init_gru(rnn):
|
| 445 |
+
"""Initialize a GRU layer. """
|
| 446 |
+
|
| 447 |
+
def _concat_init(tensor, init_funcs):
|
| 448 |
+
(length, fan_out) = tensor.shape
|
| 449 |
+
fan_in = length // len(init_funcs)
|
| 450 |
+
|
| 451 |
+
for (i, init_func) in enumerate(init_funcs):
|
| 452 |
+
init_func(tensor[i * fan_in: (i + 1) * fan_in, :])
|
| 453 |
+
|
| 454 |
+
def _inner_uniform(tensor):
|
| 455 |
+
fan_in = nn.init._calculate_correct_fan(tensor, 'fan_in')
|
| 456 |
+
nn.init.uniform_(tensor, -math.sqrt(3 / fan_in), math.sqrt(3 / fan_in))
|
| 457 |
+
|
| 458 |
+
for i in range(rnn.num_layers):
|
| 459 |
+
_concat_init(
|
| 460 |
+
getattr(rnn, 'weight_ih_l{}'.format(i)),
|
| 461 |
+
[_inner_uniform, _inner_uniform, _inner_uniform]
|
| 462 |
+
)
|
| 463 |
+
torch.nn.init.constant_(getattr(rnn, 'bias_ih_l{}'.format(i)), 0)
|
| 464 |
+
|
| 465 |
+
_concat_init(
|
| 466 |
+
getattr(rnn, 'weight_hh_l{}'.format(i)),
|
| 467 |
+
[_inner_uniform, _inner_uniform, nn.init.orthogonal_]
|
| 468 |
+
)
|
| 469 |
+
torch.nn.init.constant_(getattr(rnn, 'bias_hh_l{}'.format(i)), 0)
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
def act(x, activation):
|
| 473 |
+
if activation == 'relu':
|
| 474 |
+
return F.relu_(x)
|
| 475 |
+
|
| 476 |
+
elif activation == 'leaky_relu':
|
| 477 |
+
return F.leaky_relu_(x, negative_slope=0.2)
|
| 478 |
+
|
| 479 |
+
elif activation == 'swish':
|
| 480 |
+
return x * torch.sigmoid(x)
|
| 481 |
+
|
| 482 |
+
else:
|
| 483 |
+
raise Exception('Incorrect activation!')
|
sound_extraction/model/resunet_film.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .modules import *
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
class UNetRes_FiLM(nn.Module):
|
| 5 |
+
def __init__(self, channels, cond_embedding_dim, nsrc=1):
|
| 6 |
+
super(UNetRes_FiLM, self).__init__()
|
| 7 |
+
activation = 'relu'
|
| 8 |
+
momentum = 0.01
|
| 9 |
+
|
| 10 |
+
self.nsrc = nsrc
|
| 11 |
+
self.channels = channels
|
| 12 |
+
self.downsample_ratio = 2 ** 6 # This number equals 2^{#encoder_blocks}
|
| 13 |
+
|
| 14 |
+
self.encoder_block1 = EncoderBlockRes2BCond(in_channels=channels * nsrc, out_channels=32,
|
| 15 |
+
downsample=(2, 2), activation=activation, momentum=momentum,
|
| 16 |
+
cond_embedding_dim=cond_embedding_dim)
|
| 17 |
+
self.encoder_block2 = EncoderBlockRes2BCond(in_channels=32, out_channels=64,
|
| 18 |
+
downsample=(2, 2), activation=activation, momentum=momentum,
|
| 19 |
+
cond_embedding_dim=cond_embedding_dim)
|
| 20 |
+
self.encoder_block3 = EncoderBlockRes2BCond(in_channels=64, out_channels=128,
|
| 21 |
+
downsample=(2, 2), activation=activation, momentum=momentum,
|
| 22 |
+
cond_embedding_dim=cond_embedding_dim)
|
| 23 |
+
self.encoder_block4 = EncoderBlockRes2BCond(in_channels=128, out_channels=256,
|
| 24 |
+
downsample=(2, 2), activation=activation, momentum=momentum,
|
| 25 |
+
cond_embedding_dim=cond_embedding_dim)
|
| 26 |
+
self.encoder_block5 = EncoderBlockRes2BCond(in_channels=256, out_channels=384,
|
| 27 |
+
downsample=(2, 2), activation=activation, momentum=momentum,
|
| 28 |
+
cond_embedding_dim=cond_embedding_dim)
|
| 29 |
+
self.encoder_block6 = EncoderBlockRes2BCond(in_channels=384, out_channels=384,
|
| 30 |
+
downsample=(2, 2), activation=activation, momentum=momentum,
|
| 31 |
+
cond_embedding_dim=cond_embedding_dim)
|
| 32 |
+
self.conv_block7 = ConvBlockResCond(in_channels=384, out_channels=384,
|
| 33 |
+
kernel_size=(3, 3), activation=activation, momentum=momentum,
|
| 34 |
+
cond_embedding_dim=cond_embedding_dim)
|
| 35 |
+
self.decoder_block1 = DecoderBlockRes2BCond(in_channels=384, out_channels=384,
|
| 36 |
+
stride=(2, 2), activation=activation, momentum=momentum,
|
| 37 |
+
cond_embedding_dim=cond_embedding_dim)
|
| 38 |
+
self.decoder_block2 = DecoderBlockRes2BCond(in_channels=384, out_channels=384,
|
| 39 |
+
stride=(2, 2), activation=activation, momentum=momentum,
|
| 40 |
+
cond_embedding_dim=cond_embedding_dim)
|
| 41 |
+
self.decoder_block3 = DecoderBlockRes2BCond(in_channels=384, out_channels=256,
|
| 42 |
+
stride=(2, 2), activation=activation, momentum=momentum,
|
| 43 |
+
cond_embedding_dim=cond_embedding_dim)
|
| 44 |
+
self.decoder_block4 = DecoderBlockRes2BCond(in_channels=256, out_channels=128,
|
| 45 |
+
stride=(2, 2), activation=activation, momentum=momentum,
|
| 46 |
+
cond_embedding_dim=cond_embedding_dim)
|
| 47 |
+
self.decoder_block5 = DecoderBlockRes2BCond(in_channels=128, out_channels=64,
|
| 48 |
+
stride=(2, 2), activation=activation, momentum=momentum,
|
| 49 |
+
cond_embedding_dim=cond_embedding_dim)
|
| 50 |
+
self.decoder_block6 = DecoderBlockRes2BCond(in_channels=64, out_channels=32,
|
| 51 |
+
stride=(2, 2), activation=activation, momentum=momentum,
|
| 52 |
+
cond_embedding_dim=cond_embedding_dim)
|
| 53 |
+
|
| 54 |
+
self.after_conv_block1 = ConvBlockResCond(in_channels=32, out_channels=32,
|
| 55 |
+
kernel_size=(3, 3), activation=activation, momentum=momentum,
|
| 56 |
+
cond_embedding_dim=cond_embedding_dim)
|
| 57 |
+
|
| 58 |
+
self.after_conv2 = nn.Conv2d(in_channels=32, out_channels=1,
|
| 59 |
+
kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True)
|
| 60 |
+
|
| 61 |
+
self.init_weights()
|
| 62 |
+
|
| 63 |
+
def init_weights(self):
|
| 64 |
+
init_layer(self.after_conv2)
|
| 65 |
+
|
| 66 |
+
def forward(self, sp, cond_vec, dec_cond_vec):
|
| 67 |
+
"""
|
| 68 |
+
Args:
|
| 69 |
+
input: sp: (batch_size, channels_num, segment_samples)
|
| 70 |
+
Outputs:
|
| 71 |
+
output_dict: {
|
| 72 |
+
'wav': (batch_size, channels_num, segment_samples),
|
| 73 |
+
'sp': (batch_size, channels_num, time_steps, freq_bins)}
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
x = sp
|
| 77 |
+
# Pad spectrogram to be evenly divided by downsample ratio.
|
| 78 |
+
origin_len = x.shape[2] # time_steps
|
| 79 |
+
pad_len = int(np.ceil(x.shape[2] / self.downsample_ratio)) * self.downsample_ratio - origin_len
|
| 80 |
+
x = F.pad(x, pad=(0, 0, 0, pad_len))
|
| 81 |
+
x = x[..., 0: x.shape[-1] - 2] # (bs, channels, T, F)
|
| 82 |
+
|
| 83 |
+
# UNet
|
| 84 |
+
(x1_pool, x1) = self.encoder_block1(x, cond_vec) # x1_pool: (bs, 32, T / 2, F / 2)
|
| 85 |
+
(x2_pool, x2) = self.encoder_block2(x1_pool, cond_vec) # x2_pool: (bs, 64, T / 4, F / 4)
|
| 86 |
+
(x3_pool, x3) = self.encoder_block3(x2_pool, cond_vec) # x3_pool: (bs, 128, T / 8, F / 8)
|
| 87 |
+
(x4_pool, x4) = self.encoder_block4(x3_pool, dec_cond_vec) # x4_pool: (bs, 256, T / 16, F / 16)
|
| 88 |
+
(x5_pool, x5) = self.encoder_block5(x4_pool, dec_cond_vec) # x5_pool: (bs, 512, T / 32, F / 32)
|
| 89 |
+
(x6_pool, x6) = self.encoder_block6(x5_pool, dec_cond_vec) # x6_pool: (bs, 1024, T / 64, F / 64)
|
| 90 |
+
x_center = self.conv_block7(x6_pool, dec_cond_vec) # (bs, 2048, T / 64, F / 64)
|
| 91 |
+
x7 = self.decoder_block1(x_center, x6, dec_cond_vec) # (bs, 1024, T / 32, F / 32)
|
| 92 |
+
x8 = self.decoder_block2(x7, x5, dec_cond_vec) # (bs, 512, T / 16, F / 16)
|
| 93 |
+
x9 = self.decoder_block3(x8, x4, cond_vec) # (bs, 256, T / 8, F / 8)
|
| 94 |
+
x10 = self.decoder_block4(x9, x3, cond_vec) # (bs, 128, T / 4, F / 4)
|
| 95 |
+
x11 = self.decoder_block5(x10, x2, cond_vec) # (bs, 64, T / 2, F / 2)
|
| 96 |
+
x12 = self.decoder_block6(x11, x1, cond_vec) # (bs, 32, T, F)
|
| 97 |
+
x = self.after_conv_block1(x12, cond_vec) # (bs, 32, T, F)
|
| 98 |
+
x = self.after_conv2(x) # (bs, channels, T, F)
|
| 99 |
+
|
| 100 |
+
# Recover shape
|
| 101 |
+
x = F.pad(x, pad=(0, 2))
|
| 102 |
+
x = x[:, :, 0: origin_len, :]
|
| 103 |
+
return x
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
if __name__ == "__main__":
|
| 107 |
+
model = UNetRes_FiLM(channels=1, cond_embedding_dim=16)
|
| 108 |
+
cond_vec = torch.randn((1, 16))
|
| 109 |
+
dec_vec = cond_vec
|
| 110 |
+
print(model(torch.randn((1, 1, 1001, 513)), cond_vec, dec_vec).size())
|
sound_extraction/model/text_encoder.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from transformers import *
|
| 4 |
+
import warnings
|
| 5 |
+
warnings.filterwarnings('ignore')
|
| 6 |
+
# pretrained model name: (model class, model tokenizer, output dimension, token style)
|
| 7 |
+
MODELS = {
|
| 8 |
+
'prajjwal1/bert-mini': (BertModel, BertTokenizer),
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
class Text_Encoder(nn.Module):
|
| 12 |
+
def __init__(self, device):
|
| 13 |
+
super(Text_Encoder, self).__init__()
|
| 14 |
+
self.base_model = 'prajjwal1/bert-mini'
|
| 15 |
+
self.dropout = 0.1
|
| 16 |
+
|
| 17 |
+
self.tokenizer = MODELS[self.base_model][1].from_pretrained(self.base_model)
|
| 18 |
+
|
| 19 |
+
self.bert_layer = MODELS[self.base_model][0].from_pretrained(self.base_model,
|
| 20 |
+
add_pooling_layer=False,
|
| 21 |
+
hidden_dropout_prob=self.dropout,
|
| 22 |
+
attention_probs_dropout_prob=self.dropout,
|
| 23 |
+
output_hidden_states=True)
|
| 24 |
+
|
| 25 |
+
self.linear_layer = nn.Sequential(nn.Linear(256, 256), nn.ReLU(inplace=True))
|
| 26 |
+
|
| 27 |
+
self.device = device
|
| 28 |
+
|
| 29 |
+
def tokenize(self, caption):
|
| 30 |
+
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 31 |
+
tokenized = self.tokenizer(caption, add_special_tokens=False, padding=True, return_tensors='pt')
|
| 32 |
+
input_ids = tokenized['input_ids']
|
| 33 |
+
attns_mask = tokenized['attention_mask']
|
| 34 |
+
|
| 35 |
+
input_ids = input_ids.to(self.device)
|
| 36 |
+
attns_mask = attns_mask.to(self.device)
|
| 37 |
+
return input_ids, attns_mask
|
| 38 |
+
|
| 39 |
+
def forward(self, input_ids, attns_mask):
|
| 40 |
+
# input_ids, attns_mask = self.tokenize(caption)
|
| 41 |
+
output = self.bert_layer(input_ids=input_ids, attention_mask=attns_mask)[0]
|
| 42 |
+
cls_embed = output[:, 0, :]
|
| 43 |
+
text_embed = self.linear_layer(cls_embed)
|
| 44 |
+
|
| 45 |
+
return text_embed, output # text_embed: (batch, hidden_size)
|
sound_extraction/useful_ckpts/LASSNet.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2c6a60910bc1db03d9ff7040d0e5906ab784431cb8b279cf4e295124e9e76fae
|
| 3 |
+
size 761532233
|
sound_extraction/utils/__pycache__/stft.cpython-38.pyc
ADDED
|
Binary file (4.76 kB). View file
|
|
|
sound_extraction/utils/__pycache__/wav_io.cpython-38.pyc
ADDED
|
Binary file (823 Bytes). View file
|
|
|
sound_extraction/utils/create_mixtures.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
def add_noise_and_scale(front, noise, snr_l=0, snr_h=0, scale_lower=1.0, scale_upper=1.0):
|
| 5 |
+
"""
|
| 6 |
+
:param front: front-head audio, like vocal [samples,channel], will be normlized so any scale will be fine
|
| 7 |
+
:param noise: noise, [samples,channel], any scale
|
| 8 |
+
:param snr_l: Optional
|
| 9 |
+
:param snr_h: Optional
|
| 10 |
+
:param scale_lower: Optional
|
| 11 |
+
:param scale_upper: Optional
|
| 12 |
+
:return: scaled front and noise (noisy = front + noise), all_mel_e2e outputs are noramlized within [-1 , 1]
|
| 13 |
+
"""
|
| 14 |
+
snr = None
|
| 15 |
+
noise, front = normalize_energy_torch(noise), normalize_energy_torch(front) # set noise and vocal to equal range [-1,1]
|
| 16 |
+
# print("normalize:",torch.max(noise),torch.max(front))
|
| 17 |
+
if snr_l is not None and snr_h is not None:
|
| 18 |
+
front, noise, snr = _random_noise(front, noise, snr_l=snr_l, snr_h=snr_h) # remix them with a specific snr
|
| 19 |
+
|
| 20 |
+
noisy, noise, front = unify_energy_torch(noise + front, noise, front) # normalize noisy, noise and vocal energy into [-1,1]
|
| 21 |
+
|
| 22 |
+
# print("unify:", torch.max(noise), torch.max(front), torch.max(noisy))
|
| 23 |
+
scale = _random_scale(scale_lower, scale_upper) # random scale these three signal
|
| 24 |
+
|
| 25 |
+
# print("Scale",scale)
|
| 26 |
+
noisy, noise, front = noisy * scale, noise * scale, front * scale # apply scale
|
| 27 |
+
# print("after scale", torch.max(noisy), torch.max(noise), torch.max(front), snr, scale)
|
| 28 |
+
|
| 29 |
+
front, noise = _to_numpy(front), _to_numpy(noise) # [num_samples]
|
| 30 |
+
mixed_wav = front + noise
|
| 31 |
+
|
| 32 |
+
return front, noise, mixed_wav, snr, scale
|
| 33 |
+
|
| 34 |
+
def _random_scale(lower=0.3, upper=0.9):
|
| 35 |
+
return float(uniform_torch(lower, upper))
|
| 36 |
+
|
| 37 |
+
def _random_noise(clean, noise, snr_l=None, snr_h=None):
|
| 38 |
+
snr = uniform_torch(snr_l,snr_h)
|
| 39 |
+
clean_weight = 10 ** (float(snr) / 20)
|
| 40 |
+
return clean, noise/clean_weight, snr
|
| 41 |
+
|
| 42 |
+
def _to_numpy(wav):
|
| 43 |
+
return np.transpose(wav, (1, 0))[0].numpy() # [num_samples]
|
| 44 |
+
|
| 45 |
+
def normalize_energy(audio, alpha = 1):
|
| 46 |
+
'''
|
| 47 |
+
:param audio: 1d waveform, [batchsize, *],
|
| 48 |
+
:param alpha: the value of output range from: [-alpha,alpha]
|
| 49 |
+
:return: 1d waveform which value range from: [-alpha,alpha]
|
| 50 |
+
'''
|
| 51 |
+
val_max = activelev(audio)
|
| 52 |
+
return (audio / val_max) * alpha
|
| 53 |
+
|
| 54 |
+
def normalize_energy_torch(audio, alpha = 1):
|
| 55 |
+
'''
|
| 56 |
+
If the signal is almost empty(determined by threshold), if will only be divided by 2**15
|
| 57 |
+
:param audio: 1d waveform, 2**15
|
| 58 |
+
:param alpha: the value of output range from: [-alpha,alpha]
|
| 59 |
+
:return: 1d waveform which value range from: [-alpha,alpha]
|
| 60 |
+
'''
|
| 61 |
+
val_max = activelev_torch([audio])
|
| 62 |
+
return (audio / val_max) * alpha
|
| 63 |
+
|
| 64 |
+
def unify_energy(*args):
|
| 65 |
+
max_amp = activelev(args)
|
| 66 |
+
mix_scale = 1.0/max_amp
|
| 67 |
+
return [x * mix_scale for x in args]
|
| 68 |
+
|
| 69 |
+
def unify_energy_torch(*args):
|
| 70 |
+
max_amp = activelev_torch(args)
|
| 71 |
+
mix_scale = 1.0/max_amp
|
| 72 |
+
return [x * mix_scale for x in args]
|
| 73 |
+
|
| 74 |
+
def activelev(*args):
|
| 75 |
+
'''
|
| 76 |
+
need to update like matlab
|
| 77 |
+
'''
|
| 78 |
+
return np.max(np.abs([*args]))
|
| 79 |
+
|
| 80 |
+
def activelev_torch(*args):
|
| 81 |
+
'''
|
| 82 |
+
need to update like matlab
|
| 83 |
+
'''
|
| 84 |
+
res = []
|
| 85 |
+
args = args[0]
|
| 86 |
+
for each in args:
|
| 87 |
+
res.append(torch.max(torch.abs(each)))
|
| 88 |
+
return max(res)
|
| 89 |
+
|
| 90 |
+
def uniform_torch(lower, upper):
|
| 91 |
+
if(abs(lower-upper)<1e-5):
|
| 92 |
+
return upper
|
| 93 |
+
return (upper-lower)*torch.rand(1)+lower
|
| 94 |
+
|
| 95 |
+
if __name__ == "__main__":
|
| 96 |
+
wav1 = torch.randn(1, 32000)
|
| 97 |
+
wav2 = torch.randn(1, 32000)
|
| 98 |
+
target, noise, snr, scale = add_noise_and_scale(wav1, wav2)
|
sound_extraction/utils/stft.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torch.autograd import Variable
|
| 5 |
+
from scipy.signal import get_window
|
| 6 |
+
import librosa.util as librosa_util
|
| 7 |
+
from librosa.util import pad_center, tiny
|
| 8 |
+
# from audio_processing import window_sumsquare
|
| 9 |
+
|
| 10 |
+
def window_sumsquare(window, n_frames, hop_length=512, win_length=1024,
|
| 11 |
+
n_fft=1024, dtype=np.float32, norm=None):
|
| 12 |
+
"""
|
| 13 |
+
# from librosa 0.6
|
| 14 |
+
Compute the sum-square envelope of a window function at a given hop length.
|
| 15 |
+
This is used to estimate modulation effects induced by windowing
|
| 16 |
+
observations in short-time fourier transforms.
|
| 17 |
+
Parameters
|
| 18 |
+
----------
|
| 19 |
+
window : string, tuple, number, callable, or list-like
|
| 20 |
+
Window specification, as in `get_window`
|
| 21 |
+
n_frames : int > 0
|
| 22 |
+
The number of analysis frames
|
| 23 |
+
hop_length : int > 0
|
| 24 |
+
The number of samples to advance between frames
|
| 25 |
+
win_length : [optional]
|
| 26 |
+
The length of the window function. By default, this matches `n_fft`.
|
| 27 |
+
n_fft : int > 0
|
| 28 |
+
The length of each analysis frame.
|
| 29 |
+
dtype : np.dtype
|
| 30 |
+
The data type of the output
|
| 31 |
+
Returns
|
| 32 |
+
-------
|
| 33 |
+
wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
|
| 34 |
+
The sum-squared envelope of the window function
|
| 35 |
+
"""
|
| 36 |
+
if win_length is None:
|
| 37 |
+
win_length = n_fft
|
| 38 |
+
|
| 39 |
+
n = n_fft + hop_length * (n_frames - 1)
|
| 40 |
+
x = np.zeros(n, dtype=dtype)
|
| 41 |
+
|
| 42 |
+
# Compute the squared window at the desired length
|
| 43 |
+
win_sq = get_window(window, win_length, fftbins=True)
|
| 44 |
+
win_sq = librosa_util.normalize(win_sq, norm=norm)**2
|
| 45 |
+
win_sq = librosa_util.pad_center(win_sq, n_fft)
|
| 46 |
+
|
| 47 |
+
# Fill the envelope
|
| 48 |
+
for i in range(n_frames):
|
| 49 |
+
sample = i * hop_length
|
| 50 |
+
x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))]
|
| 51 |
+
return x
|
| 52 |
+
|
| 53 |
+
class STFT(torch.nn.Module):
|
| 54 |
+
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
|
| 55 |
+
def __init__(self, filter_length=1024, hop_length=512, win_length=1024,
|
| 56 |
+
window='hann'):
|
| 57 |
+
super(STFT, self).__init__()
|
| 58 |
+
self.filter_length = filter_length
|
| 59 |
+
self.hop_length = hop_length
|
| 60 |
+
self.win_length = win_length
|
| 61 |
+
self.window = window
|
| 62 |
+
self.forward_transform = None
|
| 63 |
+
scale = self.filter_length / self.hop_length
|
| 64 |
+
fourier_basis = np.fft.fft(np.eye(self.filter_length))
|
| 65 |
+
|
| 66 |
+
cutoff = int((self.filter_length / 2 + 1))
|
| 67 |
+
fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]),
|
| 68 |
+
np.imag(fourier_basis[:cutoff, :])])
|
| 69 |
+
|
| 70 |
+
forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
|
| 71 |
+
inverse_basis = torch.FloatTensor(
|
| 72 |
+
np.linalg.pinv(scale * fourier_basis).T[:, None, :])
|
| 73 |
+
|
| 74 |
+
if window is not None:
|
| 75 |
+
assert(filter_length >= win_length)
|
| 76 |
+
# get window and zero center pad it to filter_length
|
| 77 |
+
fft_window = get_window(window, win_length, fftbins=True)
|
| 78 |
+
fft_window = pad_center(fft_window, filter_length)
|
| 79 |
+
fft_window = torch.from_numpy(fft_window).float()
|
| 80 |
+
|
| 81 |
+
# window the bases
|
| 82 |
+
forward_basis *= fft_window
|
| 83 |
+
inverse_basis *= fft_window
|
| 84 |
+
|
| 85 |
+
self.register_buffer('forward_basis', forward_basis.float())
|
| 86 |
+
self.register_buffer('inverse_basis', inverse_basis.float())
|
| 87 |
+
|
| 88 |
+
def transform(self, input_data):
|
| 89 |
+
num_batches = input_data.size(0)
|
| 90 |
+
num_samples = input_data.size(1)
|
| 91 |
+
|
| 92 |
+
self.num_samples = num_samples
|
| 93 |
+
|
| 94 |
+
# similar to librosa, reflect-pad the input
|
| 95 |
+
input_data = input_data.view(num_batches, 1, num_samples)
|
| 96 |
+
input_data = F.pad(
|
| 97 |
+
input_data.unsqueeze(1),
|
| 98 |
+
(int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
|
| 99 |
+
mode='reflect')
|
| 100 |
+
input_data = input_data.squeeze(1)
|
| 101 |
+
|
| 102 |
+
forward_transform = F.conv1d(
|
| 103 |
+
input_data,
|
| 104 |
+
Variable(self.forward_basis, requires_grad=False),
|
| 105 |
+
stride=self.hop_length,
|
| 106 |
+
padding=0)
|
| 107 |
+
|
| 108 |
+
cutoff = int((self.filter_length / 2) + 1)
|
| 109 |
+
real_part = forward_transform[:, :cutoff, :]
|
| 110 |
+
imag_part = forward_transform[:, cutoff:, :]
|
| 111 |
+
|
| 112 |
+
magnitude = torch.sqrt(real_part**2 + imag_part**2)
|
| 113 |
+
phase = torch.autograd.Variable(
|
| 114 |
+
torch.atan2(imag_part.data, real_part.data))
|
| 115 |
+
|
| 116 |
+
return magnitude, phase # [batch_size, F(513), T(1251)]
|
| 117 |
+
|
| 118 |
+
def inverse(self, magnitude, phase):
|
| 119 |
+
recombine_magnitude_phase = torch.cat(
|
| 120 |
+
[magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1)
|
| 121 |
+
|
| 122 |
+
inverse_transform = F.conv_transpose1d(
|
| 123 |
+
recombine_magnitude_phase,
|
| 124 |
+
Variable(self.inverse_basis, requires_grad=False),
|
| 125 |
+
stride=self.hop_length,
|
| 126 |
+
padding=0)
|
| 127 |
+
|
| 128 |
+
if self.window is not None:
|
| 129 |
+
window_sum = window_sumsquare(
|
| 130 |
+
self.window, magnitude.size(-1), hop_length=self.hop_length,
|
| 131 |
+
win_length=self.win_length, n_fft=self.filter_length,
|
| 132 |
+
dtype=np.float32)
|
| 133 |
+
# remove modulation effects
|
| 134 |
+
approx_nonzero_indices = torch.from_numpy(
|
| 135 |
+
np.where(window_sum > tiny(window_sum))[0])
|
| 136 |
+
window_sum = torch.autograd.Variable(
|
| 137 |
+
torch.from_numpy(window_sum), requires_grad=False)
|
| 138 |
+
window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum
|
| 139 |
+
inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices]
|
| 140 |
+
|
| 141 |
+
# scale by hop ratio
|
| 142 |
+
inverse_transform *= float(self.filter_length) / self.hop_length
|
| 143 |
+
|
| 144 |
+
inverse_transform = inverse_transform[:, :, int(self.filter_length/2):]
|
| 145 |
+
inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):]
|
| 146 |
+
|
| 147 |
+
return inverse_transform #[batch_size, 1, sample_num]
|
| 148 |
+
|
| 149 |
+
def forward(self, input_data):
|
| 150 |
+
self.magnitude, self.phase = self.transform(input_data)
|
| 151 |
+
reconstruction = self.inverse(self.magnitude, self.phase)
|
| 152 |
+
return reconstruction
|
| 153 |
+
|
| 154 |
+
if __name__ == '__main__':
|
| 155 |
+
a = torch.randn(4, 320000)
|
| 156 |
+
stft = STFT()
|
| 157 |
+
mag, phase = stft.transform(a)
|
| 158 |
+
# rec_a = stft.inverse(mag, phase)
|
| 159 |
+
print(mag.shape)
|
sound_extraction/utils/wav_io.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import librosa
|
| 2 |
+
import librosa.filters
|
| 3 |
+
import math
|
| 4 |
+
import numpy as np
|
| 5 |
+
import scipy.io.wavfile
|
| 6 |
+
|
| 7 |
+
def load_wav(path):
|
| 8 |
+
max_length = 32000 * 10
|
| 9 |
+
wav = librosa.core.load(path, sr=32000)[0]
|
| 10 |
+
if len(wav) > max_length:
|
| 11 |
+
audio = wav[0:max_length]
|
| 12 |
+
|
| 13 |
+
# pad audio to max length, 10s for AudioCaps
|
| 14 |
+
if len(wav) < max_length:
|
| 15 |
+
# audio = torch.nn.functional.pad(audio, (0, self.max_length - audio.size(1)), 'constant')
|
| 16 |
+
wav = np.pad(wav, (0, max_length - len(wav)), 'constant')
|
| 17 |
+
wav = wav[...,None]
|
| 18 |
+
return wav
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def save_wav(wav, path):
|
| 22 |
+
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
|
| 23 |
+
scipy.io.wavfile.write(path, 32000, wav.astype(np.int16))
|