Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import BertModel, BertConfig, BertTokenizer | |
| class CharEmbedding(nn.Module): | |
| def __init__(self, model_dir): | |
| super().__init__() | |
| self.tokenizer = BertTokenizer.from_pretrained(model_dir) | |
| self.bert_config = BertConfig.from_pretrained(model_dir) | |
| self.hidden_size = self.bert_config.hidden_size | |
| self.bert = BertModel(self.bert_config) | |
| self.proj = nn.Linear(self.hidden_size, 256) | |
| self.linear = nn.Linear(256, 3) | |
| def text2Token(self, text): | |
| token = self.tokenizer.tokenize(text) | |
| txtid = self.tokenizer.convert_tokens_to_ids(token) | |
| return txtid | |
| def forward(self, inputs_ids, inputs_masks, tokens_type_ids): | |
| out_seq = self.bert(input_ids=inputs_ids, | |
| attention_mask=inputs_masks, | |
| token_type_ids=tokens_type_ids)[0] | |
| out_seq = self.proj(out_seq) | |
| return out_seq | |
| class TTSProsody(object): | |
| def __init__(self, path, device): | |
| self.device = device | |
| self.char_model = CharEmbedding(path) | |
| self.char_model.load_state_dict( | |
| torch.load( | |
| os.path.join(path, 'prosody_model.pt'), | |
| map_location="cpu" | |
| ), | |
| strict=False | |
| ) | |
| self.char_model.eval() | |
| self.char_model.to(self.device) | |
| def get_char_embeds(self, text): | |
| input_ids = self.char_model.text2Token(text) | |
| input_masks = [1] * len(input_ids) | |
| type_ids = [0] * len(input_ids) | |
| input_ids = torch.LongTensor([input_ids]).to(self.device) | |
| input_masks = torch.LongTensor([input_masks]).to(self.device) | |
| type_ids = torch.LongTensor([type_ids]).to(self.device) | |
| with torch.no_grad(): | |
| char_embeds = self.char_model( | |
| input_ids, input_masks, type_ids).squeeze(0).cpu() | |
| return char_embeds | |
| def expand_for_phone(self, char_embeds, length): # length of phones for char | |
| assert char_embeds.size(0) == len(length) | |
| expand_vecs = list() | |
| for vec, leng in zip(char_embeds, length): | |
| vec = vec.expand(leng, -1) | |
| expand_vecs.append(vec) | |
| expand_embeds = torch.cat(expand_vecs, 0) | |
| assert expand_embeds.size(0) == sum(length) | |
| return expand_embeds.numpy() | |
| if __name__ == "__main__": | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| prosody = TTSProsody('./bert/', device) | |
| while True: | |
| text = input("请输入文本:") | |
| prosody.get_char_embeds(text) | |