Spaces:
Runtime error
Runtime error
| from __future__ import absolute_import, division, print_function, unicode_literals | |
| from typing import Tuple | |
| import sys | |
| from argparse import ArgumentParser | |
| import torch | |
| import numpy as np | |
| import os | |
| import json | |
| import torch | |
| sys.path.append(os.path.join(os.path.dirname(__file__), "../src/glow_tts")) | |
| from scipy.io.wavfile import write | |
| from hifi.env import AttrDict | |
| from hifi.models import Generator | |
| from text import text_to_sequence | |
| import commons | |
| import models | |
| import utils | |
| def check_directory(dir): | |
| if not os.path.exists(dir): | |
| sys.exit("Error: {} directory does not exist".format(dir)) | |
| class TextToMel: | |
| def __init__(self, glow_model_dir, device="cuda"): | |
| self.glow_model_dir = glow_model_dir | |
| check_directory(self.glow_model_dir) | |
| self.device = device | |
| self.hps, self.glow_tts_model = self.load_glow_tts() | |
| pass | |
| def load_glow_tts(self): | |
| hps = utils.get_hparams_from_dir(self.glow_model_dir) | |
| checkpoint_path = utils.latest_checkpoint_path(self.glow_model_dir) | |
| symbols = list(hps.data.punc) + list(hps.data.chars) | |
| glow_tts_model = models.FlowGenerator( | |
| len(symbols) + getattr(hps.data, "add_blank", False), | |
| out_channels=hps.data.n_mel_channels, | |
| **hps.model | |
| ) # .to(self.device) | |
| if self.device == "cuda": | |
| glow_tts_model.to("cuda") | |
| utils.load_checkpoint(checkpoint_path, glow_tts_model) | |
| glow_tts_model.decoder.store_inverse() | |
| _ = glow_tts_model.eval() | |
| return hps, glow_tts_model | |
| def generate_mel(self, text, noise_scale=0.667, length_scale=1.0): | |
| symbols = list(self.hps.data.punc) + list(self.hps.data.chars) | |
| cleaner = self.hps.data.text_cleaners | |
| if getattr(self.hps.data, "add_blank", False): | |
| text_norm = text_to_sequence(text, symbols, cleaner) | |
| text_norm = commons.intersperse(text_norm, len(symbols)) | |
| else: # If not using "add_blank" option during training, adding spaces at the beginning and the end of utterance improves quality | |
| text = " " + text.strip() + " " | |
| text_norm = text_to_sequence(text, symbols, cleaner) | |
| sequence = np.array(text_norm)[None, :] | |
| del symbols | |
| del cleaner | |
| del text | |
| del text_norm | |
| if self.device == "cuda": | |
| x_tst = torch.autograd.Variable(torch.from_numpy(sequence)).cuda().long() | |
| x_tst_lengths = torch.tensor([x_tst.shape[1]]).cuda() | |
| else: | |
| x_tst = torch.autograd.Variable(torch.from_numpy(sequence)).long() | |
| x_tst_lengths = torch.tensor([x_tst.shape[1]]) | |
| with torch.no_grad(): | |
| (y_gen_tst, *_), *_, (attn_gen, *_) = self.glow_tts_model( | |
| x_tst, | |
| x_tst_lengths, | |
| gen=True, | |
| noise_scale=noise_scale, | |
| length_scale=length_scale, | |
| ) | |
| del x_tst | |
| del x_tst_lengths | |
| torch.cuda.empty_cache() | |
| return y_gen_tst | |
| #return y_gen_tst.cpu().detach().numpy() | |
| class MelToWav: | |
| def __init__(self, hifi_model_dir, device="cuda"): | |
| self.hifi_model_dir = hifi_model_dir | |
| check_directory(self.hifi_model_dir) | |
| self.device = device | |
| self.h, self.hifi_gan_generator = self.load_hifi_gan() | |
| pass | |
| def load_hifi_gan(self): | |
| checkpoint_path = utils.latest_checkpoint_path(self.hifi_model_dir, regex="g_*") | |
| config_file = os.path.join(self.hifi_model_dir, "config.json") | |
| data = open(config_file).read() | |
| json_config = json.loads(data) | |
| h = AttrDict(json_config) | |
| torch.manual_seed(h.seed) | |
| generator = Generator(h).to(self.device) | |
| assert os.path.isfile(checkpoint_path) | |
| print("Loading '{}'".format(checkpoint_path)) | |
| state_dict_g = torch.load(checkpoint_path, map_location=self.device) | |
| print("Complete.") | |
| generator.load_state_dict(state_dict_g["generator"]) | |
| generator.eval() | |
| generator.remove_weight_norm() | |
| return h, generator | |
| def generate_wav(self, mel): | |
| #mel = torch.FloatTensor(mel).to(self.device) | |
| y_g_hat = self.hifi_gan_generator(mel.to(self.device)) # passing through vocoder | |
| audio = y_g_hat.squeeze() | |
| audio = audio * 32768.0 | |
| audio = audio.cpu().detach().numpy().astype("int16") | |
| del y_g_hat | |
| del mel | |
| torch.cuda.empty_cache() | |
| return audio, self.h.sampling_rate | |
| if __name__ == "__main__": | |
| parser = ArgumentParser() | |
| parser.add_argument("-m", "--model", required=True, type=str) | |
| parser.add_argument("-g", "--gan", required=True, type=str) | |
| parser.add_argument("-d", "--device", type=str, default="cpu") | |
| parser.add_argument("-t", "--text", type=str, required=True) | |
| parser.add_argument("-w", "--wav", type=str, required=True) | |
| args = parser.parse_args() | |
| text_to_mel = TextToMel(glow_model_dir=args.model, device=args.device) | |
| mel_to_wav = MelToWav(hifi_model_dir=args.gan, device=args.device) | |
| mel = text_to_mel.generate_mel(args.text) | |
| audio, sr = mel_to_wav.generate_wav(mel) | |
| write(filename=args.wav, rate=sr, data=audio) | |
| pass | |