Spaces:
Runtime error
Runtime error
| from math import sqrt | |
| import torch | |
| from torch import nn | |
| from Encoder import Encoder | |
| from Decoder import Decoder | |
| from Postnet import Postnet | |
| from GST import GST | |
| from utils import to_gpu, get_mask_from_lengths | |
| from fp16_optimizer import fp32_to_fp16, fp16_to_fp32 | |
| torch.manual_seed(1234) | |
| class tacotron_2(nn.Module): | |
| def __init__(self, tacotron_hyperparams): | |
| super(tacotron_2, self).__init__() | |
| self.mask_padding = tacotron_hyperparams['mask_padding'] | |
| self.fp16_run = tacotron_hyperparams['fp16_run'] | |
| self.n_mel_channels = tacotron_hyperparams['n_mel_channels'] | |
| self.n_frames_per_step = tacotron_hyperparams['number_frames_step'] | |
| self.embedding = nn.Embedding( | |
| tacotron_hyperparams['n_symbols'], tacotron_hyperparams['symbols_embedding_length']) | |
| # CHECK THIS OUT!!! | |
| std = sqrt(2.0 / (tacotron_hyperparams['n_symbols'] + tacotron_hyperparams['symbols_embedding_length'])) | |
| val = sqrt(3.0) * std | |
| self.embedding.weight.data.uniform_(-val, val) | |
| self.encoder = Encoder(tacotron_hyperparams) | |
| self.decoder = Decoder(tacotron_hyperparams) | |
| self.postnet = Postnet(tacotron_hyperparams) | |
| self.gst = GST(tacotron_hyperparams) | |
| def parse_batch(self, batch): | |
| # GST I add the new tensor from prosody features to train GST tokens: | |
| text_padded, input_lengths, mel_padded, gate_padded, output_lengths, prosody_padded = batch | |
| text_padded = to_gpu(text_padded).long() | |
| max_len = int(torch.max(input_lengths.data).item()) # With item() you get the pure value (not in a tensor) | |
| input_lengths = to_gpu(input_lengths).long() | |
| mel_padded = to_gpu(mel_padded).float() | |
| gate_padded = to_gpu(gate_padded).float() | |
| output_lengths = to_gpu(output_lengths).long() | |
| prosody_padded = to_gpu(prosody_padded).float() | |
| return ( | |
| (text_padded, input_lengths, mel_padded, max_len, output_lengths, prosody_padded), | |
| (mel_padded, gate_padded)) | |
| def parse_input(self, inputs): | |
| inputs = fp32_to_fp16(inputs) if self.fp16_run else inputs | |
| return inputs | |
| def parse_output(self, outputs, output_lengths=None): | |
| if self.mask_padding and output_lengths is not None: | |
| mask = ~get_mask_from_lengths(output_lengths) | |
| mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1)) | |
| mask = mask.permute(1, 0, 2) | |
| outputs[0].data.masked_fill_(mask, 0.0) | |
| outputs[1].data.masked_fill_(mask, 0.0) | |
| outputs[2].data.masked_fill_(mask[:, 0, :], 1e3) # gate energies | |
| outputs = fp16_to_fp32(outputs) if self.fp16_run else outputs | |
| return outputs | |
| def forward(self, inputs): | |
| inputs, input_lengths, targets, max_len, output_lengths, gst_prosody_padded = self.parse_input(inputs) | |
| input_lengths, output_lengths = input_lengths.data, output_lengths.data | |
| embedded_inputs = self.embedding(inputs).transpose(1, 2) | |
| encoder_outputs = self.encoder(embedded_inputs, input_lengths) | |
| # GST style embedding plus embedded_inputs before entering the decoder | |
| # bin_locations = gst_prosody_padded[:, 0, :] | |
| # pitch_intensities = gst_prosody_padded[:, 1:, :] | |
| # bin_locations = bin_locations.unsqueeze(2) | |
| gst_style_embedding, gst_scores = self.gst(gst_prosody_padded, output_lengths) # [N, 512] | |
| gst_style_embedding = gst_style_embedding.expand_as(encoder_outputs) | |
| encoder_outputs = encoder_outputs + gst_style_embedding | |
| mel_outputs, gate_outputs, alignments = self.decoder( | |
| encoder_outputs, targets, memory_lengths=input_lengths) | |
| mel_outputs_postnet = self.postnet(mel_outputs) | |
| mel_outputs_postnet = mel_outputs + mel_outputs_postnet | |
| return self.parse_output( | |
| [mel_outputs, mel_outputs_postnet, gate_outputs, alignments, gst_scores], | |
| output_lengths) | |
| def inference(self, inputs, gst_scores): # gst_scores must be a torch tensor | |
| inputs = self.parse_input(inputs) | |
| embedded_inputs = self.embedding(inputs).transpose(1, 2) | |
| encoder_outputs = self.encoder.inference(embedded_inputs) | |
| # GST inference: | |
| gst_style_embedding = self.gst.inference(gst_scores) | |
| gst_style_embedding = gst_style_embedding.expand_as(encoder_outputs) | |
| encoder_outputs = encoder_outputs + gst_style_embedding | |
| mel_outputs, gate_outputs, alignments = self.decoder.inference( | |
| encoder_outputs) | |
| mel_outputs_postnet = self.postnet(mel_outputs) | |
| mel_outputs_postnet = mel_outputs + mel_outputs_postnet | |
| outputs = self.parse_output( | |
| [mel_outputs, mel_outputs_postnet, gate_outputs, alignments]) | |
| return outputs | |