Spaces:
Running
Running
Inital demo
Browse files- LICENSE +25 -0
- app.py +69 -0
- hifigan/config_v1_wavlm.json +40 -0
- hifigan/meldataset.py +208 -0
- hifigan/models.py +289 -0
- hifigan/train.py +335 -0
- hifigan/utils.py +73 -0
- hubconf.py +75 -0
- knnvc_utils.py +23 -0
- matcher.py +172 -0
- prematch_dataset.py +172 -0
- requirements.txt +5 -0
- wavlm/WavLM.py +743 -0
- wavlm/modules.py +827 -0
LICENSE
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2023 MediaLab, Department of Electrical & Electronic Engineering, Stellenbosch University
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software, and you spend at least 10 seconds
|
| 14 |
+
thinking about whether the idea of copyright for Software actually makes sense
|
| 15 |
+
the first time you download the Software.
|
| 16 |
+
|
| 17 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 18 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 19 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 20 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 21 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 22 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 23 |
+
SOFTWARE.
|
| 24 |
+
|
| 25 |
+
|
app.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torchaudio
|
| 3 |
+
import spaces
|
| 4 |
+
from typing import List
|
| 5 |
+
import soundfile as sf
|
| 6 |
+
|
| 7 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 8 |
+
knn_vc = torch.hub.load('bshall/knn-vc', 'knn_vc', prematched=True, trust_repo=True, pretrained=True, device=device)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def convert_voice(src_wav_path:str, ref_wav_paths, top_k:int):
|
| 12 |
+
|
| 13 |
+
query_seq = knn_vc.get_features(src_wav_path)
|
| 14 |
+
matching_set = knn_vc.get_matching_set([ref_wav_paths])
|
| 15 |
+
out_wav = knn_vc.match(query_seq, matching_set, topk=int(top_k))
|
| 16 |
+
|
| 17 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as converted_file:
|
| 18 |
+
sf.write(converted_file.name, out_wav, 16000, "PCM_24")
|
| 19 |
+
|
| 20 |
+
return converted_file.name
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
title = """
|
| 24 |
+
<div style="text-align: center; max-width: 700px; margin: 0 auto;">
|
| 25 |
+
<div
|
| 26 |
+
style="display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem;"
|
| 27 |
+
> <h1 style="font-weight: 900; margin-bottom: 7px; line-height: normal;">
|
| 28 |
+
KNN Voice Conversion
|
| 29 |
+
</h1> </div>
|
| 30 |
+
</div>
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
description = """
|
| 34 |
+
Voice Conversion With Just k-Nearest Neighbors. The source and reference utterance(s) are encoded into self-supervised features using WavLM.
|
| 35 |
+
Each source feature is assigned to the mean of the k closest features from the reference.
|
| 36 |
+
The resulting feature sequence is then vocoded with HiFi-GAN to arrive at the converted waveform output.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
article = """
|
| 40 |
+
If the model contributes to your research please cite the following work:
|
| 41 |
+
|
| 42 |
+
Baas, M., van Niekerk, B., & Kamper, H. (2023). Voice conversion with just nearest neighbors. arXiv preprint arXiv:2305.18975.
|
| 43 |
+
|
| 44 |
+
demo contributed by [@wetdog](https://github.com/wetdog)
|
| 45 |
+
"""
|
| 46 |
+
demo = gr.Blocks()
|
| 47 |
+
with demo:
|
| 48 |
+
gr.Markdown(title)
|
| 49 |
+
gr.Markdown(description)
|
| 50 |
+
gr.Interface(
|
| 51 |
+
fn=convert_voice,
|
| 52 |
+
inputs=[
|
| 53 |
+
gr.Audio(type='filepath'),
|
| 54 |
+
gr.Audio(type='filepath'),
|
| 55 |
+
gr.Slider(
|
| 56 |
+
3,
|
| 57 |
+
10,
|
| 58 |
+
value=4,
|
| 59 |
+
step=1,
|
| 60 |
+
label="Top-k",
|
| 61 |
+
info=f"These default settings provide pretty good results, but feel free to modify the kNN topk",
|
| 62 |
+
)],
|
| 63 |
+
outputs=[gr.Audio(type='filepath')],
|
| 64 |
+
allow_flagging=False,)
|
| 65 |
+
gr.Markdown(article)
|
| 66 |
+
|
| 67 |
+
demo.queue(max_size=10)
|
| 68 |
+
demo.launch(show_api=False, server_name="0.0.0.0", server_port=7860)
|
| 69 |
+
|
hifigan/config_v1_wavlm.json
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"resblock": "1",
|
| 3 |
+
"num_gpus": 0,
|
| 4 |
+
"batch_size": 16,
|
| 5 |
+
"learning_rate": 0.0002,
|
| 6 |
+
"adam_b1": 0.8,
|
| 7 |
+
"adam_b2": 0.99,
|
| 8 |
+
"lr_decay": 0.999,
|
| 9 |
+
"seed": 1234,
|
| 10 |
+
|
| 11 |
+
"upsample_rates": [10,8,2,2],
|
| 12 |
+
"upsample_kernel_sizes": [20,16,4,4],
|
| 13 |
+
"upsample_initial_channel": 512,
|
| 14 |
+
"resblock_kernel_sizes": [3,7,11],
|
| 15 |
+
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
|
| 16 |
+
|
| 17 |
+
"hubert_dim": 1024,
|
| 18 |
+
"hifi_dim": 512,
|
| 19 |
+
|
| 20 |
+
"segment_size": 7040,
|
| 21 |
+
"num_mels": 80,
|
| 22 |
+
"num_freq": 1025,
|
| 23 |
+
"n_fft": 1024,
|
| 24 |
+
"hop_size": 320,
|
| 25 |
+
"win_size": 1024,
|
| 26 |
+
|
| 27 |
+
"sampling_rate": 16000,
|
| 28 |
+
|
| 29 |
+
"fmin": 0,
|
| 30 |
+
"fmax": 8000,
|
| 31 |
+
"fmax_for_loss": null,
|
| 32 |
+
|
| 33 |
+
"num_workers": 4,
|
| 34 |
+
|
| 35 |
+
"dist_config": {
|
| 36 |
+
"dist_backend": "nccl",
|
| 37 |
+
"dist_url": "tcp://localhost:54321",
|
| 38 |
+
"world_size": 1
|
| 39 |
+
}
|
| 40 |
+
}
|
hifigan/meldataset.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import librosa
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
import torch.utils.data
|
| 12 |
+
import torchaudio
|
| 13 |
+
from librosa.filters import mel as librosa_mel_fn
|
| 14 |
+
from librosa.util import normalize
|
| 15 |
+
from scipy.io.wavfile import read
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def load_wav(full_path):
|
| 19 |
+
#sampling_rate, data = read(full_path)
|
| 20 |
+
#return data, sampling_rate
|
| 21 |
+
data, sampling_rate = librosa.load(full_path, sr=None)
|
| 22 |
+
return data, sampling_rate
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
| 26 |
+
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def dynamic_range_decompression(x, C=1):
|
| 30 |
+
return np.exp(x) / C
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
| 34 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def dynamic_range_decompression_torch(x, C=1):
|
| 38 |
+
return torch.exp(x) / C
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def spectral_normalize_torch(magnitudes):
|
| 42 |
+
output = dynamic_range_compression_torch(magnitudes)
|
| 43 |
+
return output
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def spectral_de_normalize_torch(magnitudes):
|
| 47 |
+
output = dynamic_range_decompression_torch(magnitudes)
|
| 48 |
+
return output
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
mel_basis = {}
|
| 52 |
+
hann_window = {}
|
| 53 |
+
|
| 54 |
+
class LogMelSpectrogram(torch.nn.Module):
|
| 55 |
+
def __init__(self, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
| 56 |
+
super().__init__()
|
| 57 |
+
self.melspctrogram = torchaudio.transforms.MelSpectrogram(
|
| 58 |
+
sample_rate=sampling_rate,
|
| 59 |
+
n_fft=n_fft,
|
| 60 |
+
win_length=win_size,
|
| 61 |
+
hop_length=hop_size,
|
| 62 |
+
center=center,
|
| 63 |
+
power=1.0,
|
| 64 |
+
norm="slaney",
|
| 65 |
+
onesided=True,
|
| 66 |
+
n_mels=num_mels,
|
| 67 |
+
mel_scale="slaney",
|
| 68 |
+
f_min=fmin,
|
| 69 |
+
f_max=fmax
|
| 70 |
+
)
|
| 71 |
+
self.n_fft = n_fft
|
| 72 |
+
self.hop_size = hop_size
|
| 73 |
+
|
| 74 |
+
def forward(self, wav):
|
| 75 |
+
wav = F.pad(wav, ((self.n_fft - self.hop_size) // 2, (self.n_fft - self.hop_size) // 2), "reflect")
|
| 76 |
+
mel = self.melspctrogram(wav)
|
| 77 |
+
logmel = torch.log(torch.clamp(mel, min=1e-5))
|
| 78 |
+
return logmel
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
| 82 |
+
if torch.min(y) < -1.:
|
| 83 |
+
print('min value is ', torch.min(y))
|
| 84 |
+
if torch.max(y) > 1.:
|
| 85 |
+
print('max value is ', torch.max(y))
|
| 86 |
+
|
| 87 |
+
global mel_basis, hann_window
|
| 88 |
+
if fmax not in mel_basis:
|
| 89 |
+
mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
|
| 90 |
+
mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
|
| 91 |
+
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
|
| 92 |
+
|
| 93 |
+
# print("Padding by", int((n_fft - hop_size)/2), y.shape)
|
| 94 |
+
# pre-padding
|
| 95 |
+
n_pad = hop_size - ( y.shape[1] % hop_size )
|
| 96 |
+
y = F.pad(y.unsqueeze(1), (0, n_pad), mode='reflect').squeeze(1)
|
| 97 |
+
# print("intermediate:", y.shape)
|
| 98 |
+
|
| 99 |
+
y = F.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
|
| 100 |
+
y = y.squeeze(1)
|
| 101 |
+
|
| 102 |
+
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
|
| 103 |
+
center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
|
| 104 |
+
spec = spec.abs().clamp_(3e-5)
|
| 105 |
+
# print("Post: ", y.shape, spec.shape)
|
| 106 |
+
|
| 107 |
+
spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec)
|
| 108 |
+
spec = spectral_normalize_torch(spec)
|
| 109 |
+
|
| 110 |
+
return spec
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def get_dataset_filelist(a):
|
| 114 |
+
train_df = pd.read_csv(a.input_training_file)
|
| 115 |
+
valid_df = pd.read_csv(a.input_validation_file)
|
| 116 |
+
return train_df, valid_df
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class MelDataset(torch.utils.data.Dataset):
|
| 120 |
+
def __init__(self, training_files, segment_size, n_fft, num_mels,
|
| 121 |
+
hop_size, win_size, sampling_rate, fmin, fmax, split=True, shuffle=True, n_cache_reuse=1,
|
| 122 |
+
device=None, fmax_loss=None, fine_tuning=False, audio_root_path=None, feat_root_path=None, use_alt_melcalc=False):
|
| 123 |
+
self.audio_files = training_files
|
| 124 |
+
if shuffle:
|
| 125 |
+
self.audio_files = self.audio_files.sample(frac=1, random_state=1234)
|
| 126 |
+
self.segment_size = segment_size
|
| 127 |
+
self.sampling_rate = sampling_rate
|
| 128 |
+
self.split = split
|
| 129 |
+
self.n_fft = n_fft
|
| 130 |
+
self.num_mels = num_mels
|
| 131 |
+
self.hop_size = hop_size
|
| 132 |
+
self.win_size = win_size
|
| 133 |
+
self.fmin = fmin
|
| 134 |
+
self.fmax = fmax
|
| 135 |
+
self.fmax_loss = fmax_loss
|
| 136 |
+
self.cached_wav = None
|
| 137 |
+
self.n_cache_reuse = n_cache_reuse
|
| 138 |
+
self._cache_ref_count = 0
|
| 139 |
+
self.device = device
|
| 140 |
+
self.fine_tuning = fine_tuning
|
| 141 |
+
self.audio_root_path = Path(audio_root_path)
|
| 142 |
+
self.feat_root_path = Path(feat_root_path)
|
| 143 |
+
self.alt_melspec = LogMelSpectrogram(n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax)
|
| 144 |
+
self.use_alt_melcalc = use_alt_melcalc
|
| 145 |
+
|
| 146 |
+
def __getitem__(self, index):
|
| 147 |
+
row = self.audio_files.iloc[index]
|
| 148 |
+
if self._cache_ref_count == 0:
|
| 149 |
+
audio, sampling_rate = load_wav(self.audio_root_path/row.audio_path)
|
| 150 |
+
if not self.fine_tuning:
|
| 151 |
+
audio = normalize(audio) * 0.95
|
| 152 |
+
self.cached_wav = audio
|
| 153 |
+
if sampling_rate != self.sampling_rate:
|
| 154 |
+
raise ValueError("{} SR doesn't match target {} SR".format(
|
| 155 |
+
sampling_rate, self.sampling_rate))
|
| 156 |
+
self._cache_ref_count = self.n_cache_reuse
|
| 157 |
+
else:
|
| 158 |
+
audio = self.cached_wav
|
| 159 |
+
self._cache_ref_count -= 1
|
| 160 |
+
|
| 161 |
+
audio = torch.tensor(audio, dtype=torch.float32)
|
| 162 |
+
audio = audio.unsqueeze(0)
|
| 163 |
+
|
| 164 |
+
if not self.fine_tuning:
|
| 165 |
+
if self.split:
|
| 166 |
+
if audio.size(1) >= self.segment_size:
|
| 167 |
+
max_audio_start = audio.size(1) - self.segment_size
|
| 168 |
+
audio_start = random.randint(0, max_audio_start)
|
| 169 |
+
audio = audio[:, audio_start:audio_start+self.segment_size]
|
| 170 |
+
else:
|
| 171 |
+
audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant')
|
| 172 |
+
|
| 173 |
+
if self.use_alt_melcalc:
|
| 174 |
+
mel = self.alt_melspec(audio)
|
| 175 |
+
else:
|
| 176 |
+
mel1 = mel_spectrogram(audio, self.n_fft, self.num_mels,
|
| 177 |
+
self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax,
|
| 178 |
+
center=False)
|
| 179 |
+
|
| 180 |
+
mel = mel.permute(0, 2, 1) # (1, dim, seq_len) --> (1, seq_len, dim)
|
| 181 |
+
else:
|
| 182 |
+
mel = torch.load(self.feat_root_path/row.feat_path, map_location='cpu').float()
|
| 183 |
+
|
| 184 |
+
if len(mel.shape) < 3:
|
| 185 |
+
mel = mel.unsqueeze(0) # (1, seq_len, dim)
|
| 186 |
+
|
| 187 |
+
if self.split:
|
| 188 |
+
frames_per_seg = math.ceil(self.segment_size / self.hop_size)
|
| 189 |
+
|
| 190 |
+
if audio.size(1) >= self.segment_size:
|
| 191 |
+
mel_start = random.randint(0, mel.size(1) - frames_per_seg - 1)
|
| 192 |
+
mel = mel[:, mel_start:mel_start + frames_per_seg, :]
|
| 193 |
+
audio = audio[:, mel_start * self.hop_size:(mel_start + frames_per_seg) * self.hop_size]
|
| 194 |
+
else:
|
| 195 |
+
mel = torch.nn.functional.pad(mel, (0, 0, 0, frames_per_seg - mel.size(2)), 'constant')
|
| 196 |
+
audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant')
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
if self.use_alt_melcalc:
|
| 200 |
+
mel_loss = self.alt_melspec(audio)
|
| 201 |
+
else:
|
| 202 |
+
mel_loss = mel_spectrogram(audio, self.n_fft, self.num_mels,
|
| 203 |
+
self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax_loss,
|
| 204 |
+
center=False)
|
| 205 |
+
return (mel.squeeze(), audio.squeeze(0), str(row.audio_path), mel_loss.squeeze())
|
| 206 |
+
|
| 207 |
+
def __len__(self):
|
| 208 |
+
return len(self.audio_files)
|
hifigan/models.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
|
| 5 |
+
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
| 6 |
+
from .utils import init_weights, get_padding
|
| 7 |
+
|
| 8 |
+
LRELU_SLOPE = 0.1
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class ResBlock1(torch.nn.Module):
|
| 12 |
+
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
|
| 13 |
+
super(ResBlock1, self).__init__()
|
| 14 |
+
self.h = h
|
| 15 |
+
self.convs1 = nn.ModuleList([
|
| 16 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
| 17 |
+
padding=get_padding(kernel_size, dilation[0]))),
|
| 18 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
| 19 |
+
padding=get_padding(kernel_size, dilation[1]))),
|
| 20 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
|
| 21 |
+
padding=get_padding(kernel_size, dilation[2])))
|
| 22 |
+
])
|
| 23 |
+
self.convs1.apply(init_weights)
|
| 24 |
+
|
| 25 |
+
self.convs2 = nn.ModuleList([
|
| 26 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
| 27 |
+
padding=get_padding(kernel_size, 1))),
|
| 28 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
| 29 |
+
padding=get_padding(kernel_size, 1))),
|
| 30 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
| 31 |
+
padding=get_padding(kernel_size, 1)))
|
| 32 |
+
])
|
| 33 |
+
self.convs2.apply(init_weights)
|
| 34 |
+
|
| 35 |
+
def forward(self, x):
|
| 36 |
+
for c1, c2 in zip(self.convs1, self.convs2):
|
| 37 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
| 38 |
+
xt = c1(xt)
|
| 39 |
+
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
| 40 |
+
xt = c2(xt)
|
| 41 |
+
x = xt + x
|
| 42 |
+
return x
|
| 43 |
+
|
| 44 |
+
def remove_weight_norm(self):
|
| 45 |
+
for l in self.convs1:
|
| 46 |
+
remove_weight_norm(l)
|
| 47 |
+
for l in self.convs2:
|
| 48 |
+
remove_weight_norm(l)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class ResBlock2(torch.nn.Module):
|
| 52 |
+
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
|
| 53 |
+
super(ResBlock2, self).__init__()
|
| 54 |
+
self.h = h
|
| 55 |
+
self.convs = nn.ModuleList([
|
| 56 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
| 57 |
+
padding=get_padding(kernel_size, dilation[0]))),
|
| 58 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
| 59 |
+
padding=get_padding(kernel_size, dilation[1])))
|
| 60 |
+
])
|
| 61 |
+
self.convs.apply(init_weights)
|
| 62 |
+
|
| 63 |
+
def forward(self, x):
|
| 64 |
+
for c in self.convs:
|
| 65 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
| 66 |
+
xt = c(xt)
|
| 67 |
+
x = xt + x
|
| 68 |
+
return x
|
| 69 |
+
|
| 70 |
+
def remove_weight_norm(self):
|
| 71 |
+
for l in self.convs:
|
| 72 |
+
remove_weight_norm(l)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class Generator(torch.nn.Module):
|
| 76 |
+
def __init__(self, h):
|
| 77 |
+
super(Generator, self).__init__()
|
| 78 |
+
self.h = h
|
| 79 |
+
self.lin_pre = nn.Linear(h.hubert_dim, h.hifi_dim)
|
| 80 |
+
self.num_kernels = len(h.resblock_kernel_sizes)
|
| 81 |
+
self.num_upsamples = len(h.upsample_rates)
|
| 82 |
+
self.conv_pre = weight_norm(Conv1d(h.hifi_dim, h.upsample_initial_channel, 7, 1, padding=3))
|
| 83 |
+
resblock = ResBlock1 if h.resblock == '1' else ResBlock2
|
| 84 |
+
|
| 85 |
+
self.ups = nn.ModuleList()
|
| 86 |
+
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
| 87 |
+
|
| 88 |
+
self.ups.append(weight_norm(
|
| 89 |
+
ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)),
|
| 90 |
+
k, u, padding=(k-u)//2)))
|
| 91 |
+
|
| 92 |
+
self.resblocks = nn.ModuleList()
|
| 93 |
+
for i in range(len(self.ups)):
|
| 94 |
+
ch = h.upsample_initial_channel//(2**(i+1))
|
| 95 |
+
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
|
| 96 |
+
self.resblocks.append(resblock(h, ch, k, d))
|
| 97 |
+
|
| 98 |
+
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
|
| 99 |
+
self.ups.apply(init_weights)
|
| 100 |
+
self.conv_post.apply(init_weights)
|
| 101 |
+
|
| 102 |
+
def forward(self, x):
|
| 103 |
+
""" `x` as (bs, seq_len, dim), regular hifi assumes input of shape (bs, n_mels, seq_len) """
|
| 104 |
+
x = self.lin_pre(x)
|
| 105 |
+
x = x.permute(0, 2, 1) # (bs, seq_len, dim) --> (bs, dim, seq_len)
|
| 106 |
+
|
| 107 |
+
x = self.conv_pre(x)
|
| 108 |
+
for i in range(self.num_upsamples):
|
| 109 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
| 110 |
+
x = self.ups[i](x)
|
| 111 |
+
xs = None
|
| 112 |
+
for j in range(self.num_kernels):
|
| 113 |
+
if xs is None:
|
| 114 |
+
xs = self.resblocks[i*self.num_kernels+j](x)
|
| 115 |
+
else:
|
| 116 |
+
xs += self.resblocks[i*self.num_kernels+j](x)
|
| 117 |
+
x = xs / self.num_kernels
|
| 118 |
+
x = F.leaky_relu(x)
|
| 119 |
+
x = self.conv_post(x)
|
| 120 |
+
x = torch.tanh(x)
|
| 121 |
+
|
| 122 |
+
return x
|
| 123 |
+
|
| 124 |
+
def remove_weight_norm(self):
|
| 125 |
+
print('Removing weight norm...')
|
| 126 |
+
for l in self.ups:
|
| 127 |
+
remove_weight_norm(l)
|
| 128 |
+
for l in self.resblocks:
|
| 129 |
+
l.remove_weight_norm()
|
| 130 |
+
remove_weight_norm(self.conv_pre)
|
| 131 |
+
remove_weight_norm(self.conv_post)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class DiscriminatorP(torch.nn.Module):
|
| 135 |
+
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
| 136 |
+
super(DiscriminatorP, self).__init__()
|
| 137 |
+
self.period = period
|
| 138 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
| 139 |
+
self.convs = nn.ModuleList([
|
| 140 |
+
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
| 141 |
+
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
| 142 |
+
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
| 143 |
+
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
| 144 |
+
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
|
| 145 |
+
])
|
| 146 |
+
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
| 147 |
+
|
| 148 |
+
def forward(self, x):
|
| 149 |
+
fmap = []
|
| 150 |
+
|
| 151 |
+
# 1d to 2d
|
| 152 |
+
b, c, t = x.shape
|
| 153 |
+
if t % self.period != 0: # pad first
|
| 154 |
+
n_pad = self.period - (t % self.period)
|
| 155 |
+
x = F.pad(x, (0, n_pad), "reflect")
|
| 156 |
+
t = t + n_pad
|
| 157 |
+
x = x.view(b, c, t // self.period, self.period)
|
| 158 |
+
|
| 159 |
+
for l in self.convs:
|
| 160 |
+
x = l(x)
|
| 161 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
| 162 |
+
fmap.append(x)
|
| 163 |
+
x = self.conv_post(x)
|
| 164 |
+
fmap.append(x)
|
| 165 |
+
x = torch.flatten(x, 1, -1)
|
| 166 |
+
|
| 167 |
+
return x, fmap
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
class MultiPeriodDiscriminator(torch.nn.Module):
|
| 171 |
+
def __init__(self):
|
| 172 |
+
super(MultiPeriodDiscriminator, self).__init__()
|
| 173 |
+
self.discriminators = nn.ModuleList([
|
| 174 |
+
DiscriminatorP(2),
|
| 175 |
+
DiscriminatorP(3),
|
| 176 |
+
DiscriminatorP(5),
|
| 177 |
+
DiscriminatorP(7),
|
| 178 |
+
DiscriminatorP(11),
|
| 179 |
+
])
|
| 180 |
+
|
| 181 |
+
def forward(self, y, y_hat):
|
| 182 |
+
y_d_rs = []
|
| 183 |
+
y_d_gs = []
|
| 184 |
+
fmap_rs = []
|
| 185 |
+
fmap_gs = []
|
| 186 |
+
for i, d in enumerate(self.discriminators):
|
| 187 |
+
y_d_r, fmap_r = d(y)
|
| 188 |
+
y_d_g, fmap_g = d(y_hat)
|
| 189 |
+
y_d_rs.append(y_d_r)
|
| 190 |
+
fmap_rs.append(fmap_r)
|
| 191 |
+
y_d_gs.append(y_d_g)
|
| 192 |
+
fmap_gs.append(fmap_g)
|
| 193 |
+
|
| 194 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
class DiscriminatorS(torch.nn.Module):
|
| 198 |
+
def __init__(self, use_spectral_norm=False):
|
| 199 |
+
super(DiscriminatorS, self).__init__()
|
| 200 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
| 201 |
+
self.convs = nn.ModuleList([
|
| 202 |
+
norm_f(Conv1d(1, 128, 15, 1, padding=7)),
|
| 203 |
+
norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
|
| 204 |
+
norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
|
| 205 |
+
norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
|
| 206 |
+
norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
|
| 207 |
+
norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
|
| 208 |
+
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
| 209 |
+
])
|
| 210 |
+
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
| 211 |
+
|
| 212 |
+
def forward(self, x):
|
| 213 |
+
fmap = []
|
| 214 |
+
for l in self.convs:
|
| 215 |
+
x = l(x)
|
| 216 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
| 217 |
+
fmap.append(x)
|
| 218 |
+
x = self.conv_post(x)
|
| 219 |
+
fmap.append(x)
|
| 220 |
+
x = torch.flatten(x, 1, -1)
|
| 221 |
+
|
| 222 |
+
return x, fmap
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
class MultiScaleDiscriminator(torch.nn.Module):
|
| 226 |
+
def __init__(self):
|
| 227 |
+
super(MultiScaleDiscriminator, self).__init__()
|
| 228 |
+
self.discriminators = nn.ModuleList([
|
| 229 |
+
DiscriminatorS(use_spectral_norm=True),
|
| 230 |
+
DiscriminatorS(),
|
| 231 |
+
DiscriminatorS(),
|
| 232 |
+
])
|
| 233 |
+
self.meanpools = nn.ModuleList([
|
| 234 |
+
AvgPool1d(4, 2, padding=2),
|
| 235 |
+
AvgPool1d(4, 2, padding=2)
|
| 236 |
+
])
|
| 237 |
+
|
| 238 |
+
def forward(self, y, y_hat):
|
| 239 |
+
y_d_rs = []
|
| 240 |
+
y_d_gs = []
|
| 241 |
+
fmap_rs = []
|
| 242 |
+
fmap_gs = []
|
| 243 |
+
for i, d in enumerate(self.discriminators):
|
| 244 |
+
if i != 0:
|
| 245 |
+
y = self.meanpools[i-1](y)
|
| 246 |
+
y_hat = self.meanpools[i-1](y_hat)
|
| 247 |
+
y_d_r, fmap_r = d(y)
|
| 248 |
+
y_d_g, fmap_g = d(y_hat)
|
| 249 |
+
y_d_rs.append(y_d_r)
|
| 250 |
+
fmap_rs.append(fmap_r)
|
| 251 |
+
y_d_gs.append(y_d_g)
|
| 252 |
+
fmap_gs.append(fmap_g)
|
| 253 |
+
|
| 254 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def feature_loss(fmap_r, fmap_g):
|
| 258 |
+
loss = 0
|
| 259 |
+
for dr, dg in zip(fmap_r, fmap_g):
|
| 260 |
+
for rl, gl in zip(dr, dg):
|
| 261 |
+
loss += torch.mean(torch.abs(rl - gl))
|
| 262 |
+
|
| 263 |
+
return loss*2
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
| 267 |
+
loss = 0
|
| 268 |
+
r_losses = []
|
| 269 |
+
g_losses = []
|
| 270 |
+
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
| 271 |
+
r_loss = torch.mean((1-dr)**2)
|
| 272 |
+
g_loss = torch.mean(dg**2)
|
| 273 |
+
loss += (r_loss + g_loss)
|
| 274 |
+
r_losses.append(r_loss.item())
|
| 275 |
+
g_losses.append(g_loss.item())
|
| 276 |
+
|
| 277 |
+
return loss, r_losses, g_losses
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def generator_loss(disc_outputs):
|
| 281 |
+
loss = 0
|
| 282 |
+
gen_losses = []
|
| 283 |
+
for dg in disc_outputs:
|
| 284 |
+
l = torch.mean((1-dg)**2)
|
| 285 |
+
gen_losses.append(l)
|
| 286 |
+
loss += l
|
| 287 |
+
|
| 288 |
+
return loss, gen_losses
|
| 289 |
+
|
hifigan/train.py
ADDED
|
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import itertools
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
import time
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.multiprocessing as mp
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from fastprogress import master_bar, progress_bar
|
| 11 |
+
from torch.cuda.amp.grad_scaler import GradScaler
|
| 12 |
+
from torch.distributed import init_process_group
|
| 13 |
+
from torch.nn.parallel import DistributedDataParallel
|
| 14 |
+
from torch.utils.data import DataLoader, DistributedSampler
|
| 15 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 16 |
+
|
| 17 |
+
from .meldataset import (LogMelSpectrogram, MelDataset, get_dataset_filelist,
|
| 18 |
+
mel_spectrogram)
|
| 19 |
+
from .models import (Generator, MultiPeriodDiscriminator,
|
| 20 |
+
MultiScaleDiscriminator, discriminator_loss, feature_loss,
|
| 21 |
+
generator_loss)
|
| 22 |
+
from .utils import (AttrDict, build_env, load_checkpoint, plot_spectrogram,
|
| 23 |
+
save_checkpoint, scan_checkpoint)
|
| 24 |
+
|
| 25 |
+
torch.backends.cudnn.benchmark = True
|
| 26 |
+
USE_ALT_MELCALC = True
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def train(rank, a, h):
|
| 30 |
+
if h.num_gpus > 1:
|
| 31 |
+
init_process_group(backend=h.dist_config['dist_backend'], init_method=h.dist_config['dist_url'],
|
| 32 |
+
world_size=h.dist_config['world_size'] * h.num_gpus, rank=rank)
|
| 33 |
+
|
| 34 |
+
torch.cuda.manual_seed(h.seed)
|
| 35 |
+
device = torch.device('cuda:{:d}'.format(rank))
|
| 36 |
+
|
| 37 |
+
generator = Generator(h).to(device)
|
| 38 |
+
mpd = MultiPeriodDiscriminator().to(device)
|
| 39 |
+
msd = MultiScaleDiscriminator().to(device)
|
| 40 |
+
|
| 41 |
+
if rank == 0:
|
| 42 |
+
print(generator)
|
| 43 |
+
os.makedirs(a.checkpoint_path, exist_ok=True)
|
| 44 |
+
print("checkpoints directory : ", a.checkpoint_path)
|
| 45 |
+
|
| 46 |
+
if os.path.isdir(a.checkpoint_path):
|
| 47 |
+
cp_g = scan_checkpoint(a.checkpoint_path, 'g_')
|
| 48 |
+
cp_do = scan_checkpoint(a.checkpoint_path, 'do_')
|
| 49 |
+
|
| 50 |
+
steps = 0
|
| 51 |
+
if cp_g is None or cp_do is None:
|
| 52 |
+
state_dict_do = None
|
| 53 |
+
last_epoch = -1
|
| 54 |
+
else:
|
| 55 |
+
state_dict_g = load_checkpoint(cp_g, device)
|
| 56 |
+
state_dict_do = load_checkpoint(cp_do, device)
|
| 57 |
+
generator.load_state_dict(state_dict_g['generator'])
|
| 58 |
+
mpd.load_state_dict(state_dict_do['mpd'])
|
| 59 |
+
msd.load_state_dict(state_dict_do['msd'])
|
| 60 |
+
steps = state_dict_do['steps'] + 1
|
| 61 |
+
last_epoch = state_dict_do['epoch']
|
| 62 |
+
print(f"Restored checkpoint from {cp_g} and {cp_do}")
|
| 63 |
+
|
| 64 |
+
if h.num_gpus > 1:
|
| 65 |
+
print("Multi-gpu detected")
|
| 66 |
+
generator = DistributedDataParallel(generator, device_ids=[rank]).to(device)
|
| 67 |
+
mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device)
|
| 68 |
+
msd = DistributedDataParallel(msd, device_ids=[rank]).to(device)
|
| 69 |
+
|
| 70 |
+
optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2])
|
| 71 |
+
optim_d = torch.optim.AdamW(itertools.chain(msd.parameters(), mpd.parameters()),
|
| 72 |
+
h.learning_rate, betas=[h.adam_b1, h.adam_b2])
|
| 73 |
+
|
| 74 |
+
if state_dict_do is not None:
|
| 75 |
+
optim_g.load_state_dict(state_dict_do['optim_g'])
|
| 76 |
+
optim_d.load_state_dict(state_dict_do['optim_d'])
|
| 77 |
+
|
| 78 |
+
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch)
|
| 79 |
+
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch)
|
| 80 |
+
if a.fp16:
|
| 81 |
+
scaler_g = GradScaler()
|
| 82 |
+
scaler_d = GradScaler()
|
| 83 |
+
|
| 84 |
+
train_df, valid_df = get_dataset_filelist(a)
|
| 85 |
+
|
| 86 |
+
trainset = MelDataset(train_df, h.segment_size, h.n_fft, h.num_mels,
|
| 87 |
+
h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, n_cache_reuse=0,
|
| 88 |
+
shuffle=False if h.num_gpus > 1 else True, fmax_loss=h.fmax_for_loss, device=device,
|
| 89 |
+
fine_tuning=a.fine_tuning,
|
| 90 |
+
audio_root_path=a.audio_root_path, feat_root_path=a.feature_root_path,
|
| 91 |
+
use_alt_melcalc=USE_ALT_MELCALC)
|
| 92 |
+
|
| 93 |
+
train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None
|
| 94 |
+
|
| 95 |
+
train_loader = DataLoader(trainset, num_workers=h.num_workers, shuffle=False,
|
| 96 |
+
sampler=train_sampler,
|
| 97 |
+
batch_size=h.batch_size,
|
| 98 |
+
pin_memory=True,
|
| 99 |
+
persistent_workers=True,
|
| 100 |
+
drop_last=True)
|
| 101 |
+
|
| 102 |
+
alt_melspec = LogMelSpectrogram(h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax).to(device)
|
| 103 |
+
|
| 104 |
+
if rank == 0:
|
| 105 |
+
validset = MelDataset(valid_df, h.segment_size, h.n_fft, h.num_mels,
|
| 106 |
+
h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, False, False, n_cache_reuse=0,
|
| 107 |
+
fmax_loss=h.fmax_for_loss, device=device, fine_tuning=a.fine_tuning,
|
| 108 |
+
audio_root_path=a.audio_root_path, feat_root_path=a.feature_root_path,
|
| 109 |
+
use_alt_melcalc=USE_ALT_MELCALC)
|
| 110 |
+
validation_loader = DataLoader(validset, num_workers=1, shuffle=False,
|
| 111 |
+
sampler=None,
|
| 112 |
+
batch_size=1,
|
| 113 |
+
pin_memory=True,
|
| 114 |
+
persistent_workers=True,
|
| 115 |
+
drop_last=True)
|
| 116 |
+
|
| 117 |
+
sw = SummaryWriter(os.path.join(a.checkpoint_path, 'logs'))
|
| 118 |
+
|
| 119 |
+
generator.train()
|
| 120 |
+
mpd.train()
|
| 121 |
+
msd.train()
|
| 122 |
+
|
| 123 |
+
if rank == 0: mb = master_bar(range(max(0, last_epoch), a.training_epochs))
|
| 124 |
+
else: mb = range(max(0, last_epoch), a.training_epochs)
|
| 125 |
+
|
| 126 |
+
for epoch in mb:
|
| 127 |
+
if rank == 0:
|
| 128 |
+
start = time.time()
|
| 129 |
+
mb.write("Epoch: {}".format(epoch+1))
|
| 130 |
+
|
| 131 |
+
if h.num_gpus > 1:
|
| 132 |
+
train_sampler.set_epoch(epoch)
|
| 133 |
+
|
| 134 |
+
if rank == 0: pb = progress_bar(enumerate(train_loader), total=len(train_loader), parent=mb)
|
| 135 |
+
else: pb = enumerate(train_loader)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
for i, batch in pb:
|
| 139 |
+
if rank == 0:
|
| 140 |
+
start_b = time.time()
|
| 141 |
+
x, y, _, y_mel = batch
|
| 142 |
+
x = x.to(device, non_blocking=True)
|
| 143 |
+
y = y.to(device, non_blocking=True)
|
| 144 |
+
y_mel = y_mel.to(device, non_blocking=True)
|
| 145 |
+
y = y.unsqueeze(1)
|
| 146 |
+
|
| 147 |
+
with torch.cuda.amp.autocast(enabled=a.fp16):
|
| 148 |
+
y_g_hat = generator(x)
|
| 149 |
+
if USE_ALT_MELCALC:
|
| 150 |
+
y_g_hat_mel = alt_melspec(y_g_hat.squeeze(1))
|
| 151 |
+
else:
|
| 152 |
+
y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size,
|
| 153 |
+
h.fmin, h.fmax_for_loss)
|
| 154 |
+
# print(x.shape, y_g_hat.shape, y_g_hat_mel.shape, y_mel.shape, y.shape)
|
| 155 |
+
optim_d.zero_grad()
|
| 156 |
+
|
| 157 |
+
with torch.cuda.amp.autocast(enabled=a.fp16):
|
| 158 |
+
# MPD
|
| 159 |
+
y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach())
|
| 160 |
+
loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
|
| 161 |
+
|
| 162 |
+
# MSD
|
| 163 |
+
y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach())
|
| 164 |
+
loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
|
| 165 |
+
|
| 166 |
+
loss_disc_all = loss_disc_s + loss_disc_f
|
| 167 |
+
|
| 168 |
+
if a.fp16:
|
| 169 |
+
scaler_d.scale(loss_disc_all).backward()
|
| 170 |
+
scaler_d.step(optim_d)
|
| 171 |
+
scaler_d.update()
|
| 172 |
+
else:
|
| 173 |
+
loss_disc_all.backward()
|
| 174 |
+
optim_d.step()
|
| 175 |
+
|
| 176 |
+
# Generator
|
| 177 |
+
optim_g.zero_grad()
|
| 178 |
+
|
| 179 |
+
with torch.cuda.amp.autocast(enabled=a.fp16):
|
| 180 |
+
# L1 Mel-Spectrogram Loss
|
| 181 |
+
loss_mel = F.l1_loss(y_mel, y_g_hat_mel) * 45
|
| 182 |
+
|
| 183 |
+
y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat)
|
| 184 |
+
y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat)
|
| 185 |
+
loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
|
| 186 |
+
loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
|
| 187 |
+
loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
|
| 188 |
+
loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
|
| 189 |
+
loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
|
| 190 |
+
|
| 191 |
+
if a.fp16:
|
| 192 |
+
scaler_g.scale(loss_gen_all).backward()
|
| 193 |
+
scaler_g.step(optim_g)
|
| 194 |
+
scaler_g.update()
|
| 195 |
+
else:
|
| 196 |
+
loss_gen_all.backward()
|
| 197 |
+
optim_g.step()
|
| 198 |
+
|
| 199 |
+
if rank == 0:
|
| 200 |
+
# STDOUT logging
|
| 201 |
+
if steps % a.stdout_interval == 0:
|
| 202 |
+
with torch.no_grad():
|
| 203 |
+
mel_error = F.l1_loss(y_mel, y_g_hat_mel).item()
|
| 204 |
+
|
| 205 |
+
mb.write('Steps : {:,d}, Gen Loss Total : {:4.3f}, Mel-Spec. Error : {:4.3f}, sec/batch : {:4.3f}, peak mem: {:5.2f}GB'. \
|
| 206 |
+
format(steps, loss_gen_all, mel_error, time.time() - start_b, torch.cuda.max_memory_allocated()/1e9))
|
| 207 |
+
mb.child.comment = "Steps : {:,d}, Gen Loss Total : {:4.3f}, Mel-Spec. Error : {:4.3f}". \
|
| 208 |
+
format(steps, loss_gen_all, mel_error)
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
# checkpointing
|
| 212 |
+
if steps % a.checkpoint_interval == 0 and steps != 0:
|
| 213 |
+
checkpoint_path = "{}/g_{:08d}.pt".format(a.checkpoint_path, steps)
|
| 214 |
+
save_checkpoint(checkpoint_path,
|
| 215 |
+
{'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()})
|
| 216 |
+
checkpoint_path = "{}/do_{:08d}.pt".format(a.checkpoint_path, steps)
|
| 217 |
+
save_checkpoint(checkpoint_path,
|
| 218 |
+
{'mpd': (mpd.module if h.num_gpus > 1
|
| 219 |
+
else mpd).state_dict(),
|
| 220 |
+
'msd': (msd.module if h.num_gpus > 1
|
| 221 |
+
else msd).state_dict(),
|
| 222 |
+
'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps,
|
| 223 |
+
'epoch': epoch})
|
| 224 |
+
|
| 225 |
+
# Tensorboard summary logging
|
| 226 |
+
if steps % a.summary_interval == 0:
|
| 227 |
+
sw.add_scalar("training/gen_loss_total", loss_gen_all, steps)
|
| 228 |
+
sw.add_scalar("training/mel_spec_error", mel_error, steps)
|
| 229 |
+
sw.add_scalar("training/disc_loss_total", loss_disc_all, steps)
|
| 230 |
+
|
| 231 |
+
# Validation
|
| 232 |
+
if steps % a.validation_interval == 0: # and steps != 0:
|
| 233 |
+
generator.eval()
|
| 234 |
+
torch.cuda.empty_cache()
|
| 235 |
+
val_err_tot = 0
|
| 236 |
+
with torch.no_grad():
|
| 237 |
+
for j, batch in progress_bar(enumerate(validation_loader), total=len(validation_loader), parent=mb):
|
| 238 |
+
x, y, _, y_mel = batch
|
| 239 |
+
y_g_hat = generator(x.to(device))
|
| 240 |
+
y_mel = y_mel.to(device, non_blocking=True)
|
| 241 |
+
if USE_ALT_MELCALC:
|
| 242 |
+
y_g_hat_mel = alt_melspec(y_g_hat.squeeze(1))
|
| 243 |
+
if y_g_hat_mel.shape[-1] != y_mel.shape[-1]:
|
| 244 |
+
# pad it
|
| 245 |
+
n_pad = h.hop_size
|
| 246 |
+
y_g_hat = F.pad(y_g_hat, (n_pad//2, n_pad - n_pad//2))
|
| 247 |
+
y_g_hat_mel = alt_melspec(y_g_hat.squeeze(1))
|
| 248 |
+
else:
|
| 249 |
+
y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate,
|
| 250 |
+
h.hop_size, h.win_size,
|
| 251 |
+
h.fmin, h.fmax_for_loss)
|
| 252 |
+
#print('valid', x.shape, y_g_hat.shape, y_g_hat_mel.shape, y_mel.shape, y.shape)
|
| 253 |
+
val_err_tot += F.l1_loss(y_mel, y_g_hat_mel).item()
|
| 254 |
+
|
| 255 |
+
if j <= 4:
|
| 256 |
+
if steps == 0:
|
| 257 |
+
sw.add_audio('gt/y_{}'.format(j), y[0], steps, h.sampling_rate)
|
| 258 |
+
sw.add_figure('gt/y_spec_{}'.format(j), plot_spectrogram(x[0]), steps)
|
| 259 |
+
|
| 260 |
+
sw.add_audio('generated/y_hat_{}'.format(j), y_g_hat[0], steps, h.sampling_rate)
|
| 261 |
+
if USE_ALT_MELCALC:
|
| 262 |
+
y_hat_spec = alt_melspec(y_g_hat.squeeze(1))
|
| 263 |
+
else:
|
| 264 |
+
y_hat_spec = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate,
|
| 265 |
+
h.hop_size, h.win_size,
|
| 266 |
+
h.fmin, h.fmax_for_loss)
|
| 267 |
+
|
| 268 |
+
sw.add_figure('generated/y_hat_spec_{}'.format(j),
|
| 269 |
+
plot_spectrogram(y_hat_spec.squeeze(0).cpu().numpy()), steps)
|
| 270 |
+
|
| 271 |
+
val_err = val_err_tot / (j+1)
|
| 272 |
+
sw.add_scalar("validation/mel_spec_error", val_err, steps)
|
| 273 |
+
mb.write(f"validation run complete at {steps:,d} steps. validation mel spec error: {val_err:5.4f}")
|
| 274 |
+
|
| 275 |
+
generator.train()
|
| 276 |
+
sw.add_scalar("memory/max_allocated_gb", torch.cuda.max_memory_allocated()/1e9, steps)
|
| 277 |
+
sw.add_scalar("memory/max_reserved_gb", torch.cuda.max_memory_reserved()/1e9, steps)
|
| 278 |
+
torch.cuda.reset_peak_memory_stats()
|
| 279 |
+
torch.cuda.reset_accumulated_memory_stats()
|
| 280 |
+
|
| 281 |
+
steps += 1
|
| 282 |
+
|
| 283 |
+
scheduler_g.step()
|
| 284 |
+
scheduler_d.step()
|
| 285 |
+
|
| 286 |
+
if rank == 0:
|
| 287 |
+
print('Time taken for epoch {} is {} sec\n'.format(epoch + 1, int(time.time() - start)))
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def main():
|
| 291 |
+
print('Initializing Training Process..')
|
| 292 |
+
|
| 293 |
+
parser = argparse.ArgumentParser()
|
| 294 |
+
|
| 295 |
+
parser.add_argument('--group_name', default=None)
|
| 296 |
+
parser.add_argument('--audio_root_path', required=True)
|
| 297 |
+
parser.add_argument('--feature_root_path', required=True)
|
| 298 |
+
parser.add_argument('--input_training_file', default='LJSpeech-1.1/training.txt')
|
| 299 |
+
parser.add_argument('--input_validation_file', default='LJSpeech-1.1/validation.txt')
|
| 300 |
+
parser.add_argument('--checkpoint_path', default='cp_hifigan')
|
| 301 |
+
parser.add_argument('--config', default='')
|
| 302 |
+
parser.add_argument('--training_epochs', default=1500, type=int)
|
| 303 |
+
parser.add_argument('--stdout_interval', default=5, type=int)
|
| 304 |
+
parser.add_argument('--checkpoint_interval', default=5000, type=int)
|
| 305 |
+
parser.add_argument('--summary_interval', default=25, type=int)
|
| 306 |
+
parser.add_argument('--validation_interval', default=1000, type=int)
|
| 307 |
+
parser.add_argument('--fp16', default=False, type=bool)
|
| 308 |
+
parser.add_argument('--fine_tuning', action='store_true')
|
| 309 |
+
|
| 310 |
+
a = parser.parse_args()
|
| 311 |
+
print(a)
|
| 312 |
+
with open(a.config) as f:
|
| 313 |
+
data = f.read()
|
| 314 |
+
|
| 315 |
+
json_config = json.loads(data)
|
| 316 |
+
h = AttrDict(json_config)
|
| 317 |
+
build_env(a.config, 'config.json', a.checkpoint_path)
|
| 318 |
+
|
| 319 |
+
torch.manual_seed(h.seed)
|
| 320 |
+
if torch.cuda.is_available():
|
| 321 |
+
torch.cuda.manual_seed(h.seed)
|
| 322 |
+
h.num_gpus = torch.cuda.device_count()
|
| 323 |
+
h.batch_size = int(h.batch_size / h.num_gpus)
|
| 324 |
+
print('Batch size per GPU :', h.batch_size)
|
| 325 |
+
else:
|
| 326 |
+
pass
|
| 327 |
+
|
| 328 |
+
if h.num_gpus > 1:
|
| 329 |
+
mp.spawn(train, nprocs=h.num_gpus, args=(a, h,))
|
| 330 |
+
else:
|
| 331 |
+
train(0, a, h)
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
if __name__ == '__main__':
|
| 335 |
+
main()
|
hifigan/utils.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import os
|
| 3 |
+
import shutil
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch.nn.utils import weight_norm
|
| 7 |
+
import json
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def plot_spectrogram(spectrogram):
|
| 11 |
+
import matplotlib.pylab as plt
|
| 12 |
+
import matplotlib
|
| 13 |
+
matplotlib.use("Agg")
|
| 14 |
+
fig, ax = plt.subplots(figsize=(10, 2))
|
| 15 |
+
im = ax.imshow(spectrogram, aspect="auto", origin="lower",
|
| 16 |
+
interpolation='none')
|
| 17 |
+
plt.colorbar(im, ax=ax)
|
| 18 |
+
|
| 19 |
+
fig.canvas.draw()
|
| 20 |
+
plt.close()
|
| 21 |
+
|
| 22 |
+
return fig
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def init_weights(m, mean=0.0, std=0.01):
|
| 26 |
+
classname = m.__class__.__name__
|
| 27 |
+
if classname.find("Conv") != -1:
|
| 28 |
+
m.weight.data.normal_(mean, std)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def apply_weight_norm(m):
|
| 32 |
+
classname = m.__class__.__name__
|
| 33 |
+
if classname.find("Conv") != -1:
|
| 34 |
+
weight_norm(m)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def get_padding(kernel_size, dilation=1):
|
| 38 |
+
return int((kernel_size*dilation - dilation)/2)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def load_checkpoint(filepath, device):
|
| 42 |
+
assert os.path.isfile(filepath)
|
| 43 |
+
print("Loading '{}'".format(filepath))
|
| 44 |
+
checkpoint_dict = torch.load(filepath, map_location=device)
|
| 45 |
+
print("Complete.")
|
| 46 |
+
return checkpoint_dict
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def save_checkpoint(filepath, obj):
|
| 50 |
+
print("Saving checkpoint to {}".format(filepath))
|
| 51 |
+
torch.save(obj, filepath)
|
| 52 |
+
print("Complete.")
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def scan_checkpoint(cp_dir, prefix):
|
| 56 |
+
pattern = os.path.join(cp_dir, prefix + '*')
|
| 57 |
+
cp_list = glob.glob(pattern)
|
| 58 |
+
if len(cp_list) == 0:
|
| 59 |
+
return None
|
| 60 |
+
return sorted(cp_list)[-1]
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class AttrDict(dict):
|
| 64 |
+
def __init__(self, *args, **kwargs):
|
| 65 |
+
super(AttrDict, self).__init__(*args, **kwargs)
|
| 66 |
+
self.__dict__ = self
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def build_env(config, config_name, path):
|
| 70 |
+
t_path = os.path.join(path, config_name)
|
| 71 |
+
if config != t_path:
|
| 72 |
+
os.makedirs(path, exist_ok=True)
|
| 73 |
+
shutil.copyfile(config, os.path.join(path, config_name))
|
hubconf.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dependencies = ['torch', 'torchaudio', 'numpy']
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import Tensor
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import logging
|
| 8 |
+
import json
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
from wavlm.WavLM import WavLM, WavLMConfig
|
| 13 |
+
from hifigan.models import Generator as HiFiGAN
|
| 14 |
+
from hifigan.utils import AttrDict
|
| 15 |
+
from matcher import KNeighborsVC
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def knn_vc(pretrained=True, progress=True, prematched=True, device='cuda') -> KNeighborsVC:
|
| 19 |
+
""" Load kNN-VC (WavLM encoder and HiFiGAN decoder). Optionally use vocoder trained on `prematched` data. """
|
| 20 |
+
hifigan, hifigan_cfg = hifigan_wavlm(pretrained, progress, prematched, device)
|
| 21 |
+
wavlm = wavlm_large(pretrained, progress, device)
|
| 22 |
+
knnvc = KNeighborsVC(wavlm, hifigan, hifigan_cfg, device)
|
| 23 |
+
return knnvc
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def hifigan_wavlm(pretrained=True, progress=True, prematched=True, device='cuda') -> HiFiGAN:
|
| 27 |
+
""" Load pretrained hifigan trained to vocode wavlm features. Optionally use weights trained on `prematched` data. """
|
| 28 |
+
cp = Path(__file__).parent.absolute()
|
| 29 |
+
|
| 30 |
+
with open(cp/'hifigan'/'config_v1_wavlm.json') as f:
|
| 31 |
+
data = f.read()
|
| 32 |
+
json_config = json.loads(data)
|
| 33 |
+
h = AttrDict(json_config)
|
| 34 |
+
device = torch.device(device)
|
| 35 |
+
|
| 36 |
+
generator = HiFiGAN(h).to(device)
|
| 37 |
+
|
| 38 |
+
if pretrained:
|
| 39 |
+
if prematched:
|
| 40 |
+
url = "https://github.com/bshall/knn-vc/releases/download/v0.1/prematch_g_02500000.pt"
|
| 41 |
+
else:
|
| 42 |
+
url = "https://github.com/bshall/knn-vc/releases/download/v0.1/g_02500000.pt"
|
| 43 |
+
state_dict_g = torch.hub.load_state_dict_from_url(
|
| 44 |
+
url,
|
| 45 |
+
map_location=device,
|
| 46 |
+
progress=progress
|
| 47 |
+
)
|
| 48 |
+
generator.load_state_dict(state_dict_g['generator'])
|
| 49 |
+
generator.eval()
|
| 50 |
+
generator.remove_weight_norm()
|
| 51 |
+
print(f"[HiFiGAN] Generator loaded with {sum([p.numel() for p in generator.parameters()]):,d} parameters.")
|
| 52 |
+
return generator, h
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def wavlm_large(pretrained=True, progress=True, device='cuda') -> WavLM:
|
| 56 |
+
"""Load the WavLM large checkpoint from the original paper. See https://github.com/microsoft/unilm/tree/master/wavlm for details. """
|
| 57 |
+
if torch.cuda.is_available() == False:
|
| 58 |
+
if str(device) != 'cpu':
|
| 59 |
+
logging.warning(f"Overriding device {device} to cpu since no GPU is available.")
|
| 60 |
+
device = 'cpu'
|
| 61 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
| 62 |
+
"https://github.com/bshall/knn-vc/releases/download/v0.1/WavLM-Large.pt",
|
| 63 |
+
map_location=device,
|
| 64 |
+
progress=progress
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
cfg = WavLMConfig(checkpoint['cfg'])
|
| 68 |
+
device = torch.device(device)
|
| 69 |
+
model = WavLM(cfg)
|
| 70 |
+
if pretrained:
|
| 71 |
+
model.load_state_dict(checkpoint['model'])
|
| 72 |
+
model = model.to(device)
|
| 73 |
+
model.eval()
|
| 74 |
+
print(f"WavLM-Large loaded with {sum([p.numel() for p in model.parameters()]):,d} parameters.")
|
| 75 |
+
return model
|
knnvc_utils.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
def generate_matrix_from_index(A, len=25):
|
| 4 |
+
matrix = np.zeros(len, dtype=float)
|
| 5 |
+
matrix[A] = 1
|
| 6 |
+
return matrix
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def retrieve_index_from_matrix(matrix):
|
| 10 |
+
A = np.where(matrix == 1)[0]
|
| 11 |
+
return A
|
| 12 |
+
|
| 13 |
+
if __name__ == '__main__':
|
| 14 |
+
# Generating a matrix from index A
|
| 15 |
+
A = 6
|
| 16 |
+
matrix = generate_matrix_from_index(A)
|
| 17 |
+
print("Generated Matrix:")
|
| 18 |
+
print(matrix)
|
| 19 |
+
|
| 20 |
+
# Retrieving index A from the matrix
|
| 21 |
+
retrieved_A = retrieve_index_from_matrix(matrix)
|
| 22 |
+
print("Retrieved Index A:")
|
| 23 |
+
print(retrieved_A)
|
matcher.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import torchaudio
|
| 8 |
+
import torchaudio.transforms as T
|
| 9 |
+
from hifigan.models import Generator as HiFiGAN
|
| 10 |
+
from hifigan.utils import AttrDict
|
| 11 |
+
from torch import Tensor
|
| 12 |
+
from torchaudio.sox_effects import apply_effects_tensor
|
| 13 |
+
from wavlm.WavLM import WavLM
|
| 14 |
+
from knnvc_utils import generate_matrix_from_index
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
SPEAKER_INFORMATION_LAYER = 6
|
| 18 |
+
SPEAKER_INFORMATION_WEIGHTS = generate_matrix_from_index(SPEAKER_INFORMATION_LAYER)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def fast_cosine_dist(source_feats: Tensor, matching_pool: Tensor, device: str = 'cpu') -> Tensor:
|
| 22 |
+
""" Like torch.cdist, but fixed dim=-1 and for cosine distance."""
|
| 23 |
+
source_norms = torch.norm(source_feats, p=2, dim=-1).to(device)
|
| 24 |
+
matching_norms = torch.norm(matching_pool, p=2, dim=-1)
|
| 25 |
+
dotprod = -torch.cdist(source_feats[None].to(device), matching_pool[None], p=2)[0]**2 + source_norms[:, None]**2 + matching_norms[None]**2
|
| 26 |
+
dotprod /= 2
|
| 27 |
+
|
| 28 |
+
dists = 1 - ( dotprod / (source_norms[:, None] * matching_norms[None]) )
|
| 29 |
+
return dists
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class KNeighborsVC(nn.Module):
|
| 33 |
+
|
| 34 |
+
def __init__(self,
|
| 35 |
+
wavlm: WavLM,
|
| 36 |
+
hifigan: HiFiGAN,
|
| 37 |
+
hifigan_cfg: AttrDict,
|
| 38 |
+
device='cuda'
|
| 39 |
+
) -> None:
|
| 40 |
+
""" kNN-VC matcher.
|
| 41 |
+
Arguments:
|
| 42 |
+
- `wavlm` : trained WavLM model
|
| 43 |
+
- `hifigan`: trained hifigan model
|
| 44 |
+
- `hifigan_cfg`: hifigan config to use for vocoding.
|
| 45 |
+
"""
|
| 46 |
+
super().__init__()
|
| 47 |
+
# set which features to extract from wavlm
|
| 48 |
+
self.weighting = torch.tensor(SPEAKER_INFORMATION_WEIGHTS, device=device)[:, None]
|
| 49 |
+
# load hifigan
|
| 50 |
+
self.hifigan = hifigan.eval()
|
| 51 |
+
self.h = hifigan_cfg
|
| 52 |
+
# store wavlm
|
| 53 |
+
self.wavlm = wavlm.eval()
|
| 54 |
+
self.device = torch.device(device)
|
| 55 |
+
self.sr = self.h.sampling_rate
|
| 56 |
+
self.hop_length = 320
|
| 57 |
+
|
| 58 |
+
def get_matching_set(self, wavs: list[Path] | list[Tensor], weights=None, vad_trigger_level=7) -> Tensor:
|
| 59 |
+
""" Get concatenated wavlm features for the matching set using all waveforms in `wavs`,
|
| 60 |
+
specified as either a list of paths or list of loaded waveform tensors of
|
| 61 |
+
shape (channels, T), assumed to be of 16kHz sample rate.
|
| 62 |
+
Optionally specify custom WavLM feature weighting with `weights`.
|
| 63 |
+
"""
|
| 64 |
+
feats = []
|
| 65 |
+
for p in wavs:
|
| 66 |
+
feats.append(self.get_features(p, weights=self.weighting if weights is None else weights, vad_trigger_level=vad_trigger_level))
|
| 67 |
+
|
| 68 |
+
feats = torch.concat(feats, dim=0).cpu()
|
| 69 |
+
return feats
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
@torch.inference_mode()
|
| 73 |
+
def vocode(self, c: Tensor) -> Tensor:
|
| 74 |
+
""" Vocode features with hifigan. `c` is of shape (bs, seq_len, c_dim) """
|
| 75 |
+
y_g_hat = self.hifigan(c)
|
| 76 |
+
y_g_hat = y_g_hat.squeeze(1)
|
| 77 |
+
return y_g_hat
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
@torch.inference_mode()
|
| 81 |
+
def get_features(self, path, weights=None, vad_trigger_level=0):
|
| 82 |
+
"""Returns features of `path` waveform as a tensor of shape (seq_len, dim), optionally perform VAD trimming
|
| 83 |
+
on start/end with `vad_trigger_level`.
|
| 84 |
+
"""
|
| 85 |
+
# load audio
|
| 86 |
+
if weights == None: weights = self.weighting
|
| 87 |
+
if type(path) in [str, Path]:
|
| 88 |
+
x, sr = torchaudio.load(path, normalize=True)
|
| 89 |
+
else:
|
| 90 |
+
x: Tensor = path
|
| 91 |
+
sr = self.sr
|
| 92 |
+
if x.dim() == 1: x = x[None]
|
| 93 |
+
|
| 94 |
+
if not sr == self.sr :
|
| 95 |
+
print(f"resample {sr} to {self.sr} in {path}")
|
| 96 |
+
x = torchaudio.functional.resample(x, orig_freq=sr, new_freq=self.sr)
|
| 97 |
+
sr = self.sr
|
| 98 |
+
|
| 99 |
+
# trim silence from front and back
|
| 100 |
+
if vad_trigger_level > 1e-3:
|
| 101 |
+
transform = T.Vad(sample_rate=sr, trigger_level=vad_trigger_level)
|
| 102 |
+
x_front_trim = transform(x)
|
| 103 |
+
# original way, disabled because it lacks windows support
|
| 104 |
+
#waveform_reversed, sr = apply_effects_tensor(x_front_trim, sr, [["reverse"]])
|
| 105 |
+
waveform_reversed = torch.flip(x_front_trim, (-1,))
|
| 106 |
+
waveform_reversed_front_trim = transform(waveform_reversed)
|
| 107 |
+
waveform_end_trim = torch.flip(waveform_reversed_front_trim, (-1,))
|
| 108 |
+
#waveform_end_trim, sr = apply_effects_tensor(
|
| 109 |
+
# waveform_reversed_front_trim, sr, [["reverse"]]
|
| 110 |
+
#)
|
| 111 |
+
x = waveform_end_trim
|
| 112 |
+
|
| 113 |
+
# extract the representation of each layer
|
| 114 |
+
wav_input_16khz = x.to(self.device)
|
| 115 |
+
if torch.allclose(weights, self.weighting):
|
| 116 |
+
# use fastpath
|
| 117 |
+
features = self.wavlm.extract_features(wav_input_16khz, output_layer=SPEAKER_INFORMATION_LAYER, ret_layer_results=False)[0]
|
| 118 |
+
features = features.squeeze(0)
|
| 119 |
+
else:
|
| 120 |
+
# use slower weighted
|
| 121 |
+
rep, layer_results = self.wavlm.extract_features(wav_input_16khz, output_layer=self.wavlm.cfg.encoder_layers, ret_layer_results=True)[0]
|
| 122 |
+
features = torch.cat([x.transpose(0, 1) for x, _ in layer_results], dim=0) # (n_layers, seq_len, dim)
|
| 123 |
+
# save full sequence
|
| 124 |
+
features = ( features*weights[:, None] ).sum(dim=0) # (seq_len, dim)
|
| 125 |
+
|
| 126 |
+
return features
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
@torch.inference_mode()
|
| 130 |
+
def match(self, query_seq: Tensor, matching_set: Tensor, synth_set: Tensor = None,
|
| 131 |
+
topk: int = 4, tgt_loudness_db: float | None = -16,
|
| 132 |
+
target_duration: float | None = None, device: str | None = None) -> Tensor:
|
| 133 |
+
""" Given `query_seq`, `matching_set`, and `synth_set` tensors of shape (N, dim), perform kNN regression matching
|
| 134 |
+
with k=`topk`. Inputs:
|
| 135 |
+
- `query_seq`: Tensor (N1, dim) of the input/source query features.
|
| 136 |
+
- `matching_set`: Tensor (N2, dim) of the matching set used as the 'training set' for the kNN algorithm.
|
| 137 |
+
- `synth_set`: optional Tensor (N2, dim) corresponding to the matching set. We use the matching set to assign each query
|
| 138 |
+
vector to a vector in the matching set, and then use the corresponding vector from the synth set during HiFiGAN synthesis.
|
| 139 |
+
By default, and for best performance, this should be identical to the matching set.
|
| 140 |
+
- `topk`: k in the kNN -- the number of nearest neighbors to average over.
|
| 141 |
+
- `tgt_loudness_db`: float db used to normalize the output volume. Set to None to disable.
|
| 142 |
+
- `target_duration`: if set to a float, interpolate resulting waveform duration to be equal to this value in seconds.
|
| 143 |
+
- `device`: if None, uses default device at initialization. Otherwise uses specified device
|
| 144 |
+
Returns:
|
| 145 |
+
- converted waveform of shape (T,)
|
| 146 |
+
"""
|
| 147 |
+
device = torch.device(device) if device is not None else self.device
|
| 148 |
+
if synth_set is None: synth_set = matching_set.to(device)
|
| 149 |
+
else: synth_set = synth_set.to(device)
|
| 150 |
+
matching_set = matching_set.to(device)
|
| 151 |
+
query_seq = query_seq.to(device)
|
| 152 |
+
|
| 153 |
+
if target_duration is not None:
|
| 154 |
+
target_samples = int(target_duration*self.sr)
|
| 155 |
+
scale_factor = (target_samples/self.hop_length) / query_seq.shape[0] # n_targ_feats / n_input_feats
|
| 156 |
+
query_seq = F.interpolate(query_seq.T[None], scale_factor=scale_factor, mode='linear')[0].T
|
| 157 |
+
|
| 158 |
+
dists = fast_cosine_dist(query_seq, matching_set, device=device)
|
| 159 |
+
best = dists.topk(k=topk, largest=False, dim=-1)
|
| 160 |
+
out_feats = synth_set[best.indices].mean(dim=1)
|
| 161 |
+
|
| 162 |
+
prediction = self.vocode(out_feats[None].to(device)).cpu().squeeze()
|
| 163 |
+
|
| 164 |
+
# normalization
|
| 165 |
+
if tgt_loudness_db is not None:
|
| 166 |
+
src_loudness = torchaudio.functional.loudness(prediction[None], self.h.sampling_rate)
|
| 167 |
+
tgt_loudness = tgt_loudness_db
|
| 168 |
+
pred_wav = torchaudio.functional.gain(prediction, tgt_loudness - src_loudness)
|
| 169 |
+
else: pred_wav = prediction
|
| 170 |
+
return pred_wav
|
| 171 |
+
|
| 172 |
+
|
prematch_dataset.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import gc
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import time
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
import torchaudio
|
| 14 |
+
from fastprogress.fastprogress import master_bar, progress_bar
|
| 15 |
+
from torch import Tensor
|
| 16 |
+
|
| 17 |
+
from hubconf import wavlm_large
|
| 18 |
+
|
| 19 |
+
DOWNSAMPLE_FACTOR = 320
|
| 20 |
+
|
| 21 |
+
global feature_cache
|
| 22 |
+
feature_cache = {}
|
| 23 |
+
global synthesis_cache
|
| 24 |
+
synthesis_cache = {}
|
| 25 |
+
|
| 26 |
+
def make_librispeech_df(root_path: Path) -> pd.DataFrame:
|
| 27 |
+
all_files = []
|
| 28 |
+
folders = ['train-clean-100', 'dev-clean']
|
| 29 |
+
print(f"[LIBRISPEECH] Computing folders {folders}")
|
| 30 |
+
for f in folders:
|
| 31 |
+
all_files.extend(list((root_path/f).rglob('**/*.flac')))
|
| 32 |
+
speakers = ['ls-' + f.stem.split('-')[0] for f in all_files]
|
| 33 |
+
df = pd.DataFrame({'path': all_files, 'speaker': speakers})
|
| 34 |
+
return df
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def main(args):
|
| 38 |
+
device = torch.device(args.device)
|
| 39 |
+
SYNTH_WEIGHTINGS = F.one_hot(torch.tensor(args.synthesis_layer), num_classes=25).float().to(device)[:, None]
|
| 40 |
+
MATCH_WEIGHTINGS = F.one_hot(torch.tensor(args.matching_layer), num_classes=25).float().to(device)[:, None]
|
| 41 |
+
|
| 42 |
+
print(f"Matching weightings: {MATCH_WEIGHTINGS.squeeze()}\nSynthesis weightings: {SYNTH_WEIGHTINGS.squeeze()}")
|
| 43 |
+
ls_df = make_librispeech_df(Path(args.librispeech_path))
|
| 44 |
+
|
| 45 |
+
print(f"Loading wavlm.")
|
| 46 |
+
wavlm = wavlm_large(pretrained=True, progress=True, device=args.device)
|
| 47 |
+
|
| 48 |
+
np.random.seed(args.seed)
|
| 49 |
+
torch.manual_seed(args.seed)
|
| 50 |
+
extract(ls_df, wavlm, args.device, Path(args.librispeech_path), Path(args.out_path), SYNTH_WEIGHTINGS, MATCH_WEIGHTINGS)
|
| 51 |
+
print("All done!", flush=True)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def path2pools(path: Path, wavlm: nn.Module(), match_weights: Tensor, synth_weights: Tensor, device):
|
| 55 |
+
"""Given a waveform `path`, compute the matching pool"""
|
| 56 |
+
|
| 57 |
+
uttrs_from_same_spk = sorted(list(path.parent.rglob('**/*.flac')))
|
| 58 |
+
uttrs_from_same_spk.remove(path)
|
| 59 |
+
matching_pool = []
|
| 60 |
+
synth_pool = []
|
| 61 |
+
for pth in uttrs_from_same_spk:
|
| 62 |
+
if pth in feature_cache and pth in synthesis_cache:
|
| 63 |
+
matching_feats = feature_cache[pth].float() # (seq_len, dim)
|
| 64 |
+
synth_feats = synthesis_cache[pth].float() # (seq_len, dim)
|
| 65 |
+
else:
|
| 66 |
+
feats = get_full_features(pth, wavlm, device)
|
| 67 |
+
matching_feats = ( feats*match_weights[:, None] ).sum(dim=0) # (seq_len, dim)
|
| 68 |
+
synth_feats = ( feats*synth_weights[:, None] ).sum(dim=0) # (seq_len, dim)
|
| 69 |
+
feature_cache[pth] = matching_feats.half().cpu()
|
| 70 |
+
synthesis_cache[pth] = synth_feats.half().cpu()
|
| 71 |
+
|
| 72 |
+
matching_pool.append(matching_feats.cpu())
|
| 73 |
+
synth_pool.append(synth_feats.cpu())
|
| 74 |
+
matching_pool = torch.concat(matching_pool, dim=0)
|
| 75 |
+
synth_pool = torch.concat(synth_pool, dim=0)
|
| 76 |
+
return matching_pool, synth_pool # (N, dim)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
@torch.inference_mode()
|
| 80 |
+
def get_full_features(path, wavlm, device):
|
| 81 |
+
|
| 82 |
+
x, sr = torchaudio.load(path)
|
| 83 |
+
assert sr == 16000
|
| 84 |
+
# This does not work i.t.o the hifigan training.
|
| 85 |
+
# x = F.pad(x, (DOWNSAMPLE_FACTOR//2, DOWNSAMPLE_FACTOR - DOWNSAMPLE_FACTOR//2), value=0)
|
| 86 |
+
# This does.
|
| 87 |
+
n_pad = DOWNSAMPLE_FACTOR - (x.shape[-1] % DOWNSAMPLE_FACTOR)
|
| 88 |
+
x = F.pad(x, (0, n_pad), value=0)
|
| 89 |
+
|
| 90 |
+
# extract the representation of each layer
|
| 91 |
+
wav_input_16khz = x.to(device)
|
| 92 |
+
rep, layer_results = wavlm.extract_features(wav_input_16khz, output_layer=wavlm.cfg.encoder_layers, ret_layer_results=True)[0]
|
| 93 |
+
features = torch.cat([x.transpose(0, 1) for x, _ in layer_results], dim=0) # (n_layers, seq_len, dim)
|
| 94 |
+
|
| 95 |
+
return features
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def fast_cosine_dist(source_feats, matching_pool):
|
| 99 |
+
source_norms = torch.norm(source_feats, p=2, dim=-1)
|
| 100 |
+
matching_norms = torch.norm(matching_pool, p=2, dim=-1)
|
| 101 |
+
dotprod = -torch.cdist(source_feats[None], matching_pool[None], p=2)[0]**2 + source_norms[:, None]**2 + matching_norms[None]**2
|
| 102 |
+
dotprod /= 2
|
| 103 |
+
|
| 104 |
+
dists = 1 - ( dotprod / (source_norms[:, None] * matching_norms[None]) )
|
| 105 |
+
return dists
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
@torch.inference_mode()
|
| 109 |
+
def extract(df: pd.DataFrame, wavlm: nn.Module, device, ls_path: Path, out_path: Path, synth_weights: Tensor, match_weights: Tensor):
|
| 110 |
+
|
| 111 |
+
pb = progress_bar(df.iterrows(), total=len(df))
|
| 112 |
+
|
| 113 |
+
for i, row in pb:
|
| 114 |
+
rel_path = Path(row.path).relative_to(ls_path)
|
| 115 |
+
targ_path = (out_path/rel_path).with_suffix('.pt')
|
| 116 |
+
if args.resume:
|
| 117 |
+
if targ_path.is_file(): continue
|
| 118 |
+
# if targ_path.is_file(): continue
|
| 119 |
+
os.makedirs(targ_path.parent, exist_ok=True)
|
| 120 |
+
|
| 121 |
+
if Path(row.path) in feature_cache:
|
| 122 |
+
source_feats = feature_cache[Path(row.path)].float()
|
| 123 |
+
else:
|
| 124 |
+
source_feats = get_full_features(row.path, wavlm, device)
|
| 125 |
+
source_feats = ( source_feats*match_weights[:, None] ).sum(dim=0) # (seq_len, dim)
|
| 126 |
+
|
| 127 |
+
matching_pool, synth_pool = path2pools(row.path, wavlm, match_weights, synth_weights, device)
|
| 128 |
+
|
| 129 |
+
if not args.prematch:
|
| 130 |
+
out_feats = source_feats.cpu()
|
| 131 |
+
else:
|
| 132 |
+
dists = fast_cosine_dist(source_feats.cpu(), matching_pool.cpu()).cpu()
|
| 133 |
+
best = dists.topk(k=args.topk, dim=-1, largest=False) # (src_len, 4)
|
| 134 |
+
out_feats = synth_pool[best.indices].mean(dim=1) # (N, dim)
|
| 135 |
+
|
| 136 |
+
# save matched sequence
|
| 137 |
+
if i < 3: print("Feature has shape: ", out_feats.shape, flush=True)
|
| 138 |
+
# 3. save
|
| 139 |
+
torch.save(out_feats.cpu().half(), str(targ_path))
|
| 140 |
+
if hasattr(pb, 'child'):
|
| 141 |
+
pb.child.comment = str(rel_path)
|
| 142 |
+
pb.child.wait_for = min(pb.child.wait_for, 10)
|
| 143 |
+
pb.main_bar.comment = str(rel_path)
|
| 144 |
+
else:
|
| 145 |
+
pb.wait_for = min(pb.wait_for, 10)
|
| 146 |
+
pb.comment = str(rel_path)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
if i % 1000 == 0:
|
| 150 |
+
print(f"Done {i:,d}/{len(df):,d}", flush=True)
|
| 151 |
+
feature_cache.clear()
|
| 152 |
+
synthesis_cache.clear()
|
| 153 |
+
gc.collect()
|
| 154 |
+
time.sleep(4)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
if __name__ == '__main__':
|
| 158 |
+
parser = argparse.ArgumentParser(description="Compute matched wavlm features for a librispeech dataset")
|
| 159 |
+
|
| 160 |
+
parser.add_argument('--librispeech_path', required=True, type=str)
|
| 161 |
+
parser.add_argument('--seed', default=123, type=int)
|
| 162 |
+
parser.add_argument('--out_path', required=True, type=str)
|
| 163 |
+
parser.add_argument('--device', default='cuda', type=str)
|
| 164 |
+
parser.add_argument('--topk', type=int, default=4)
|
| 165 |
+
parser.add_argument('--matching_layer', type=int, default=6)
|
| 166 |
+
parser.add_argument('--synthesis_layer', type=int, default=6)
|
| 167 |
+
parser.add_argument('--prematch', action='store_true', help='prematch')
|
| 168 |
+
parser.add_argument('--resume', action='store_true')
|
| 169 |
+
|
| 170 |
+
args = parser.parse_args()
|
| 171 |
+
main(args)
|
| 172 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchaudio
|
| 3 |
+
soundfile
|
| 4 |
+
gradio
|
| 5 |
+
spaces
|
wavlm/WavLM.py
ADDED
|
@@ -0,0 +1,743 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
|
| 3 |
+
# Github source: https://github.com/microsoft/unilm/tree/master/wavlm
|
| 4 |
+
# Copyright (c) 2021 Microsoft
|
| 5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 6 |
+
# Based on fairseq code bases
|
| 7 |
+
# https://github.com/pytorch/fairseq
|
| 8 |
+
# --------------------------------------------------------
|
| 9 |
+
|
| 10 |
+
import math
|
| 11 |
+
import logging
|
| 12 |
+
from typing import List, Optional, Tuple
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
import torch.nn.functional as F
|
| 19 |
+
from torch.nn import LayerNorm
|
| 20 |
+
from .modules import (
|
| 21 |
+
Fp32GroupNorm,
|
| 22 |
+
Fp32LayerNorm,
|
| 23 |
+
GradMultiply,
|
| 24 |
+
MultiheadAttention,
|
| 25 |
+
SamePad,
|
| 26 |
+
init_bert_params,
|
| 27 |
+
get_activation_fn,
|
| 28 |
+
TransposeLast,
|
| 29 |
+
GLU_Linear,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
logger = logging.getLogger(__name__)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def compute_mask_indices(
|
| 36 |
+
shape: Tuple[int, int],
|
| 37 |
+
padding_mask: Optional[torch.Tensor],
|
| 38 |
+
mask_prob: float,
|
| 39 |
+
mask_length: int,
|
| 40 |
+
mask_type: str = "static",
|
| 41 |
+
mask_other: float = 0.0,
|
| 42 |
+
min_masks: int = 0,
|
| 43 |
+
no_overlap: bool = False,
|
| 44 |
+
min_space: int = 0,
|
| 45 |
+
) -> np.ndarray:
|
| 46 |
+
"""
|
| 47 |
+
Computes random mask spans for a given shape
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
shape: the the shape for which to compute masks.
|
| 51 |
+
should be of size 2 where first element is batch size and 2nd is timesteps
|
| 52 |
+
padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
|
| 53 |
+
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
|
| 54 |
+
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
|
| 55 |
+
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
|
| 56 |
+
mask_type: how to compute mask lengths
|
| 57 |
+
static = fixed size
|
| 58 |
+
uniform = sample from uniform distribution [mask_other, mask_length*2]
|
| 59 |
+
normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
|
| 60 |
+
poisson = sample from possion distribution with lambda = mask length
|
| 61 |
+
min_masks: minimum number of masked spans
|
| 62 |
+
no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
|
| 63 |
+
min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
bsz, all_sz = shape
|
| 67 |
+
mask = np.full((bsz, all_sz), False)
|
| 68 |
+
|
| 69 |
+
all_num_mask = int(
|
| 70 |
+
# add a random number for probabilistic rounding
|
| 71 |
+
mask_prob * all_sz / float(mask_length)
|
| 72 |
+
+ np.random.rand()
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
all_num_mask = max(min_masks, all_num_mask)
|
| 76 |
+
|
| 77 |
+
mask_idcs = []
|
| 78 |
+
for i in range(bsz):
|
| 79 |
+
if padding_mask is not None:
|
| 80 |
+
sz = all_sz - padding_mask[i].long().sum().item()
|
| 81 |
+
num_mask = int(
|
| 82 |
+
# add a random number for probabilistic rounding
|
| 83 |
+
mask_prob * sz / float(mask_length)
|
| 84 |
+
+ np.random.rand()
|
| 85 |
+
)
|
| 86 |
+
num_mask = max(min_masks, num_mask)
|
| 87 |
+
else:
|
| 88 |
+
sz = all_sz
|
| 89 |
+
num_mask = all_num_mask
|
| 90 |
+
|
| 91 |
+
if mask_type == "static":
|
| 92 |
+
lengths = np.full(num_mask, mask_length)
|
| 93 |
+
elif mask_type == "uniform":
|
| 94 |
+
lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
|
| 95 |
+
elif mask_type == "normal":
|
| 96 |
+
lengths = np.random.normal(mask_length, mask_other, size=num_mask)
|
| 97 |
+
lengths = [max(1, int(round(x))) for x in lengths]
|
| 98 |
+
elif mask_type == "poisson":
|
| 99 |
+
lengths = np.random.poisson(mask_length, size=num_mask)
|
| 100 |
+
lengths = [int(round(x)) for x in lengths]
|
| 101 |
+
else:
|
| 102 |
+
raise Exception("unknown mask selection " + mask_type)
|
| 103 |
+
|
| 104 |
+
if sum(lengths) == 0:
|
| 105 |
+
lengths[0] = min(mask_length, sz - 1)
|
| 106 |
+
|
| 107 |
+
if no_overlap:
|
| 108 |
+
mask_idc = []
|
| 109 |
+
|
| 110 |
+
def arrange(s, e, length, keep_length):
|
| 111 |
+
span_start = np.random.randint(s, e - length)
|
| 112 |
+
mask_idc.extend(span_start + i for i in range(length))
|
| 113 |
+
|
| 114 |
+
new_parts = []
|
| 115 |
+
if span_start - s - min_space >= keep_length:
|
| 116 |
+
new_parts.append((s, span_start - min_space + 1))
|
| 117 |
+
if e - span_start - keep_length - min_space > keep_length:
|
| 118 |
+
new_parts.append((span_start + length + min_space, e))
|
| 119 |
+
return new_parts
|
| 120 |
+
|
| 121 |
+
parts = [(0, sz)]
|
| 122 |
+
min_length = min(lengths)
|
| 123 |
+
for length in sorted(lengths, reverse=True):
|
| 124 |
+
lens = np.fromiter(
|
| 125 |
+
(e - s if e - s >= length + min_space else 0 for s, e in parts),
|
| 126 |
+
np.int,
|
| 127 |
+
)
|
| 128 |
+
l_sum = np.sum(lens)
|
| 129 |
+
if l_sum == 0:
|
| 130 |
+
break
|
| 131 |
+
probs = lens / np.sum(lens)
|
| 132 |
+
c = np.random.choice(len(parts), p=probs)
|
| 133 |
+
s, e = parts.pop(c)
|
| 134 |
+
parts.extend(arrange(s, e, length, min_length))
|
| 135 |
+
mask_idc = np.asarray(mask_idc)
|
| 136 |
+
else:
|
| 137 |
+
min_len = min(lengths)
|
| 138 |
+
if sz - min_len <= num_mask:
|
| 139 |
+
min_len = sz - num_mask - 1
|
| 140 |
+
|
| 141 |
+
mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
|
| 142 |
+
|
| 143 |
+
mask_idc = np.asarray(
|
| 144 |
+
[
|
| 145 |
+
mask_idc[j] + offset
|
| 146 |
+
for j in range(len(mask_idc))
|
| 147 |
+
for offset in range(lengths[j])
|
| 148 |
+
]
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
|
| 152 |
+
|
| 153 |
+
min_len = min([len(m) for m in mask_idcs])
|
| 154 |
+
for i, mask_idc in enumerate(mask_idcs):
|
| 155 |
+
if len(mask_idc) > min_len:
|
| 156 |
+
mask_idc = np.random.choice(mask_idc, min_len, replace=False)
|
| 157 |
+
mask[i, mask_idc] = True
|
| 158 |
+
|
| 159 |
+
return mask
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class WavLMConfig:
|
| 163 |
+
def __init__(self, cfg=None):
|
| 164 |
+
self.extractor_mode: str = "default" # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True)
|
| 165 |
+
self.encoder_layers: int = 12 # num encoder layers in the transformer
|
| 166 |
+
|
| 167 |
+
self.encoder_embed_dim: int = 768 # encoder embedding dimension
|
| 168 |
+
self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
|
| 169 |
+
self.encoder_attention_heads: int = 12 # num encoder attention heads
|
| 170 |
+
self.activation_fn: str = "gelu" # activation function to use
|
| 171 |
+
|
| 172 |
+
self.layer_norm_first: bool = False # apply layernorm first in the transformer
|
| 173 |
+
self.conv_feature_layers: str = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]
|
| 174 |
+
self.conv_bias: bool = False # include bias in conv encoder
|
| 175 |
+
self.feature_grad_mult: float = 1.0 # multiply feature extractor var grads by this
|
| 176 |
+
|
| 177 |
+
self.normalize: bool = False # normalize input to have 0 mean and unit variance during training
|
| 178 |
+
|
| 179 |
+
# dropouts
|
| 180 |
+
self.dropout: float = 0.1 # dropout probability for the transformer
|
| 181 |
+
self.attention_dropout: float = 0.1 # dropout probability for attention weights
|
| 182 |
+
self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
|
| 183 |
+
self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
|
| 184 |
+
self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
|
| 185 |
+
self.dropout_features: float = 0.0 # dropout to apply to the features (after feat extr)
|
| 186 |
+
|
| 187 |
+
# masking
|
| 188 |
+
self.mask_length: int = 10 # mask length
|
| 189 |
+
self.mask_prob: float = 0.65 # probability of replacing a token with mask
|
| 190 |
+
self.mask_selection: str = "static" # how to choose mask length
|
| 191 |
+
self.mask_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh
|
| 192 |
+
self.no_mask_overlap: bool = False # whether to allow masks to overlap
|
| 193 |
+
self.mask_min_space: int = 1 # min space between spans (if no overlap is enabled)
|
| 194 |
+
|
| 195 |
+
# channel masking
|
| 196 |
+
self.mask_channel_length: int = 10 # length of the mask for features (channels)
|
| 197 |
+
self.mask_channel_prob: float = 0.0 # probability of replacing a feature with 0
|
| 198 |
+
self.mask_channel_selection: str = "static" # how to choose mask length for channel masking
|
| 199 |
+
self.mask_channel_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indices
|
| 200 |
+
self.no_mask_channel_overlap: bool = False # whether to allow channel masks to overlap
|
| 201 |
+
self.mask_channel_min_space: int = 1 # min space between spans (if no overlap is enabled)
|
| 202 |
+
|
| 203 |
+
# positional embeddings
|
| 204 |
+
self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
|
| 205 |
+
self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
|
| 206 |
+
|
| 207 |
+
# relative position embedding
|
| 208 |
+
self.relative_position_embedding: bool = False # apply relative position embedding
|
| 209 |
+
self.num_buckets: int = 320 # number of buckets for relative position embedding
|
| 210 |
+
self.max_distance: int = 1280 # maximum distance for relative position embedding
|
| 211 |
+
self.gru_rel_pos: bool = False # apply gated relative position embedding
|
| 212 |
+
|
| 213 |
+
if cfg is not None:
|
| 214 |
+
self.update(cfg)
|
| 215 |
+
|
| 216 |
+
def update(self, cfg: dict):
|
| 217 |
+
self.__dict__.update(cfg)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
class WavLM(nn.Module):
|
| 221 |
+
def __init__(
|
| 222 |
+
self,
|
| 223 |
+
cfg: WavLMConfig,
|
| 224 |
+
) -> None:
|
| 225 |
+
super().__init__()
|
| 226 |
+
logger.info(f"WavLM Config: {cfg.__dict__}")
|
| 227 |
+
|
| 228 |
+
self.cfg = cfg
|
| 229 |
+
feature_enc_layers = eval(cfg.conv_feature_layers)
|
| 230 |
+
self.embed = feature_enc_layers[-1][0]
|
| 231 |
+
|
| 232 |
+
self.feature_extractor = ConvFeatureExtractionModel(
|
| 233 |
+
conv_layers=feature_enc_layers,
|
| 234 |
+
dropout=0.0,
|
| 235 |
+
mode=cfg.extractor_mode,
|
| 236 |
+
conv_bias=cfg.conv_bias,
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
self.post_extract_proj = (
|
| 240 |
+
nn.Linear(self.embed, cfg.encoder_embed_dim)
|
| 241 |
+
if self.embed != cfg.encoder_embed_dim
|
| 242 |
+
else None
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
self.mask_prob = cfg.mask_prob
|
| 246 |
+
self.mask_selection = cfg.mask_selection
|
| 247 |
+
self.mask_other = cfg.mask_other
|
| 248 |
+
self.mask_length = cfg.mask_length
|
| 249 |
+
self.no_mask_overlap = cfg.no_mask_overlap
|
| 250 |
+
self.mask_min_space = cfg.mask_min_space
|
| 251 |
+
|
| 252 |
+
self.mask_channel_prob = cfg.mask_channel_prob
|
| 253 |
+
self.mask_channel_selection = cfg.mask_channel_selection
|
| 254 |
+
self.mask_channel_other = cfg.mask_channel_other
|
| 255 |
+
self.mask_channel_length = cfg.mask_channel_length
|
| 256 |
+
self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
|
| 257 |
+
self.mask_channel_min_space = cfg.mask_channel_min_space
|
| 258 |
+
|
| 259 |
+
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
| 260 |
+
self.dropout_features = nn.Dropout(cfg.dropout_features)
|
| 261 |
+
|
| 262 |
+
self.feature_grad_mult = cfg.feature_grad_mult
|
| 263 |
+
|
| 264 |
+
self.mask_emb = nn.Parameter(
|
| 265 |
+
torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
self.encoder = TransformerEncoder(cfg)
|
| 269 |
+
self.layer_norm = LayerNorm(self.embed)
|
| 270 |
+
|
| 271 |
+
def apply_mask(self, x, padding_mask):
|
| 272 |
+
B, T, C = x.shape
|
| 273 |
+
if self.mask_prob > 0:
|
| 274 |
+
mask_indices = compute_mask_indices(
|
| 275 |
+
(B, T),
|
| 276 |
+
padding_mask,
|
| 277 |
+
self.mask_prob,
|
| 278 |
+
self.mask_length,
|
| 279 |
+
self.mask_selection,
|
| 280 |
+
self.mask_other,
|
| 281 |
+
min_masks=2,
|
| 282 |
+
no_overlap=self.no_mask_overlap,
|
| 283 |
+
min_space=self.mask_min_space,
|
| 284 |
+
)
|
| 285 |
+
mask_indices = torch.from_numpy(mask_indices).to(x.device)
|
| 286 |
+
x[mask_indices] = self.mask_emb
|
| 287 |
+
else:
|
| 288 |
+
mask_indices = None
|
| 289 |
+
|
| 290 |
+
if self.mask_channel_prob > 0:
|
| 291 |
+
mask_channel_indices = compute_mask_indices(
|
| 292 |
+
(B, C),
|
| 293 |
+
None,
|
| 294 |
+
self.mask_channel_prob,
|
| 295 |
+
self.mask_channel_length,
|
| 296 |
+
self.mask_channel_selection,
|
| 297 |
+
self.mask_channel_other,
|
| 298 |
+
no_overlap=self.no_mask_channel_overlap,
|
| 299 |
+
min_space=self.mask_channel_min_space,
|
| 300 |
+
)
|
| 301 |
+
mask_channel_indices = (
|
| 302 |
+
torch.from_numpy(mask_channel_indices)
|
| 303 |
+
.to(x.device)
|
| 304 |
+
.unsqueeze(1)
|
| 305 |
+
.expand(-1, T, -1)
|
| 306 |
+
)
|
| 307 |
+
x[mask_channel_indices] = 0
|
| 308 |
+
|
| 309 |
+
return x, mask_indices
|
| 310 |
+
|
| 311 |
+
def forward_padding_mask(
|
| 312 |
+
self, features: torch.Tensor, padding_mask: torch.Tensor,
|
| 313 |
+
) -> torch.Tensor:
|
| 314 |
+
extra = padding_mask.size(1) % features.size(1)
|
| 315 |
+
if extra > 0:
|
| 316 |
+
padding_mask = padding_mask[:, :-extra]
|
| 317 |
+
padding_mask = padding_mask.view(
|
| 318 |
+
padding_mask.size(0), features.size(1), -1
|
| 319 |
+
)
|
| 320 |
+
padding_mask = padding_mask.all(-1)
|
| 321 |
+
return padding_mask
|
| 322 |
+
|
| 323 |
+
def extract_features(
|
| 324 |
+
self,
|
| 325 |
+
source: torch.Tensor,
|
| 326 |
+
padding_mask: Optional[torch.Tensor] = None,
|
| 327 |
+
mask: bool = False,
|
| 328 |
+
ret_conv: bool = False,
|
| 329 |
+
output_layer: Optional[int] = None,
|
| 330 |
+
ret_layer_results: bool = False,
|
| 331 |
+
):
|
| 332 |
+
|
| 333 |
+
if self.feature_grad_mult > 0:
|
| 334 |
+
features = self.feature_extractor(source)
|
| 335 |
+
if self.feature_grad_mult != 1.0:
|
| 336 |
+
features = GradMultiply.apply(features, self.feature_grad_mult)
|
| 337 |
+
else:
|
| 338 |
+
with torch.no_grad():
|
| 339 |
+
features = self.feature_extractor(source)
|
| 340 |
+
|
| 341 |
+
features = features.transpose(1, 2)
|
| 342 |
+
features = self.layer_norm(features)
|
| 343 |
+
|
| 344 |
+
if padding_mask is not None:
|
| 345 |
+
padding_mask = self.forward_padding_mask(features, padding_mask)
|
| 346 |
+
|
| 347 |
+
if self.post_extract_proj is not None:
|
| 348 |
+
features = self.post_extract_proj(features)
|
| 349 |
+
|
| 350 |
+
features = self.dropout_input(features)
|
| 351 |
+
|
| 352 |
+
if mask:
|
| 353 |
+
x, mask_indices = self.apply_mask(
|
| 354 |
+
features, padding_mask
|
| 355 |
+
)
|
| 356 |
+
else:
|
| 357 |
+
x = features
|
| 358 |
+
|
| 359 |
+
# feature: (B, T, D), float
|
| 360 |
+
# target: (B, T), long
|
| 361 |
+
# x: (B, T, D), float
|
| 362 |
+
# padding_mask: (B, T), bool
|
| 363 |
+
# mask_indices: (B, T), bool
|
| 364 |
+
x, layer_results = self.encoder(
|
| 365 |
+
x,
|
| 366 |
+
padding_mask=padding_mask,
|
| 367 |
+
layer=None if output_layer is None else output_layer - 1
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
res = {"x": x, "padding_mask": padding_mask, "features": features, "layer_results": layer_results}
|
| 371 |
+
|
| 372 |
+
feature = res["features"] if ret_conv else res["x"]
|
| 373 |
+
if ret_layer_results:
|
| 374 |
+
feature = (feature, res["layer_results"])
|
| 375 |
+
return feature, res["padding_mask"]
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
class ConvFeatureExtractionModel(nn.Module):
|
| 379 |
+
def __init__(
|
| 380 |
+
self,
|
| 381 |
+
conv_layers: List[Tuple[int, int, int]],
|
| 382 |
+
dropout: float = 0.0,
|
| 383 |
+
mode: str = "default",
|
| 384 |
+
conv_bias: bool = False,
|
| 385 |
+
conv_type: str = "default"
|
| 386 |
+
):
|
| 387 |
+
super().__init__()
|
| 388 |
+
|
| 389 |
+
assert mode in {"default", "layer_norm"}
|
| 390 |
+
|
| 391 |
+
def block(
|
| 392 |
+
n_in,
|
| 393 |
+
n_out,
|
| 394 |
+
k,
|
| 395 |
+
stride,
|
| 396 |
+
is_layer_norm=False,
|
| 397 |
+
is_group_norm=False,
|
| 398 |
+
conv_bias=False,
|
| 399 |
+
):
|
| 400 |
+
def make_conv():
|
| 401 |
+
conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
|
| 402 |
+
nn.init.kaiming_normal_(conv.weight)
|
| 403 |
+
return conv
|
| 404 |
+
|
| 405 |
+
assert (
|
| 406 |
+
is_layer_norm and is_group_norm
|
| 407 |
+
) == False, "layer norm and group norm are exclusive"
|
| 408 |
+
|
| 409 |
+
if is_layer_norm:
|
| 410 |
+
return nn.Sequential(
|
| 411 |
+
make_conv(),
|
| 412 |
+
nn.Dropout(p=dropout),
|
| 413 |
+
nn.Sequential(
|
| 414 |
+
TransposeLast(),
|
| 415 |
+
Fp32LayerNorm(dim, elementwise_affine=True),
|
| 416 |
+
TransposeLast(),
|
| 417 |
+
),
|
| 418 |
+
nn.GELU(),
|
| 419 |
+
)
|
| 420 |
+
elif is_group_norm:
|
| 421 |
+
return nn.Sequential(
|
| 422 |
+
make_conv(),
|
| 423 |
+
nn.Dropout(p=dropout),
|
| 424 |
+
Fp32GroupNorm(dim, dim, affine=True),
|
| 425 |
+
nn.GELU(),
|
| 426 |
+
)
|
| 427 |
+
else:
|
| 428 |
+
return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
|
| 429 |
+
|
| 430 |
+
self.conv_type = conv_type
|
| 431 |
+
if self.conv_type == "default":
|
| 432 |
+
in_d = 1
|
| 433 |
+
self.conv_layers = nn.ModuleList()
|
| 434 |
+
for i, cl in enumerate(conv_layers):
|
| 435 |
+
assert len(cl) == 3, "invalid conv definition: " + str(cl)
|
| 436 |
+
(dim, k, stride) = cl
|
| 437 |
+
|
| 438 |
+
self.conv_layers.append(
|
| 439 |
+
block(
|
| 440 |
+
in_d,
|
| 441 |
+
dim,
|
| 442 |
+
k,
|
| 443 |
+
stride,
|
| 444 |
+
is_layer_norm=mode == "layer_norm",
|
| 445 |
+
is_group_norm=mode == "default" and i == 0,
|
| 446 |
+
conv_bias=conv_bias,
|
| 447 |
+
)
|
| 448 |
+
)
|
| 449 |
+
in_d = dim
|
| 450 |
+
elif self.conv_type == "conv2d":
|
| 451 |
+
in_d = 1
|
| 452 |
+
self.conv_layers = nn.ModuleList()
|
| 453 |
+
for i, cl in enumerate(conv_layers):
|
| 454 |
+
assert len(cl) == 3
|
| 455 |
+
(dim, k, stride) = cl
|
| 456 |
+
|
| 457 |
+
self.conv_layers.append(
|
| 458 |
+
torch.nn.Conv2d(in_d, dim, k, stride)
|
| 459 |
+
)
|
| 460 |
+
self.conv_layers.append(torch.nn.ReLU())
|
| 461 |
+
in_d = dim
|
| 462 |
+
elif self.conv_type == "custom":
|
| 463 |
+
in_d = 1
|
| 464 |
+
idim = 80
|
| 465 |
+
self.conv_layers = nn.ModuleList()
|
| 466 |
+
for i, cl in enumerate(conv_layers):
|
| 467 |
+
assert len(cl) == 3
|
| 468 |
+
(dim, k, stride) = cl
|
| 469 |
+
self.conv_layers.append(
|
| 470 |
+
torch.nn.Conv2d(in_d, dim, k, stride, padding=1)
|
| 471 |
+
)
|
| 472 |
+
self.conv_layers.append(
|
| 473 |
+
torch.nn.LayerNorm([dim, idim])
|
| 474 |
+
)
|
| 475 |
+
self.conv_layers.append(torch.nn.ReLU())
|
| 476 |
+
in_d = dim
|
| 477 |
+
if (i + 1) % 2 == 0:
|
| 478 |
+
self.conv_layers.append(
|
| 479 |
+
torch.nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 480 |
+
)
|
| 481 |
+
idim = int(math.ceil(idim / 2))
|
| 482 |
+
else:
|
| 483 |
+
pass
|
| 484 |
+
|
| 485 |
+
def forward(self, x, mask=None):
|
| 486 |
+
|
| 487 |
+
# BxT -> BxCxT
|
| 488 |
+
x = x.unsqueeze(1)
|
| 489 |
+
if self.conv_type == "custom":
|
| 490 |
+
for conv in self.conv_layers:
|
| 491 |
+
if isinstance(conv, nn.LayerNorm):
|
| 492 |
+
x = x.transpose(1, 2)
|
| 493 |
+
x = conv(x).transpose(1, 2)
|
| 494 |
+
else:
|
| 495 |
+
x = conv(x)
|
| 496 |
+
x = x.transpose(2, 3).contiguous()
|
| 497 |
+
x = x.view(x.size(0), -1, x.size(-1))
|
| 498 |
+
else:
|
| 499 |
+
for conv in self.conv_layers:
|
| 500 |
+
x = conv(x)
|
| 501 |
+
if self.conv_type == "conv2d":
|
| 502 |
+
b, c, t, f = x.size()
|
| 503 |
+
x = x.transpose(2, 3).contiguous().view(b, c * f, t)
|
| 504 |
+
return x
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
class TransformerEncoder(nn.Module):
|
| 508 |
+
def __init__(self, args):
|
| 509 |
+
super().__init__()
|
| 510 |
+
|
| 511 |
+
self.dropout = args.dropout
|
| 512 |
+
self.embedding_dim = args.encoder_embed_dim
|
| 513 |
+
|
| 514 |
+
self.pos_conv = nn.Conv1d(
|
| 515 |
+
self.embedding_dim,
|
| 516 |
+
self.embedding_dim,
|
| 517 |
+
kernel_size=args.conv_pos,
|
| 518 |
+
padding=args.conv_pos // 2,
|
| 519 |
+
groups=args.conv_pos_groups,
|
| 520 |
+
)
|
| 521 |
+
dropout = 0
|
| 522 |
+
std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
|
| 523 |
+
nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
|
| 524 |
+
nn.init.constant_(self.pos_conv.bias, 0)
|
| 525 |
+
|
| 526 |
+
self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
|
| 527 |
+
self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
|
| 528 |
+
|
| 529 |
+
if hasattr(args, "relative_position_embedding"):
|
| 530 |
+
self.relative_position_embedding = args.relative_position_embedding
|
| 531 |
+
self.num_buckets = args.num_buckets
|
| 532 |
+
self.max_distance = args.max_distance
|
| 533 |
+
else:
|
| 534 |
+
self.relative_position_embedding = False
|
| 535 |
+
self.num_buckets = 0
|
| 536 |
+
self.max_distance = 0
|
| 537 |
+
|
| 538 |
+
self.layers = nn.ModuleList(
|
| 539 |
+
[
|
| 540 |
+
TransformerSentenceEncoderLayer(
|
| 541 |
+
embedding_dim=self.embedding_dim,
|
| 542 |
+
ffn_embedding_dim=args.encoder_ffn_embed_dim,
|
| 543 |
+
num_attention_heads=args.encoder_attention_heads,
|
| 544 |
+
dropout=self.dropout,
|
| 545 |
+
attention_dropout=args.attention_dropout,
|
| 546 |
+
activation_dropout=args.activation_dropout,
|
| 547 |
+
activation_fn=args.activation_fn,
|
| 548 |
+
layer_norm_first=args.layer_norm_first,
|
| 549 |
+
has_relative_attention_bias=(self.relative_position_embedding and i == 0),
|
| 550 |
+
num_buckets=self.num_buckets,
|
| 551 |
+
max_distance=self.max_distance,
|
| 552 |
+
gru_rel_pos=args.gru_rel_pos,
|
| 553 |
+
)
|
| 554 |
+
for i in range(args.encoder_layers)
|
| 555 |
+
]
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
self.layer_norm_first = args.layer_norm_first
|
| 559 |
+
self.layer_norm = LayerNorm(self.embedding_dim)
|
| 560 |
+
self.layerdrop = args.encoder_layerdrop
|
| 561 |
+
|
| 562 |
+
self.apply(init_bert_params)
|
| 563 |
+
|
| 564 |
+
def forward(self, x, padding_mask=None, streaming_mask=None, layer=None):
|
| 565 |
+
x, layer_results = self.extract_features(x, padding_mask, streaming_mask, layer)
|
| 566 |
+
|
| 567 |
+
if self.layer_norm_first and layer is None:
|
| 568 |
+
x = self.layer_norm(x)
|
| 569 |
+
|
| 570 |
+
return x, layer_results
|
| 571 |
+
|
| 572 |
+
def extract_features(self, x, padding_mask=None, streaming_mask=None, tgt_layer=None):
|
| 573 |
+
|
| 574 |
+
if padding_mask is not None:
|
| 575 |
+
x[padding_mask] = 0
|
| 576 |
+
|
| 577 |
+
x_conv = self.pos_conv(x.transpose(1, 2))
|
| 578 |
+
x_conv = x_conv.transpose(1, 2)
|
| 579 |
+
x += x_conv
|
| 580 |
+
|
| 581 |
+
if not self.layer_norm_first:
|
| 582 |
+
x = self.layer_norm(x)
|
| 583 |
+
|
| 584 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
| 585 |
+
|
| 586 |
+
# B x T x C -> T x B x C
|
| 587 |
+
x = x.transpose(0, 1)
|
| 588 |
+
|
| 589 |
+
layer_results = []
|
| 590 |
+
z = None
|
| 591 |
+
if tgt_layer is not None:
|
| 592 |
+
layer_results.append((x, z))
|
| 593 |
+
r = None
|
| 594 |
+
pos_bias = None
|
| 595 |
+
for i, layer in enumerate(self.layers):
|
| 596 |
+
dropout_probability = np.random.random()
|
| 597 |
+
if not self.training or (dropout_probability > self.layerdrop):
|
| 598 |
+
x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False,
|
| 599 |
+
self_attn_mask=streaming_mask, pos_bias=pos_bias)
|
| 600 |
+
if tgt_layer is not None:
|
| 601 |
+
layer_results.append((x, z))
|
| 602 |
+
if i == tgt_layer:
|
| 603 |
+
r = x
|
| 604 |
+
break
|
| 605 |
+
|
| 606 |
+
if r is not None:
|
| 607 |
+
x = r
|
| 608 |
+
|
| 609 |
+
# T x B x C -> B x T x C
|
| 610 |
+
x = x.transpose(0, 1)
|
| 611 |
+
|
| 612 |
+
return x, layer_results
|
| 613 |
+
|
| 614 |
+
|
| 615 |
+
class TransformerSentenceEncoderLayer(nn.Module):
|
| 616 |
+
"""
|
| 617 |
+
Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
|
| 618 |
+
models.
|
| 619 |
+
"""
|
| 620 |
+
|
| 621 |
+
def __init__(
|
| 622 |
+
self,
|
| 623 |
+
embedding_dim: float = 768,
|
| 624 |
+
ffn_embedding_dim: float = 3072,
|
| 625 |
+
num_attention_heads: float = 8,
|
| 626 |
+
dropout: float = 0.1,
|
| 627 |
+
attention_dropout: float = 0.1,
|
| 628 |
+
activation_dropout: float = 0.1,
|
| 629 |
+
activation_fn: str = "relu",
|
| 630 |
+
layer_norm_first: bool = False,
|
| 631 |
+
has_relative_attention_bias: bool = False,
|
| 632 |
+
num_buckets: int = 0,
|
| 633 |
+
max_distance: int = 0,
|
| 634 |
+
rescale_init: bool = False,
|
| 635 |
+
gru_rel_pos: bool = False,
|
| 636 |
+
) -> None:
|
| 637 |
+
|
| 638 |
+
super().__init__()
|
| 639 |
+
# Initialize parameters
|
| 640 |
+
self.embedding_dim = embedding_dim
|
| 641 |
+
self.dropout = dropout
|
| 642 |
+
self.activation_dropout = activation_dropout
|
| 643 |
+
|
| 644 |
+
# Initialize blocks
|
| 645 |
+
self.activation_name = activation_fn
|
| 646 |
+
self.activation_fn = get_activation_fn(activation_fn)
|
| 647 |
+
self.self_attn = MultiheadAttention(
|
| 648 |
+
self.embedding_dim,
|
| 649 |
+
num_attention_heads,
|
| 650 |
+
dropout=attention_dropout,
|
| 651 |
+
self_attention=True,
|
| 652 |
+
has_relative_attention_bias=has_relative_attention_bias,
|
| 653 |
+
num_buckets=num_buckets,
|
| 654 |
+
max_distance=max_distance,
|
| 655 |
+
rescale_init=rescale_init,
|
| 656 |
+
gru_rel_pos=gru_rel_pos,
|
| 657 |
+
)
|
| 658 |
+
|
| 659 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 660 |
+
self.dropout2 = nn.Dropout(self.activation_dropout)
|
| 661 |
+
self.dropout3 = nn.Dropout(dropout)
|
| 662 |
+
|
| 663 |
+
self.layer_norm_first = layer_norm_first
|
| 664 |
+
|
| 665 |
+
# layer norm associated with the self attention layer
|
| 666 |
+
self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
|
| 667 |
+
|
| 668 |
+
if self.activation_name == "glu":
|
| 669 |
+
self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish")
|
| 670 |
+
else:
|
| 671 |
+
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
|
| 672 |
+
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
|
| 673 |
+
|
| 674 |
+
# layer norm associated with the position wise feed-forward NN
|
| 675 |
+
self.final_layer_norm = LayerNorm(self.embedding_dim)
|
| 676 |
+
|
| 677 |
+
def forward(
|
| 678 |
+
self,
|
| 679 |
+
x: torch.Tensor,
|
| 680 |
+
self_attn_mask: torch.Tensor = None,
|
| 681 |
+
self_attn_padding_mask: torch.Tensor = None,
|
| 682 |
+
need_weights: bool = False,
|
| 683 |
+
pos_bias=None
|
| 684 |
+
):
|
| 685 |
+
"""
|
| 686 |
+
LayerNorm is applied either before or after the self-attention/ffn
|
| 687 |
+
modules similar to the original Transformer imlementation.
|
| 688 |
+
"""
|
| 689 |
+
residual = x
|
| 690 |
+
|
| 691 |
+
if self.layer_norm_first:
|
| 692 |
+
x = self.self_attn_layer_norm(x)
|
| 693 |
+
x, attn, pos_bias = self.self_attn(
|
| 694 |
+
query=x,
|
| 695 |
+
key=x,
|
| 696 |
+
value=x,
|
| 697 |
+
key_padding_mask=self_attn_padding_mask,
|
| 698 |
+
need_weights=False,
|
| 699 |
+
attn_mask=self_attn_mask,
|
| 700 |
+
position_bias=pos_bias
|
| 701 |
+
)
|
| 702 |
+
x = self.dropout1(x)
|
| 703 |
+
x = residual + x
|
| 704 |
+
|
| 705 |
+
residual = x
|
| 706 |
+
x = self.final_layer_norm(x)
|
| 707 |
+
if self.activation_name == "glu":
|
| 708 |
+
x = self.fc1(x)
|
| 709 |
+
else:
|
| 710 |
+
x = self.activation_fn(self.fc1(x))
|
| 711 |
+
x = self.dropout2(x)
|
| 712 |
+
x = self.fc2(x)
|
| 713 |
+
x = self.dropout3(x)
|
| 714 |
+
x = residual + x
|
| 715 |
+
else:
|
| 716 |
+
x, attn, pos_bias = self.self_attn(
|
| 717 |
+
query=x,
|
| 718 |
+
key=x,
|
| 719 |
+
value=x,
|
| 720 |
+
key_padding_mask=self_attn_padding_mask,
|
| 721 |
+
need_weights=need_weights,
|
| 722 |
+
attn_mask=self_attn_mask,
|
| 723 |
+
position_bias=pos_bias
|
| 724 |
+
)
|
| 725 |
+
|
| 726 |
+
x = self.dropout1(x)
|
| 727 |
+
x = residual + x
|
| 728 |
+
|
| 729 |
+
x = self.self_attn_layer_norm(x)
|
| 730 |
+
|
| 731 |
+
residual = x
|
| 732 |
+
if self.activation_name == "glu":
|
| 733 |
+
x = self.fc1(x)
|
| 734 |
+
else:
|
| 735 |
+
x = self.activation_fn(self.fc1(x))
|
| 736 |
+
x = self.dropout2(x)
|
| 737 |
+
x = self.fc2(x)
|
| 738 |
+
x = self.dropout3(x)
|
| 739 |
+
x = residual + x
|
| 740 |
+
x = self.final_layer_norm(x)
|
| 741 |
+
|
| 742 |
+
return x, attn, pos_bias
|
| 743 |
+
|
wavlm/modules.py
ADDED
|
@@ -0,0 +1,827 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
|
| 3 |
+
# Github source: https://github.com/microsoft/unilm/tree/master/wavlm
|
| 4 |
+
# Copyright (c) 2021 Microsoft
|
| 5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 6 |
+
# Based on fairseq code bases
|
| 7 |
+
# https://github.com/pytorch/fairseq
|
| 8 |
+
# --------------------------------------------------------
|
| 9 |
+
|
| 10 |
+
import math
|
| 11 |
+
import warnings
|
| 12 |
+
from typing import Dict, Optional, Tuple
|
| 13 |
+
import torch
|
| 14 |
+
from torch import Tensor, nn
|
| 15 |
+
from torch.nn import Parameter
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class TransposeLast(nn.Module):
|
| 20 |
+
def __init__(self, deconstruct_idx=None):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.deconstruct_idx = deconstruct_idx
|
| 23 |
+
|
| 24 |
+
def forward(self, x):
|
| 25 |
+
if self.deconstruct_idx is not None:
|
| 26 |
+
x = x[self.deconstruct_idx]
|
| 27 |
+
return x.transpose(-2, -1)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class Fp32LayerNorm(nn.LayerNorm):
|
| 31 |
+
def __init__(self, *args, **kwargs):
|
| 32 |
+
super().__init__(*args, **kwargs)
|
| 33 |
+
|
| 34 |
+
def forward(self, input):
|
| 35 |
+
output = F.layer_norm(
|
| 36 |
+
input.float(),
|
| 37 |
+
self.normalized_shape,
|
| 38 |
+
self.weight.float() if self.weight is not None else None,
|
| 39 |
+
self.bias.float() if self.bias is not None else None,
|
| 40 |
+
self.eps,
|
| 41 |
+
)
|
| 42 |
+
return output.type_as(input)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class Fp32GroupNorm(nn.GroupNorm):
|
| 46 |
+
def __init__(self, *args, **kwargs):
|
| 47 |
+
super().__init__(*args, **kwargs)
|
| 48 |
+
|
| 49 |
+
def forward(self, input):
|
| 50 |
+
output = F.group_norm(
|
| 51 |
+
input.float(),
|
| 52 |
+
self.num_groups,
|
| 53 |
+
self.weight.float() if self.weight is not None else None,
|
| 54 |
+
self.bias.float() if self.bias is not None else None,
|
| 55 |
+
self.eps,
|
| 56 |
+
)
|
| 57 |
+
return output.type_as(input)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class GradMultiply(torch.autograd.Function):
|
| 61 |
+
@staticmethod
|
| 62 |
+
def forward(ctx, x, scale):
|
| 63 |
+
ctx.scale = scale
|
| 64 |
+
res = x.new(x)
|
| 65 |
+
return res
|
| 66 |
+
|
| 67 |
+
@staticmethod
|
| 68 |
+
def backward(ctx, grad):
|
| 69 |
+
return grad * ctx.scale, None
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class SamePad(nn.Module):
|
| 73 |
+
def __init__(self, kernel_size, causal=False):
|
| 74 |
+
super().__init__()
|
| 75 |
+
if causal:
|
| 76 |
+
self.remove = kernel_size - 1
|
| 77 |
+
else:
|
| 78 |
+
self.remove = 1 if kernel_size % 2 == 0 else 0
|
| 79 |
+
|
| 80 |
+
def forward(self, x):
|
| 81 |
+
if self.remove > 0:
|
| 82 |
+
x = x[:, :, : -self.remove]
|
| 83 |
+
return x
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class Swish(nn.Module):
|
| 87 |
+
"""Swish function
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
def __init__(self):
|
| 91 |
+
"""Construct an MultiHeadedAttention object."""
|
| 92 |
+
super(Swish, self).__init__()
|
| 93 |
+
self.act = torch.nn.Sigmoid()
|
| 94 |
+
|
| 95 |
+
def forward(self, x):
|
| 96 |
+
return x * self.act(x)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class GLU_Linear(nn.Module):
|
| 100 |
+
def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True):
|
| 101 |
+
super(GLU_Linear, self).__init__()
|
| 102 |
+
|
| 103 |
+
self.glu_type = glu_type
|
| 104 |
+
self.output_dim = output_dim
|
| 105 |
+
|
| 106 |
+
if glu_type == "sigmoid":
|
| 107 |
+
self.glu_act = torch.nn.Sigmoid()
|
| 108 |
+
elif glu_type == "swish":
|
| 109 |
+
self.glu_act = Swish()
|
| 110 |
+
elif glu_type == "relu":
|
| 111 |
+
self.glu_act = torch.nn.ReLU()
|
| 112 |
+
elif glu_type == "gelu":
|
| 113 |
+
self.glu_act = torch.nn.GELU()
|
| 114 |
+
|
| 115 |
+
if bias_in_glu:
|
| 116 |
+
self.linear = nn.Linear(input_dim, output_dim * 2, True)
|
| 117 |
+
else:
|
| 118 |
+
self.linear = nn.Linear(input_dim, output_dim * 2, False)
|
| 119 |
+
|
| 120 |
+
def forward(self, x):
|
| 121 |
+
# to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
|
| 122 |
+
x = self.linear(x)
|
| 123 |
+
|
| 124 |
+
if self.glu_type == "bilinear":
|
| 125 |
+
x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2])
|
| 126 |
+
else:
|
| 127 |
+
x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2]))
|
| 128 |
+
|
| 129 |
+
return x
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def gelu_accurate(x):
|
| 133 |
+
if not hasattr(gelu_accurate, "_a"):
|
| 134 |
+
gelu_accurate._a = math.sqrt(2 / math.pi)
|
| 135 |
+
return (
|
| 136 |
+
0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def gelu(x: torch.Tensor) -> torch.Tensor:
|
| 141 |
+
return torch.nn.functional.gelu(x.float()).type_as(x)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def get_activation_fn(activation: str):
|
| 145 |
+
"""Returns the activation function corresponding to `activation`"""
|
| 146 |
+
|
| 147 |
+
if activation == "relu":
|
| 148 |
+
return F.relu
|
| 149 |
+
elif activation == "gelu":
|
| 150 |
+
return gelu
|
| 151 |
+
elif activation == "gelu_fast":
|
| 152 |
+
warnings.warn(
|
| 153 |
+
"--activation-fn=gelu_fast has been renamed to gelu_accurate"
|
| 154 |
+
)
|
| 155 |
+
return gelu_accurate
|
| 156 |
+
elif activation == "gelu_accurate":
|
| 157 |
+
return gelu_accurate
|
| 158 |
+
elif activation == "tanh":
|
| 159 |
+
return torch.tanh
|
| 160 |
+
elif activation == "linear":
|
| 161 |
+
return lambda x: x
|
| 162 |
+
elif activation == "glu":
|
| 163 |
+
return lambda x: x
|
| 164 |
+
else:
|
| 165 |
+
raise RuntimeError("--activation-fn {} not supported".format(activation))
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def init_bert_params(module):
|
| 169 |
+
"""
|
| 170 |
+
Initialize the weights specific to the BERT Model.
|
| 171 |
+
This overrides the default initializations depending on the specified arguments.
|
| 172 |
+
1. If normal_init_linear_weights is set then weights of linear
|
| 173 |
+
layer will be initialized using the normal distribution and
|
| 174 |
+
bais will be set to the specified value.
|
| 175 |
+
2. If normal_init_embed_weights is set then weights of embedding
|
| 176 |
+
layer will be initialized using the normal distribution.
|
| 177 |
+
3. If normal_init_proj_weights is set then weights of
|
| 178 |
+
in_project_weight for MultiHeadAttention initialized using
|
| 179 |
+
the normal distribution (to be validated).
|
| 180 |
+
"""
|
| 181 |
+
|
| 182 |
+
def normal_(data):
|
| 183 |
+
# with FSDP, module params will be on CUDA, so we cast them back to CPU
|
| 184 |
+
# so that the RNG is consistent with and without FSDP
|
| 185 |
+
data.copy_(
|
| 186 |
+
data.cpu().normal_(mean=0.0, std=0.02).to(data.device)
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
if isinstance(module, nn.Linear):
|
| 190 |
+
normal_(module.weight.data)
|
| 191 |
+
if module.bias is not None:
|
| 192 |
+
module.bias.data.zero_()
|
| 193 |
+
if isinstance(module, nn.Embedding):
|
| 194 |
+
normal_(module.weight.data)
|
| 195 |
+
if module.padding_idx is not None:
|
| 196 |
+
module.weight.data[module.padding_idx].zero_()
|
| 197 |
+
if isinstance(module, MultiheadAttention):
|
| 198 |
+
normal_(module.q_proj.weight.data)
|
| 199 |
+
normal_(module.k_proj.weight.data)
|
| 200 |
+
normal_(module.v_proj.weight.data)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def quant_noise(module, p, block_size):
|
| 204 |
+
"""
|
| 205 |
+
Wraps modules and applies quantization noise to the weights for
|
| 206 |
+
subsequent quantization with Iterative Product Quantization as
|
| 207 |
+
described in "Training with Quantization Noise for Extreme Model Compression"
|
| 208 |
+
|
| 209 |
+
Args:
|
| 210 |
+
- module: nn.Module
|
| 211 |
+
- p: amount of Quantization Noise
|
| 212 |
+
- block_size: size of the blocks for subsequent quantization with iPQ
|
| 213 |
+
|
| 214 |
+
Remarks:
|
| 215 |
+
- Module weights must have the right sizes wrt the block size
|
| 216 |
+
- Only Linear, Embedding and Conv2d modules are supported for the moment
|
| 217 |
+
- For more detail on how to quantize by blocks with convolutional weights,
|
| 218 |
+
see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
|
| 219 |
+
- We implement the simplest form of noise here as stated in the paper
|
| 220 |
+
which consists in randomly dropping blocks
|
| 221 |
+
"""
|
| 222 |
+
|
| 223 |
+
# if no quantization noise, don't register hook
|
| 224 |
+
if p <= 0:
|
| 225 |
+
return module
|
| 226 |
+
|
| 227 |
+
# supported modules
|
| 228 |
+
assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
|
| 229 |
+
|
| 230 |
+
# test whether module.weight has the right sizes wrt block_size
|
| 231 |
+
is_conv = module.weight.ndim == 4
|
| 232 |
+
|
| 233 |
+
# 2D matrix
|
| 234 |
+
if not is_conv:
|
| 235 |
+
assert (
|
| 236 |
+
module.weight.size(1) % block_size == 0
|
| 237 |
+
), "Input features must be a multiple of block sizes"
|
| 238 |
+
|
| 239 |
+
# 4D matrix
|
| 240 |
+
else:
|
| 241 |
+
# 1x1 convolutions
|
| 242 |
+
if module.kernel_size == (1, 1):
|
| 243 |
+
assert (
|
| 244 |
+
module.in_channels % block_size == 0
|
| 245 |
+
), "Input channels must be a multiple of block sizes"
|
| 246 |
+
# regular convolutions
|
| 247 |
+
else:
|
| 248 |
+
k = module.kernel_size[0] * module.kernel_size[1]
|
| 249 |
+
assert k % block_size == 0, "Kernel size must be a multiple of block size"
|
| 250 |
+
|
| 251 |
+
def _forward_pre_hook(mod, input):
|
| 252 |
+
# no noise for evaluation
|
| 253 |
+
if mod.training:
|
| 254 |
+
if not is_conv:
|
| 255 |
+
# gather weight and sizes
|
| 256 |
+
weight = mod.weight
|
| 257 |
+
in_features = weight.size(1)
|
| 258 |
+
out_features = weight.size(0)
|
| 259 |
+
|
| 260 |
+
# split weight matrix into blocks and randomly drop selected blocks
|
| 261 |
+
mask = torch.zeros(
|
| 262 |
+
in_features // block_size * out_features, device=weight.device
|
| 263 |
+
)
|
| 264 |
+
mask.bernoulli_(p)
|
| 265 |
+
mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
|
| 266 |
+
|
| 267 |
+
else:
|
| 268 |
+
# gather weight and sizes
|
| 269 |
+
weight = mod.weight
|
| 270 |
+
in_channels = mod.in_channels
|
| 271 |
+
out_channels = mod.out_channels
|
| 272 |
+
|
| 273 |
+
# split weight matrix into blocks and randomly drop selected blocks
|
| 274 |
+
if mod.kernel_size == (1, 1):
|
| 275 |
+
mask = torch.zeros(
|
| 276 |
+
int(in_channels // block_size * out_channels),
|
| 277 |
+
device=weight.device,
|
| 278 |
+
)
|
| 279 |
+
mask.bernoulli_(p)
|
| 280 |
+
mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
|
| 281 |
+
else:
|
| 282 |
+
mask = torch.zeros(
|
| 283 |
+
weight.size(0), weight.size(1), device=weight.device
|
| 284 |
+
)
|
| 285 |
+
mask.bernoulli_(p)
|
| 286 |
+
mask = (
|
| 287 |
+
mask.unsqueeze(2)
|
| 288 |
+
.unsqueeze(3)
|
| 289 |
+
.repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
# scale weights and apply mask
|
| 293 |
+
mask = mask.to(
|
| 294 |
+
torch.bool
|
| 295 |
+
) # x.bool() is not currently supported in TorchScript
|
| 296 |
+
s = 1 / (1 - p)
|
| 297 |
+
mod.weight.data = s * weight.masked_fill(mask, 0)
|
| 298 |
+
|
| 299 |
+
module.register_forward_pre_hook(_forward_pre_hook)
|
| 300 |
+
return module
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
class MultiheadAttention(nn.Module):
|
| 304 |
+
"""Multi-headed attention.
|
| 305 |
+
|
| 306 |
+
See "Attention Is All You Need" for more details.
|
| 307 |
+
"""
|
| 308 |
+
|
| 309 |
+
def __init__(
|
| 310 |
+
self,
|
| 311 |
+
embed_dim,
|
| 312 |
+
num_heads,
|
| 313 |
+
kdim=None,
|
| 314 |
+
vdim=None,
|
| 315 |
+
dropout=0.0,
|
| 316 |
+
bias=True,
|
| 317 |
+
add_bias_kv=False,
|
| 318 |
+
add_zero_attn=False,
|
| 319 |
+
self_attention=False,
|
| 320 |
+
encoder_decoder_attention=False,
|
| 321 |
+
q_noise=0.0,
|
| 322 |
+
qn_block_size=8,
|
| 323 |
+
has_relative_attention_bias=False,
|
| 324 |
+
num_buckets=32,
|
| 325 |
+
max_distance=128,
|
| 326 |
+
gru_rel_pos=False,
|
| 327 |
+
rescale_init=False,
|
| 328 |
+
):
|
| 329 |
+
super().__init__()
|
| 330 |
+
self.embed_dim = embed_dim
|
| 331 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
| 332 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
| 333 |
+
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
| 334 |
+
|
| 335 |
+
self.num_heads = num_heads
|
| 336 |
+
self.dropout_module = nn.Dropout(dropout)
|
| 337 |
+
|
| 338 |
+
self.has_relative_attention_bias = has_relative_attention_bias
|
| 339 |
+
self.num_buckets = num_buckets
|
| 340 |
+
self.max_distance = max_distance
|
| 341 |
+
if self.has_relative_attention_bias:
|
| 342 |
+
self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
|
| 343 |
+
|
| 344 |
+
self.head_dim = embed_dim // num_heads
|
| 345 |
+
self.q_head_dim = self.head_dim
|
| 346 |
+
self.k_head_dim = self.head_dim
|
| 347 |
+
assert (
|
| 348 |
+
self.head_dim * num_heads == self.embed_dim
|
| 349 |
+
), "embed_dim must be divisible by num_heads"
|
| 350 |
+
self.scaling = self.head_dim ** -0.5
|
| 351 |
+
|
| 352 |
+
self.self_attention = self_attention
|
| 353 |
+
self.encoder_decoder_attention = encoder_decoder_attention
|
| 354 |
+
|
| 355 |
+
assert not self.self_attention or self.qkv_same_dim, (
|
| 356 |
+
"Self-attention requires query, key and " "value to be of the same size"
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
k_bias = True
|
| 360 |
+
if rescale_init:
|
| 361 |
+
k_bias = False
|
| 362 |
+
|
| 363 |
+
k_embed_dim = embed_dim
|
| 364 |
+
q_embed_dim = embed_dim
|
| 365 |
+
|
| 366 |
+
self.k_proj = quant_noise(
|
| 367 |
+
nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size
|
| 368 |
+
)
|
| 369 |
+
self.v_proj = quant_noise(
|
| 370 |
+
nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
|
| 371 |
+
)
|
| 372 |
+
self.q_proj = quant_noise(
|
| 373 |
+
nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
self.out_proj = quant_noise(
|
| 377 |
+
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
if add_bias_kv:
|
| 381 |
+
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
|
| 382 |
+
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
|
| 383 |
+
else:
|
| 384 |
+
self.bias_k = self.bias_v = None
|
| 385 |
+
|
| 386 |
+
self.add_zero_attn = add_zero_attn
|
| 387 |
+
|
| 388 |
+
self.gru_rel_pos = gru_rel_pos
|
| 389 |
+
if self.gru_rel_pos:
|
| 390 |
+
self.grep_linear = nn.Linear(self.q_head_dim, 8)
|
| 391 |
+
self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
|
| 392 |
+
|
| 393 |
+
self.reset_parameters()
|
| 394 |
+
|
| 395 |
+
def reset_parameters(self):
|
| 396 |
+
if self.qkv_same_dim:
|
| 397 |
+
# Empirically observed the convergence to be much better with
|
| 398 |
+
# the scaled initialization
|
| 399 |
+
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
|
| 400 |
+
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
|
| 401 |
+
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
|
| 402 |
+
else:
|
| 403 |
+
nn.init.xavier_uniform_(self.k_proj.weight)
|
| 404 |
+
nn.init.xavier_uniform_(self.v_proj.weight)
|
| 405 |
+
nn.init.xavier_uniform_(self.q_proj.weight)
|
| 406 |
+
|
| 407 |
+
nn.init.xavier_uniform_(self.out_proj.weight)
|
| 408 |
+
if self.out_proj.bias is not None:
|
| 409 |
+
nn.init.constant_(self.out_proj.bias, 0.0)
|
| 410 |
+
if self.bias_k is not None:
|
| 411 |
+
nn.init.xavier_normal_(self.bias_k)
|
| 412 |
+
if self.bias_v is not None:
|
| 413 |
+
nn.init.xavier_normal_(self.bias_v)
|
| 414 |
+
if self.has_relative_attention_bias:
|
| 415 |
+
nn.init.xavier_normal_(self.relative_attention_bias.weight)
|
| 416 |
+
|
| 417 |
+
def _relative_positions_bucket(self, relative_positions, bidirectional=True):
|
| 418 |
+
num_buckets = self.num_buckets
|
| 419 |
+
max_distance = self.max_distance
|
| 420 |
+
relative_buckets = 0
|
| 421 |
+
|
| 422 |
+
if bidirectional:
|
| 423 |
+
num_buckets = num_buckets // 2
|
| 424 |
+
relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
|
| 425 |
+
relative_positions = torch.abs(relative_positions)
|
| 426 |
+
else:
|
| 427 |
+
relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions))
|
| 428 |
+
|
| 429 |
+
max_exact = num_buckets // 2
|
| 430 |
+
is_small = relative_positions < max_exact
|
| 431 |
+
|
| 432 |
+
relative_postion_if_large = max_exact + (
|
| 433 |
+
torch.log(relative_positions.float() / max_exact)
|
| 434 |
+
/ math.log(max_distance / max_exact)
|
| 435 |
+
* (num_buckets - max_exact)
|
| 436 |
+
).to(torch.long)
|
| 437 |
+
relative_postion_if_large = torch.min(
|
| 438 |
+
relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large)
|
| 442 |
+
return relative_buckets
|
| 443 |
+
|
| 444 |
+
def compute_bias(self, query_length, key_length):
|
| 445 |
+
context_position = torch.arange(query_length, dtype=torch.long)[:, None]
|
| 446 |
+
memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
|
| 447 |
+
relative_position = memory_position - context_position
|
| 448 |
+
relative_position_bucket = self._relative_positions_bucket(
|
| 449 |
+
relative_position,
|
| 450 |
+
bidirectional=True
|
| 451 |
+
)
|
| 452 |
+
relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
|
| 453 |
+
values = self.relative_attention_bias(relative_position_bucket)
|
| 454 |
+
values = values.permute([2, 0, 1])
|
| 455 |
+
return values
|
| 456 |
+
|
| 457 |
+
def forward(
|
| 458 |
+
self,
|
| 459 |
+
query,
|
| 460 |
+
key: Optional[Tensor],
|
| 461 |
+
value: Optional[Tensor],
|
| 462 |
+
key_padding_mask: Optional[Tensor] = None,
|
| 463 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
| 464 |
+
need_weights: bool = True,
|
| 465 |
+
static_kv: bool = False,
|
| 466 |
+
attn_mask: Optional[Tensor] = None,
|
| 467 |
+
before_softmax: bool = False,
|
| 468 |
+
need_head_weights: bool = False,
|
| 469 |
+
position_bias: Optional[Tensor] = None
|
| 470 |
+
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
|
| 471 |
+
"""Input shape: Time x Batch x Channel
|
| 472 |
+
|
| 473 |
+
Args:
|
| 474 |
+
key_padding_mask (ByteTensor, optional): mask to exclude
|
| 475 |
+
keys that are pads, of shape `(batch, src_len)`, where
|
| 476 |
+
padding elements are indicated by 1s.
|
| 477 |
+
need_weights (bool, optional): return the attention weights,
|
| 478 |
+
averaged over heads (default: False).
|
| 479 |
+
attn_mask (ByteTensor, optional): typically used to
|
| 480 |
+
implement causal attention, where the mask prevents the
|
| 481 |
+
attention from looking forward in time (default: None).
|
| 482 |
+
before_softmax (bool, optional): return the raw attention
|
| 483 |
+
weights and values before the attention softmax.
|
| 484 |
+
need_head_weights (bool, optional): return the attention
|
| 485 |
+
weights for each head. Implies *need_weights*. Default:
|
| 486 |
+
return the average attention weights over all heads.
|
| 487 |
+
"""
|
| 488 |
+
if need_head_weights:
|
| 489 |
+
need_weights = True
|
| 490 |
+
|
| 491 |
+
is_tpu = query.device.type == "xla"
|
| 492 |
+
|
| 493 |
+
tgt_len, bsz, embed_dim = query.size()
|
| 494 |
+
src_len = tgt_len
|
| 495 |
+
assert embed_dim == self.embed_dim
|
| 496 |
+
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
| 497 |
+
if key is not None:
|
| 498 |
+
src_len, key_bsz, _ = key.size()
|
| 499 |
+
if not torch.jit.is_scripting():
|
| 500 |
+
assert key_bsz == bsz
|
| 501 |
+
assert value is not None
|
| 502 |
+
assert src_len, bsz == value.shape[:2]
|
| 503 |
+
|
| 504 |
+
if self.has_relative_attention_bias and position_bias is None:
|
| 505 |
+
position_bias = self.compute_bias(tgt_len, src_len)
|
| 506 |
+
position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len)
|
| 507 |
+
|
| 508 |
+
if (
|
| 509 |
+
not is_tpu # don't use PyTorch version on TPUs
|
| 510 |
+
and incremental_state is None
|
| 511 |
+
and not static_kv
|
| 512 |
+
# A workaround for quantization to work. Otherwise JIT compilation
|
| 513 |
+
# treats bias in linear module as method.
|
| 514 |
+
and not torch.jit.is_scripting()
|
| 515 |
+
and self.q_head_dim == self.head_dim
|
| 516 |
+
):
|
| 517 |
+
assert key is not None and value is not None
|
| 518 |
+
assert attn_mask is None
|
| 519 |
+
|
| 520 |
+
attn_mask_rel_pos = None
|
| 521 |
+
if position_bias is not None:
|
| 522 |
+
attn_mask_rel_pos = position_bias
|
| 523 |
+
if self.gru_rel_pos:
|
| 524 |
+
query_layer = query.transpose(0, 1)
|
| 525 |
+
new_x_shape = query_layer.size()[:-1] + (self.num_heads, -1)
|
| 526 |
+
query_layer = query_layer.view(*new_x_shape)
|
| 527 |
+
query_layer = query_layer.permute(0, 2, 1, 3)
|
| 528 |
+
_B, _H, _L, __ = query_layer.size()
|
| 529 |
+
|
| 530 |
+
gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(
|
| 531 |
+
_B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
|
| 532 |
+
gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
|
| 533 |
+
attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
|
| 534 |
+
|
| 535 |
+
attn_mask_rel_pos = attn_mask_rel_pos.view((-1, tgt_len, tgt_len))
|
| 536 |
+
k_proj_bias = self.k_proj.bias
|
| 537 |
+
if k_proj_bias is None:
|
| 538 |
+
k_proj_bias = torch.zeros_like(self.q_proj.bias)
|
| 539 |
+
|
| 540 |
+
x, attn = F.multi_head_attention_forward(
|
| 541 |
+
query,
|
| 542 |
+
key,
|
| 543 |
+
value,
|
| 544 |
+
self.embed_dim,
|
| 545 |
+
self.num_heads,
|
| 546 |
+
torch.empty([0]),
|
| 547 |
+
torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
|
| 548 |
+
self.bias_k,
|
| 549 |
+
self.bias_v,
|
| 550 |
+
self.add_zero_attn,
|
| 551 |
+
self.dropout_module.p,
|
| 552 |
+
self.out_proj.weight,
|
| 553 |
+
self.out_proj.bias,
|
| 554 |
+
self.training,
|
| 555 |
+
# self.training or self.dropout_module.apply_during_inference,
|
| 556 |
+
key_padding_mask,
|
| 557 |
+
need_weights,
|
| 558 |
+
attn_mask_rel_pos,
|
| 559 |
+
use_separate_proj_weight=True,
|
| 560 |
+
q_proj_weight=self.q_proj.weight,
|
| 561 |
+
k_proj_weight=self.k_proj.weight,
|
| 562 |
+
v_proj_weight=self.v_proj.weight,
|
| 563 |
+
)
|
| 564 |
+
return x, attn, position_bias
|
| 565 |
+
|
| 566 |
+
if incremental_state is not None:
|
| 567 |
+
saved_state = self._get_input_buffer(incremental_state)
|
| 568 |
+
if saved_state is not None and "prev_key" in saved_state:
|
| 569 |
+
# previous time steps are cached - no need to recompute
|
| 570 |
+
# key and value if they are static
|
| 571 |
+
if static_kv:
|
| 572 |
+
assert self.encoder_decoder_attention and not self.self_attention
|
| 573 |
+
key = value = None
|
| 574 |
+
else:
|
| 575 |
+
saved_state = None
|
| 576 |
+
|
| 577 |
+
if self.self_attention:
|
| 578 |
+
q = self.q_proj(query)
|
| 579 |
+
k = self.k_proj(query)
|
| 580 |
+
v = self.v_proj(query)
|
| 581 |
+
elif self.encoder_decoder_attention:
|
| 582 |
+
# encoder-decoder attention
|
| 583 |
+
q = self.q_proj(query)
|
| 584 |
+
if key is None:
|
| 585 |
+
assert value is None
|
| 586 |
+
k = v = None
|
| 587 |
+
else:
|
| 588 |
+
k = self.k_proj(key)
|
| 589 |
+
v = self.v_proj(key)
|
| 590 |
+
|
| 591 |
+
else:
|
| 592 |
+
assert key is not None and value is not None
|
| 593 |
+
q = self.q_proj(query)
|
| 594 |
+
k = self.k_proj(key)
|
| 595 |
+
v = self.v_proj(value)
|
| 596 |
+
q *= self.scaling
|
| 597 |
+
|
| 598 |
+
if self.bias_k is not None:
|
| 599 |
+
assert self.bias_v is not None
|
| 600 |
+
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
|
| 601 |
+
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
|
| 602 |
+
if attn_mask is not None:
|
| 603 |
+
attn_mask = torch.cat(
|
| 604 |
+
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
|
| 605 |
+
)
|
| 606 |
+
if key_padding_mask is not None:
|
| 607 |
+
key_padding_mask = torch.cat(
|
| 608 |
+
[
|
| 609 |
+
key_padding_mask,
|
| 610 |
+
key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
|
| 611 |
+
],
|
| 612 |
+
dim=1,
|
| 613 |
+
)
|
| 614 |
+
|
| 615 |
+
q = (
|
| 616 |
+
q.contiguous()
|
| 617 |
+
.view(tgt_len, bsz * self.num_heads, self.q_head_dim)
|
| 618 |
+
.transpose(0, 1)
|
| 619 |
+
)
|
| 620 |
+
if k is not None:
|
| 621 |
+
k = (
|
| 622 |
+
k.contiguous()
|
| 623 |
+
.view(-1, bsz * self.num_heads, self.k_head_dim)
|
| 624 |
+
.transpose(0, 1)
|
| 625 |
+
)
|
| 626 |
+
if v is not None:
|
| 627 |
+
v = (
|
| 628 |
+
v.contiguous()
|
| 629 |
+
.view(-1, bsz * self.num_heads, self.head_dim)
|
| 630 |
+
.transpose(0, 1)
|
| 631 |
+
)
|
| 632 |
+
|
| 633 |
+
if saved_state is not None:
|
| 634 |
+
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
|
| 635 |
+
if "prev_key" in saved_state:
|
| 636 |
+
_prev_key = saved_state["prev_key"]
|
| 637 |
+
assert _prev_key is not None
|
| 638 |
+
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
|
| 639 |
+
if static_kv:
|
| 640 |
+
k = prev_key
|
| 641 |
+
else:
|
| 642 |
+
assert k is not None
|
| 643 |
+
k = torch.cat([prev_key, k], dim=1)
|
| 644 |
+
src_len = k.size(1)
|
| 645 |
+
if "prev_value" in saved_state:
|
| 646 |
+
_prev_value = saved_state["prev_value"]
|
| 647 |
+
assert _prev_value is not None
|
| 648 |
+
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
|
| 649 |
+
if static_kv:
|
| 650 |
+
v = prev_value
|
| 651 |
+
else:
|
| 652 |
+
assert v is not None
|
| 653 |
+
v = torch.cat([prev_value, v], dim=1)
|
| 654 |
+
prev_key_padding_mask: Optional[Tensor] = None
|
| 655 |
+
if "prev_key_padding_mask" in saved_state:
|
| 656 |
+
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
|
| 657 |
+
assert k is not None and v is not None
|
| 658 |
+
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
|
| 659 |
+
key_padding_mask=key_padding_mask,
|
| 660 |
+
prev_key_padding_mask=prev_key_padding_mask,
|
| 661 |
+
batch_size=bsz,
|
| 662 |
+
src_len=k.size(1),
|
| 663 |
+
static_kv=static_kv,
|
| 664 |
+
)
|
| 665 |
+
|
| 666 |
+
saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
|
| 667 |
+
saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
|
| 668 |
+
saved_state["prev_key_padding_mask"] = key_padding_mask
|
| 669 |
+
# In this branch incremental_state is never None
|
| 670 |
+
assert incremental_state is not None
|
| 671 |
+
incremental_state = self._set_input_buffer(incremental_state, saved_state)
|
| 672 |
+
assert k is not None
|
| 673 |
+
assert k.size(1) == src_len
|
| 674 |
+
|
| 675 |
+
# This is part of a workaround to get around fork/join parallelism
|
| 676 |
+
# not supporting Optional types.
|
| 677 |
+
if key_padding_mask is not None and key_padding_mask.dim() == 0:
|
| 678 |
+
key_padding_mask = None
|
| 679 |
+
|
| 680 |
+
if key_padding_mask is not None:
|
| 681 |
+
assert key_padding_mask.size(0) == bsz
|
| 682 |
+
assert key_padding_mask.size(1) == src_len
|
| 683 |
+
|
| 684 |
+
if self.add_zero_attn:
|
| 685 |
+
assert v is not None
|
| 686 |
+
src_len += 1
|
| 687 |
+
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
|
| 688 |
+
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
|
| 689 |
+
if attn_mask is not None:
|
| 690 |
+
attn_mask = torch.cat(
|
| 691 |
+
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
|
| 692 |
+
)
|
| 693 |
+
if key_padding_mask is not None:
|
| 694 |
+
key_padding_mask = torch.cat(
|
| 695 |
+
[
|
| 696 |
+
key_padding_mask,
|
| 697 |
+
torch.zeros(key_padding_mask.size(0), 1).type_as(
|
| 698 |
+
key_padding_mask
|
| 699 |
+
),
|
| 700 |
+
],
|
| 701 |
+
dim=1,
|
| 702 |
+
)
|
| 703 |
+
|
| 704 |
+
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
| 705 |
+
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
|
| 706 |
+
|
| 707 |
+
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
| 708 |
+
|
| 709 |
+
if attn_mask is not None:
|
| 710 |
+
attn_mask = attn_mask.unsqueeze(0)
|
| 711 |
+
attn_weights += attn_mask
|
| 712 |
+
|
| 713 |
+
if key_padding_mask is not None:
|
| 714 |
+
# don't attend to padding symbols
|
| 715 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
| 716 |
+
if not is_tpu:
|
| 717 |
+
attn_weights = attn_weights.masked_fill(
|
| 718 |
+
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
|
| 719 |
+
float("-inf"),
|
| 720 |
+
)
|
| 721 |
+
else:
|
| 722 |
+
attn_weights = attn_weights.transpose(0, 2)
|
| 723 |
+
attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
|
| 724 |
+
attn_weights = attn_weights.transpose(0, 2)
|
| 725 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
| 726 |
+
|
| 727 |
+
if before_softmax:
|
| 728 |
+
return attn_weights, v, position_bias
|
| 729 |
+
|
| 730 |
+
if position_bias is not None:
|
| 731 |
+
if self.gru_rel_pos == 1:
|
| 732 |
+
query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim)
|
| 733 |
+
_B, _H, _L, __ = query_layer.size()
|
| 734 |
+
gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(
|
| 735 |
+
_B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
|
| 736 |
+
gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
|
| 737 |
+
position_bias = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
|
| 738 |
+
|
| 739 |
+
position_bias = position_bias.view(attn_weights.size())
|
| 740 |
+
|
| 741 |
+
attn_weights = attn_weights + position_bias
|
| 742 |
+
|
| 743 |
+
attn_weights_float = F.softmax(
|
| 744 |
+
attn_weights, dim=-1
|
| 745 |
+
)
|
| 746 |
+
attn_weights = attn_weights_float.type_as(attn_weights)
|
| 747 |
+
attn_probs = self.dropout_module(attn_weights)
|
| 748 |
+
|
| 749 |
+
assert v is not None
|
| 750 |
+
attn = torch.bmm(attn_probs, v)
|
| 751 |
+
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
| 752 |
+
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
| 753 |
+
attn = self.out_proj(attn)
|
| 754 |
+
attn_weights: Optional[Tensor] = None
|
| 755 |
+
if need_weights:
|
| 756 |
+
attn_weights = attn_weights_float.view(
|
| 757 |
+
bsz, self.num_heads, tgt_len, src_len
|
| 758 |
+
).transpose(1, 0)
|
| 759 |
+
if not need_head_weights:
|
| 760 |
+
# average attention weights over heads
|
| 761 |
+
attn_weights = attn_weights.mean(dim=0)
|
| 762 |
+
|
| 763 |
+
return attn, attn_weights, position_bias
|
| 764 |
+
|
| 765 |
+
@staticmethod
|
| 766 |
+
def _append_prev_key_padding_mask(
|
| 767 |
+
key_padding_mask: Optional[Tensor],
|
| 768 |
+
prev_key_padding_mask: Optional[Tensor],
|
| 769 |
+
batch_size: int,
|
| 770 |
+
src_len: int,
|
| 771 |
+
static_kv: bool,
|
| 772 |
+
) -> Optional[Tensor]:
|
| 773 |
+
# saved key padding masks have shape (bsz, seq_len)
|
| 774 |
+
if prev_key_padding_mask is not None and static_kv:
|
| 775 |
+
new_key_padding_mask = prev_key_padding_mask
|
| 776 |
+
elif prev_key_padding_mask is not None and key_padding_mask is not None:
|
| 777 |
+
new_key_padding_mask = torch.cat(
|
| 778 |
+
[prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
|
| 779 |
+
)
|
| 780 |
+
# During incremental decoding, as the padding token enters and
|
| 781 |
+
# leaves the frame, there will be a time when prev or current
|
| 782 |
+
# is None
|
| 783 |
+
elif prev_key_padding_mask is not None:
|
| 784 |
+
if src_len > prev_key_padding_mask.size(1):
|
| 785 |
+
filler = torch.zeros(
|
| 786 |
+
(batch_size, src_len - prev_key_padding_mask.size(1)),
|
| 787 |
+
device=prev_key_padding_mask.device,
|
| 788 |
+
)
|
| 789 |
+
new_key_padding_mask = torch.cat(
|
| 790 |
+
[prev_key_padding_mask.float(), filler.float()], dim=1
|
| 791 |
+
)
|
| 792 |
+
else:
|
| 793 |
+
new_key_padding_mask = prev_key_padding_mask.float()
|
| 794 |
+
elif key_padding_mask is not None:
|
| 795 |
+
if src_len > key_padding_mask.size(1):
|
| 796 |
+
filler = torch.zeros(
|
| 797 |
+
(batch_size, src_len - key_padding_mask.size(1)),
|
| 798 |
+
device=key_padding_mask.device,
|
| 799 |
+
)
|
| 800 |
+
new_key_padding_mask = torch.cat(
|
| 801 |
+
[filler.float(), key_padding_mask.float()], dim=1
|
| 802 |
+
)
|
| 803 |
+
else:
|
| 804 |
+
new_key_padding_mask = key_padding_mask.float()
|
| 805 |
+
else:
|
| 806 |
+
new_key_padding_mask = prev_key_padding_mask
|
| 807 |
+
return new_key_padding_mask
|
| 808 |
+
|
| 809 |
+
def _get_input_buffer(
|
| 810 |
+
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
|
| 811 |
+
) -> Dict[str, Optional[Tensor]]:
|
| 812 |
+
result = self.get_incremental_state(incremental_state, "attn_state")
|
| 813 |
+
if result is not None:
|
| 814 |
+
return result
|
| 815 |
+
else:
|
| 816 |
+
empty_result: Dict[str, Optional[Tensor]] = {}
|
| 817 |
+
return empty_result
|
| 818 |
+
|
| 819 |
+
def _set_input_buffer(
|
| 820 |
+
self,
|
| 821 |
+
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
|
| 822 |
+
buffer: Dict[str, Optional[Tensor]],
|
| 823 |
+
):
|
| 824 |
+
return self.set_incremental_state(incremental_state, "attn_state", buffer)
|
| 825 |
+
|
| 826 |
+
def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
|
| 827 |
+
return attn_weights
|