Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import logging | |
| import os | |
| import os.path as op | |
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from fairseq.data.audio.text_to_speech_dataset import TextToSpeechDatasetCreator | |
| from fairseq.tasks import register_task | |
| from fairseq.tasks.speech_to_text import SpeechToTextTask | |
| from fairseq.speech_generator import ( | |
| AutoRegressiveSpeechGenerator, | |
| NonAutoregressiveSpeechGenerator, | |
| TeacherForcingAutoRegressiveSpeechGenerator, | |
| ) | |
| logging.basicConfig( | |
| format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", | |
| datefmt="%Y-%m-%d %H:%M:%S", | |
| level=logging.INFO, | |
| ) | |
| logger = logging.getLogger(__name__) | |
| try: | |
| from tensorboardX import SummaryWriter | |
| except ImportError: | |
| logger.info("Please install tensorboardX: pip install tensorboardX") | |
| SummaryWriter = None | |
| class TextToSpeechTask(SpeechToTextTask): | |
| def add_args(parser): | |
| parser.add_argument("data", help="manifest root path") | |
| parser.add_argument( | |
| "--config-yaml", | |
| type=str, | |
| default="config.yaml", | |
| help="Configuration YAML filename (under manifest root)", | |
| ) | |
| parser.add_argument( | |
| "--max-source-positions", | |
| default=1024, | |
| type=int, | |
| metavar="N", | |
| help="max number of tokens in the source sequence", | |
| ) | |
| parser.add_argument( | |
| "--max-target-positions", | |
| default=1200, | |
| type=int, | |
| metavar="N", | |
| help="max number of tokens in the target sequence", | |
| ) | |
| parser.add_argument("--n-frames-per-step", type=int, default=1) | |
| parser.add_argument("--eos-prob-threshold", type=float, default=0.5) | |
| parser.add_argument("--eval-inference", action="store_true") | |
| parser.add_argument("--eval-tb-nsample", type=int, default=8) | |
| parser.add_argument("--vocoder", type=str, default="griffin_lim") | |
| parser.add_argument("--spec-bwd-max-iter", type=int, default=8) | |
| def __init__(self, args, src_dict): | |
| super().__init__(args, src_dict) | |
| self.src_dict = src_dict | |
| self.sr = self.data_cfg.config.get("features").get("sample_rate") | |
| self.tensorboard_writer = None | |
| self.tensorboard_dir = "" | |
| if args.tensorboard_logdir and SummaryWriter is not None: | |
| self.tensorboard_dir = os.path.join(args.tensorboard_logdir, "valid_extra") | |
| def load_dataset(self, split, epoch=1, combine=False, **kwargs): | |
| is_train_split = split.startswith("train") | |
| pre_tokenizer = self.build_tokenizer(self.args) | |
| bpe_tokenizer = self.build_bpe(self.args) | |
| self.datasets[split] = TextToSpeechDatasetCreator.from_tsv( | |
| self.args.data, | |
| self.data_cfg, | |
| split, | |
| self.src_dict, | |
| pre_tokenizer, | |
| bpe_tokenizer, | |
| is_train_split=is_train_split, | |
| epoch=epoch, | |
| seed=self.args.seed, | |
| n_frames_per_step=self.args.n_frames_per_step, | |
| speaker_to_id=self.speaker_to_id, | |
| ) | |
| def target_dictionary(self): | |
| return None | |
| def source_dictionary(self): | |
| return self.src_dict | |
| def get_speaker_embeddings_path(self): | |
| speaker_emb_path = None | |
| if self.data_cfg.config.get("speaker_emb_filename") is not None: | |
| speaker_emb_path = op.join( | |
| self.args.data, self.data_cfg.config.get("speaker_emb_filename") | |
| ) | |
| return speaker_emb_path | |
| def get_speaker_embeddings(cls, args): | |
| embed_speaker = None | |
| if args.speaker_to_id is not None: | |
| if args.speaker_emb_path is None: | |
| embed_speaker = torch.nn.Embedding( | |
| len(args.speaker_to_id), args.speaker_embed_dim | |
| ) | |
| else: | |
| speaker_emb_mat = np.load(args.speaker_emb_path) | |
| assert speaker_emb_mat.shape[1] == args.speaker_embed_dim | |
| embed_speaker = torch.nn.Embedding.from_pretrained( | |
| torch.from_numpy(speaker_emb_mat), | |
| freeze=True, | |
| ) | |
| logger.info( | |
| f"load speaker embeddings from {args.speaker_emb_path}. " | |
| f"train embedding? {embed_speaker.weight.requires_grad}\n" | |
| f"embeddings:\n{speaker_emb_mat}" | |
| ) | |
| return embed_speaker | |
| def build_model(self, cfg, from_checkpoint=False): | |
| cfg.pitch_min = self.data_cfg.config["features"].get("pitch_min", None) | |
| cfg.pitch_max = self.data_cfg.config["features"].get("pitch_max", None) | |
| cfg.energy_min = self.data_cfg.config["features"].get("energy_min", None) | |
| cfg.energy_max = self.data_cfg.config["features"].get("energy_max", None) | |
| cfg.speaker_emb_path = self.get_speaker_embeddings_path() | |
| model = super().build_model(cfg, from_checkpoint) | |
| self.generator = None | |
| if getattr(cfg, "eval_inference", False): | |
| self.generator = self.build_generator([model], cfg) | |
| return model | |
| def build_generator(self, models, cfg, vocoder=None, **unused): | |
| if vocoder is None: | |
| vocoder = self.build_default_vocoder() | |
| model = models[0] | |
| if getattr(model, "NON_AUTOREGRESSIVE", False): | |
| return NonAutoregressiveSpeechGenerator(model, vocoder, self.data_cfg) | |
| else: | |
| generator = AutoRegressiveSpeechGenerator | |
| if getattr(cfg, "teacher_forcing", False): | |
| generator = TeacherForcingAutoRegressiveSpeechGenerator | |
| logger.info("Teacher forcing mode for generation") | |
| return generator( | |
| model, | |
| vocoder, | |
| self.data_cfg, | |
| max_iter=self.args.max_target_positions, | |
| eos_prob_threshold=self.args.eos_prob_threshold, | |
| ) | |
| def build_default_vocoder(self): | |
| from fairseq.models.text_to_speech.vocoder import get_vocoder | |
| vocoder = get_vocoder(self.args, self.data_cfg) | |
| if torch.cuda.is_available() and not self.args.cpu: | |
| vocoder = vocoder.cuda() | |
| else: | |
| vocoder = vocoder.cpu() | |
| return vocoder | |
| def valid_step(self, sample, model, criterion): | |
| loss, sample_size, logging_output = super().valid_step(sample, model, criterion) | |
| if getattr(self.args, "eval_inference", False): | |
| hypos, inference_losses = self.valid_step_with_inference( | |
| sample, model, self.generator | |
| ) | |
| for k, v in inference_losses.items(): | |
| assert k not in logging_output | |
| logging_output[k] = v | |
| picked_id = 0 | |
| if self.tensorboard_dir and (sample["id"] == picked_id).any(): | |
| self.log_tensorboard( | |
| sample, | |
| hypos[: self.args.eval_tb_nsample], | |
| model._num_updates, | |
| is_na_model=getattr(model, "NON_AUTOREGRESSIVE", False), | |
| ) | |
| return loss, sample_size, logging_output | |
| def valid_step_with_inference(self, sample, model, generator): | |
| hypos = generator.generate(model, sample, has_targ=True) | |
| losses = { | |
| "mcd_loss": 0.0, | |
| "targ_frames": 0.0, | |
| "pred_frames": 0.0, | |
| "nins": 0.0, | |
| "ndel": 0.0, | |
| } | |
| rets = batch_mel_cepstral_distortion( | |
| [hypo["targ_waveform"] for hypo in hypos], | |
| [hypo["waveform"] for hypo in hypos], | |
| self.sr, | |
| normalize_type=None, | |
| ) | |
| for d, extra in rets: | |
| pathmap = extra[-1] | |
| losses["mcd_loss"] += d.item() | |
| losses["targ_frames"] += pathmap.size(0) | |
| losses["pred_frames"] += pathmap.size(1) | |
| losses["nins"] += (pathmap.sum(dim=1) - 1).sum().item() | |
| losses["ndel"] += (pathmap.sum(dim=0) - 1).sum().item() | |
| return hypos, losses | |
| def log_tensorboard(self, sample, hypos, num_updates, is_na_model=False): | |
| if self.tensorboard_writer is None: | |
| self.tensorboard_writer = SummaryWriter(self.tensorboard_dir) | |
| tb_writer = self.tensorboard_writer | |
| for b in range(len(hypos)): | |
| idx = sample["id"][b] | |
| text = sample["src_texts"][b] | |
| targ = hypos[b]["targ_feature"] | |
| pred = hypos[b]["feature"] | |
| attn = hypos[b]["attn"] | |
| if is_na_model: | |
| data = plot_tts_output( | |
| [targ.transpose(0, 1), pred.transpose(0, 1)], | |
| [f"target (idx={idx})", "output"], | |
| attn, | |
| "alignment", | |
| ret_np=True, | |
| suptitle=text, | |
| ) | |
| else: | |
| eos_prob = hypos[b]["eos_prob"] | |
| data = plot_tts_output( | |
| [targ.transpose(0, 1), pred.transpose(0, 1), attn], | |
| [f"target (idx={idx})", "output", "alignment"], | |
| eos_prob, | |
| "eos prob", | |
| ret_np=True, | |
| suptitle=text, | |
| ) | |
| tb_writer.add_image( | |
| f"inference_sample_{b}", data, num_updates, dataformats="HWC" | |
| ) | |
| if hypos[b]["waveform"] is not None: | |
| targ_wave = hypos[b]["targ_waveform"].detach().cpu().float() | |
| pred_wave = hypos[b]["waveform"].detach().cpu().float() | |
| tb_writer.add_audio( | |
| f"inference_targ_{b}", targ_wave, num_updates, sample_rate=self.sr | |
| ) | |
| tb_writer.add_audio( | |
| f"inference_pred_{b}", pred_wave, num_updates, sample_rate=self.sr | |
| ) | |
| def save_figure_to_numpy(fig): | |
| data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") | |
| data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) | |
| return data | |
| DEFAULT_V_MIN = np.log(1e-5) | |
| def plot_tts_output( | |
| data_2d, | |
| title_2d, | |
| data_1d, | |
| title_1d, | |
| figsize=(24, 4), | |
| v_min=DEFAULT_V_MIN, | |
| v_max=3, | |
| ret_np=False, | |
| suptitle="", | |
| ): | |
| try: | |
| import matplotlib.pyplot as plt | |
| from mpl_toolkits.axes_grid1 import make_axes_locatable | |
| except ImportError: | |
| raise ImportError("Please install Matplotlib: pip install matplotlib") | |
| data_2d = [ | |
| x.detach().cpu().float().numpy() if isinstance(x, torch.Tensor) else x | |
| for x in data_2d | |
| ] | |
| fig, axes = plt.subplots(1, len(data_2d) + 1, figsize=figsize) | |
| if suptitle: | |
| fig.suptitle(suptitle[:400]) # capped at 400 chars | |
| axes = [axes] if len(data_2d) == 0 else axes | |
| for ax, x, name in zip(axes, data_2d, title_2d): | |
| ax.set_title(name) | |
| divider = make_axes_locatable(ax) | |
| cax = divider.append_axes("right", size="5%", pad=0.05) | |
| im = ax.imshow( | |
| x, | |
| origin="lower", | |
| aspect="auto", | |
| vmin=max(x.min(), v_min), | |
| vmax=min(x.max(), v_max), | |
| ) | |
| fig.colorbar(im, cax=cax, orientation="vertical") | |
| if isinstance(data_1d, torch.Tensor): | |
| data_1d = data_1d.detach().cpu().numpy() | |
| axes[-1].plot(data_1d) | |
| axes[-1].set_title(title_1d) | |
| plt.tight_layout() | |
| if ret_np: | |
| fig.canvas.draw() | |
| data = save_figure_to_numpy(fig) | |
| plt.close(fig) | |
| return data | |
| def antidiag_indices(offset, min_i=0, max_i=None, min_j=0, max_j=None): | |
| """ | |
| for a (3, 4) matrix with min_i=1, max_i=3, min_j=1, max_j=4, outputs | |
| offset=2 (1, 1), | |
| offset=3 (2, 1), (1, 2) | |
| offset=4 (2, 2), (1, 3) | |
| offset=5 (2, 3) | |
| constraints: | |
| i + j = offset | |
| min_j <= j < max_j | |
| min_i <= offset - j < max_i | |
| """ | |
| if max_i is None: | |
| max_i = offset + 1 | |
| if max_j is None: | |
| max_j = offset + 1 | |
| min_j = max(min_j, offset - max_i + 1, 0) | |
| max_j = min(max_j, offset - min_i + 1, offset + 1) | |
| j = torch.arange(min_j, max_j) | |
| i = offset - j | |
| return torch.stack([i, j]) | |
| def batch_dynamic_time_warping(distance, shapes=None): | |
| """full batched DTW without any constraints | |
| distance: (batchsize, max_M, max_N) matrix | |
| shapes: (batchsize,) vector specifying (M, N) for each entry | |
| """ | |
| # ptr: 0=left, 1=up-left, 2=up | |
| ptr2dij = {0: (0, -1), 1: (-1, -1), 2: (-1, 0)} | |
| bsz, m, n = distance.size() | |
| cumdist = torch.zeros_like(distance) | |
| backptr = torch.zeros_like(distance).type(torch.int32) - 1 | |
| # initialize | |
| cumdist[:, 0, :] = distance[:, 0, :].cumsum(dim=-1) | |
| cumdist[:, :, 0] = distance[:, :, 0].cumsum(dim=-1) | |
| backptr[:, 0, :] = 0 | |
| backptr[:, :, 0] = 2 | |
| # DP with optimized anti-diagonal parallelization, O(M+N) steps | |
| for offset in range(2, m + n - 1): | |
| ind = antidiag_indices(offset, 1, m, 1, n) | |
| c = torch.stack( | |
| [ | |
| cumdist[:, ind[0], ind[1] - 1], | |
| cumdist[:, ind[0] - 1, ind[1] - 1], | |
| cumdist[:, ind[0] - 1, ind[1]], | |
| ], | |
| dim=2, | |
| ) | |
| v, b = c.min(axis=-1) | |
| backptr[:, ind[0], ind[1]] = b.int() | |
| cumdist[:, ind[0], ind[1]] = v + distance[:, ind[0], ind[1]] | |
| # backtrace | |
| pathmap = torch.zeros_like(backptr) | |
| for b in range(bsz): | |
| i = m - 1 if shapes is None else (shapes[b][0] - 1).item() | |
| j = n - 1 if shapes is None else (shapes[b][1] - 1).item() | |
| dtwpath = [(i, j)] | |
| while (i != 0 or j != 0) and len(dtwpath) < 10000: | |
| assert i >= 0 and j >= 0 | |
| di, dj = ptr2dij[backptr[b, i, j].item()] | |
| i, j = i + di, j + dj | |
| dtwpath.append((i, j)) | |
| dtwpath = dtwpath[::-1] | |
| indices = torch.from_numpy(np.array(dtwpath)) | |
| pathmap[b, indices[:, 0], indices[:, 1]] = 1 | |
| return cumdist, backptr, pathmap | |
| def compute_l2_dist(x1, x2): | |
| """compute an (m, n) L2 distance matrix from (m, d) and (n, d) matrices""" | |
| return torch.cdist(x1.unsqueeze(0), x2.unsqueeze(0), p=2).squeeze(0).pow(2) | |
| def compute_rms_dist(x1, x2): | |
| l2_dist = compute_l2_dist(x1, x2) | |
| return (l2_dist / x1.size(1)).pow(0.5) | |
| def get_divisor(pathmap, normalize_type): | |
| if normalize_type is None: | |
| return 1 | |
| elif normalize_type == "len1": | |
| return pathmap.size(0) | |
| elif normalize_type == "len2": | |
| return pathmap.size(1) | |
| elif normalize_type == "path": | |
| return pathmap.sum().item() | |
| else: | |
| raise ValueError(f"normalize_type {normalize_type} not supported") | |
| def batch_compute_distortion(y1, y2, sr, feat_fn, dist_fn, normalize_type): | |
| d, s, x1, x2 = [], [], [], [] | |
| for cur_y1, cur_y2 in zip(y1, y2): | |
| assert cur_y1.ndim == 1 and cur_y2.ndim == 1 | |
| cur_x1 = feat_fn(cur_y1) | |
| cur_x2 = feat_fn(cur_y2) | |
| x1.append(cur_x1) | |
| x2.append(cur_x2) | |
| cur_d = dist_fn(cur_x1, cur_x2) | |
| d.append(cur_d) | |
| s.append(d[-1].size()) | |
| max_m = max(ss[0] for ss in s) | |
| max_n = max(ss[1] for ss in s) | |
| d = torch.stack( | |
| [F.pad(dd, (0, max_n - dd.size(1), 0, max_m - dd.size(0))) for dd in d] | |
| ) | |
| s = torch.LongTensor(s).to(d.device) | |
| cumdists, backptrs, pathmaps = batch_dynamic_time_warping(d, s) | |
| rets = [] | |
| itr = zip(s, x1, x2, d, cumdists, backptrs, pathmaps) | |
| for (m, n), cur_x1, cur_x2, dist, cumdist, backptr, pathmap in itr: | |
| cumdist = cumdist[:m, :n] | |
| backptr = backptr[:m, :n] | |
| pathmap = pathmap[:m, :n] | |
| divisor = get_divisor(pathmap, normalize_type) | |
| distortion = cumdist[-1, -1] / divisor | |
| ret = distortion, (cur_x1, cur_x2, dist, cumdist, backptr, pathmap) | |
| rets.append(ret) | |
| return rets | |
| def batch_mel_cepstral_distortion(y1, y2, sr, normalize_type="path", mfcc_fn=None): | |
| """ | |
| https://arxiv.org/pdf/2011.03568.pdf | |
| The root mean squared error computed on 13-dimensional MFCC using DTW for | |
| alignment. MFCC features are computed from an 80-channel log-mel | |
| spectrogram using a 50ms Hann window and hop of 12.5ms. | |
| y1: list of waveforms | |
| y2: list of waveforms | |
| sr: sampling rate | |
| """ | |
| try: | |
| import torchaudio | |
| except ImportError: | |
| raise ImportError("Please install torchaudio: pip install torchaudio") | |
| if mfcc_fn is None or mfcc_fn.sample_rate != sr: | |
| melkwargs = { | |
| "n_fft": int(0.05 * sr), | |
| "win_length": int(0.05 * sr), | |
| "hop_length": int(0.0125 * sr), | |
| "f_min": 20, | |
| "n_mels": 80, | |
| "window_fn": torch.hann_window, | |
| } | |
| mfcc_fn = torchaudio.transforms.MFCC( | |
| sr, n_mfcc=13, log_mels=True, melkwargs=melkwargs | |
| ).to(y1[0].device) | |
| return batch_compute_distortion( | |
| y1, | |
| y2, | |
| sr, | |
| lambda y: mfcc_fn(y).transpose(-1, -2), | |
| compute_rms_dist, | |
| normalize_type, | |
| ) | |