Spaces:
Running
on
Zero
Running
on
Zero
| # cp from https://github.com/lifeiteng/vall-e/blob/main/valle/data/tokenizer.py | |
| # Copyright 2023 (authors: Feiteng Li) | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import re | |
| from dataclasses import asdict, dataclass | |
| from typing import Any, Dict, List, Optional, Pattern, Union | |
| import numpy as np | |
| import torch | |
| import torchaudio | |
| # from encodec import EncodecModel | |
| # from encodec.utils import convert_audio | |
| # from lhotse.features import FeatureExtractor | |
| # from lhotse.utils import Seconds, compute_num_frames | |
| from phonemizer.backend import EspeakBackend | |
| from phonemizer.backend.espeak.language_switch import LanguageSwitch | |
| from phonemizer.backend.espeak.words_mismatch import WordMismatch | |
| from phonemizer.punctuation import Punctuation | |
| from phonemizer.separator import Separator | |
| try: | |
| from pypinyin import Style, pinyin | |
| from pypinyin.style._utils import get_finals, get_initials | |
| except Exception: | |
| pass | |
| class PypinyinBackend: | |
| """PypinyinBackend for Chinese. Most codes is referenced from espnet. | |
| There are two types pinyin or initials_finals, one is | |
| just like "ni1 hao3", the other is like "n i1 h ao3". | |
| """ | |
| def __init__( | |
| self, | |
| backend="initials_finals", | |
| punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(), | |
| ) -> None: | |
| self.backend = backend | |
| self.punctuation_marks = punctuation_marks | |
| def phonemize( | |
| self, text: List[str], separator: Separator, strip=True, njobs=1 | |
| ) -> List[str]: | |
| assert isinstance(text, List) | |
| phonemized = [] | |
| for _text in text: | |
| _text = re.sub(" +", " ", _text.strip()) | |
| _text = _text.replace(" ", separator.word) | |
| phones = [] | |
| if self.backend == "pypinyin": | |
| for n, py in enumerate( | |
| pinyin( | |
| _text, style=Style.TONE3, neutral_tone_with_five=True | |
| ) | |
| ): | |
| if all([c in self.punctuation_marks for c in py[0]]): | |
| if len(phones): | |
| assert phones[-1] == separator.syllable | |
| phones.pop(-1) | |
| phones.extend(list(py[0])) | |
| else: | |
| phones.extend([py[0], separator.syllable]) | |
| elif self.backend == "pypinyin_initials_finals": | |
| for n, py in enumerate( | |
| pinyin( | |
| _text, style=Style.TONE3, neutral_tone_with_five=True | |
| ) | |
| ): | |
| if all([c in self.punctuation_marks for c in py[0]]): | |
| if len(phones): | |
| assert phones[-1] == separator.syllable | |
| phones.pop(-1) | |
| phones.extend(list(py[0])) | |
| else: | |
| if py[0][-1].isalnum(): | |
| initial = get_initials(py[0], strict=False) | |
| if py[0][-1].isdigit(): | |
| final = ( | |
| get_finals(py[0][:-1], strict=False) | |
| + py[0][-1] | |
| ) | |
| else: | |
| final = get_finals(py[0], strict=False) | |
| phones.extend( | |
| [ | |
| initial, | |
| separator.phone, | |
| final, | |
| separator.syllable, | |
| ] | |
| ) | |
| else: | |
| assert ValueError | |
| else: | |
| raise NotImplementedError | |
| phonemized.append( | |
| "".join(phones).rstrip(f"{separator.word}{separator.syllable}") | |
| ) | |
| return phonemized | |
| class TextTokenizer: | |
| """Phonemize Text.""" | |
| def __init__( | |
| self, | |
| language="en-us", | |
| backend="espeak", | |
| separator=Separator(word="_", syllable="-", phone="|"), | |
| preserve_punctuation=True, | |
| punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(), | |
| with_stress: bool = False, | |
| tie: Union[bool, str] = False, | |
| language_switch: LanguageSwitch = "keep-flags", | |
| words_mismatch: WordMismatch = "ignore", | |
| ) -> None: | |
| if backend == "espeak": | |
| phonemizer = EspeakBackend( | |
| language, | |
| punctuation_marks=punctuation_marks, | |
| preserve_punctuation=preserve_punctuation, | |
| with_stress=with_stress, | |
| tie=tie, | |
| language_switch=language_switch, | |
| words_mismatch=words_mismatch, | |
| ) | |
| elif backend in ["pypinyin", "pypinyin_initials_finals"]: | |
| phonemizer = PypinyinBackend( | |
| backend=backend, | |
| punctuation_marks=punctuation_marks + separator.word, | |
| ) | |
| else: | |
| raise NotImplementedError(f"{backend}") | |
| self.backend = phonemizer | |
| self.separator = separator | |
| def to_list(self, phonemized: str) -> List[str]: | |
| fields = [] | |
| for word in phonemized.split(self.separator.word): | |
| # "ɐ m|iː|n?" ɹ|ɪ|z|ɜː|v; h|ɪ|z. | |
| pp = re.findall(r"\w+|[^\w\s]", word, re.UNICODE) | |
| fields.extend( | |
| [p for p in pp if p != self.separator.phone] | |
| + [self.separator.word] | |
| ) | |
| assert len("".join(fields[:-1])) == len(phonemized) - phonemized.count( | |
| self.separator.phone | |
| ) | |
| return fields[:-1] | |
| def __call__(self, text, strip=True) -> List[List[str]]: | |
| if isinstance(text, str): | |
| text = [text] | |
| phonemized = self.backend.phonemize( | |
| text, separator=self.separator, strip=strip, njobs=1 | |
| ) | |
| return [self.to_list(p) for p in phonemized] | |
| def tokenize_text(tokenizer: TextTokenizer, text: str) -> List[str]: | |
| phonemes = tokenizer([text.strip()]) | |
| return phonemes[0] # k2symbols | |
| def remove_encodec_weight_norm(model): | |
| from encodec.modules import SConv1d | |
| from encodec.modules.seanet import SConvTranspose1d, SEANetResnetBlock | |
| from torch.nn.utils import remove_weight_norm | |
| encoder = model.encoder.model | |
| for key in encoder._modules: | |
| if isinstance(encoder._modules[key], SEANetResnetBlock): | |
| remove_weight_norm(encoder._modules[key].shortcut.conv.conv) | |
| block_modules = encoder._modules[key].block._modules | |
| for skey in block_modules: | |
| if isinstance(block_modules[skey], SConv1d): | |
| remove_weight_norm(block_modules[skey].conv.conv) | |
| elif isinstance(encoder._modules[key], SConv1d): | |
| remove_weight_norm(encoder._modules[key].conv.conv) | |
| decoder = model.decoder.model | |
| for key in decoder._modules: | |
| if isinstance(decoder._modules[key], SEANetResnetBlock): | |
| remove_weight_norm(decoder._modules[key].shortcut.conv.conv) | |
| block_modules = decoder._modules[key].block._modules | |
| for skey in block_modules: | |
| if isinstance(block_modules[skey], SConv1d): | |
| remove_weight_norm(block_modules[skey].conv.conv) | |
| elif isinstance(decoder._modules[key], SConvTranspose1d): | |
| remove_weight_norm(decoder._modules[key].convtr.convtr) | |
| elif isinstance(decoder._modules[key], SConv1d): | |
| remove_weight_norm(decoder._modules[key].conv.conv) | |
| # class AudioTokenizer: | |
| # """EnCodec audio.""" | |
| # def __init__( | |
| # self, | |
| # bandwidth, float=6.0, | |
| # device: Any = None, | |
| # ) -> None: | |
| # # Instantiate a pretrained EnCodec model | |
| # model = EncodecModel.encodec_model_24khz() | |
| # model.set_target_bandwidth(bandwidth=bandwidth) | |
| # remove_encodec_weight_norm(model) | |
| # if not device: | |
| # device = torch.device("cpu") | |
| # if torch.cuda.is_available(): | |
| # device = torch.device("cuda:0") | |
| # self._device = device | |
| # self.codec = model.to(device) | |
| # self.sample_rate = model.sample_rate | |
| # self.channels = model.channels | |
| # @property | |
| # def device(self): | |
| # return self._device | |
| # def encode(self, wav: torch.Tensor) -> torch.Tensor: | |
| # return self.codec.encode(wav.to(self.device)) | |
| # def decode(self, frames: torch.Tensor) -> torch.Tensor: | |
| # return self.codec.decode(frames) | |
| # class AudioTokenizer: | |
| # """EnCodec audio.""" | |
| # def __init__( | |
| # self, | |
| # bandwidth: float=6.0, | |
| # device: Any = None, | |
| # hificodec=False, | |
| # signature = None | |
| # ) -> None: | |
| # self.hificodec = hificodec | |
| # self.customized = True if signature != None else False | |
| # if hificodec: | |
| # import sys | |
| # sys.path.append("/home/pyp/AcademiCodec") | |
| # from academicodec.models.hificodec.vqvae import VQVAE | |
| # config_path = "/home/pyp/AcademiCodec/egs/HiFi-Codec-16k-320d/config_16k_320d.json" | |
| # model_path = "/home/pyp/AcademiCodec/egs/HiFi-Codec-16k-320d/checkpoint/HiFi-Codec-16k-320d" | |
| # self.sample_rate = 16000 | |
| # self.channels = 1 | |
| # model = VQVAE(config_path, model_path, with_encoder=True) | |
| # model.generator.remove_weight_norm() | |
| # model.encoder.remove_weight_norm() | |
| # model.eval() | |
| # else: | |
| # if signature != None: | |
| # # use customized encodec model | |
| # # import sys | |
| # # sys.path.append("home/pyp/audiocraft") | |
| # from audiocraft.solvers import CompressionSolver | |
| # model_path = f'//sig/{signature}' | |
| # model = CompressionSolver.model_from_checkpoint(model_path) | |
| # self.sample_rate = model.sample_rate | |
| # self.channels = model.channels | |
| # else: | |
| # # Instantiate a pretrained EnCodec model | |
| # model = EncodecModel.encodec_model_24khz() | |
| # model.set_target_bandwidth(bandwidth=bandwidth) | |
| # remove_encodec_weight_norm(model) | |
| # self.sample_rate = model.sample_rate | |
| # self.channels = model.channels | |
| # if not device: | |
| # device = torch.device("cpu") | |
| # if torch.cuda.is_available(): | |
| # device = torch.device("cuda:0") | |
| # self._device = device | |
| # self.codec = model.to(device) | |
| # @property | |
| # def device(self): | |
| # return self._device | |
| # def encode(self, wav: torch.Tensor) -> torch.Tensor: | |
| # if self.hificodec: | |
| # assert wav.ndim==3 and wav.shape[:2] == torch.Size((1,1)), wav.shape | |
| # wav = wav.squeeze(0) | |
| # codes = self.codec.encode(wav.to(self.device)) # [1,T,4] | |
| # return [(codes.transpose(2,1),None)] | |
| # elif self.customized: | |
| # codes = self.codec.encode(wav.to(self.device)) | |
| # return [(codes[0], None)] | |
| # return self.codec.encode(wav.to(self.device)) | |
| # def decode(self, frames: torch.Tensor) -> torch.Tensor: | |
| # if self.hificodec: | |
| # frames = frames[0][0] # [1,4,T] | |
| # assert frames.shape[:2] == torch.Size((1,4)) | |
| # audio = self.codec(frames.transpose(2,1)) | |
| # assert audio.shape[0] == 1, audio.shape | |
| # return audio | |
| # elif self.customized: | |
| # frames = frames[0][0] # [1,4,T] | |
| # return self.codec.decode(frames) | |
| # return self.codec.decode(frames) | |
| # # try: | |
| # # return self.codec.decode(frames) | |
| # # except: | |
| # # import logging | |
| # # logging.info(f"error when decoding frame of shape: {frames[0][0].shape}") | |
| # # self.codec.cpu() | |
| # # ret = self.codec.cpu().decode([(frames[0][0].cpu(),None)])[0].to(self._device) | |
| # # self.codec.to(self._device) | |
| # # return [ret] | |
| # def tokenize_audio(tokenizer: AudioTokenizer, audio_path: str, offset = -1, num_frames=-1): | |
| # # Load and pre-process the audio waveform | |
| # if offset != -1 and num_frames!=-1: | |
| # wav, sr = torchaudio.load(audio_path, frame_offset=offset, num_frames=num_frames) | |
| # else: | |
| # wav, sr = torchaudio.load(audio_path) | |
| # wav = convert_audio(wav, sr, tokenizer.sample_rate, tokenizer.channels) | |
| # wav = wav.unsqueeze(0) | |
| # # Extract discrete codes from EnCodec | |
| # with torch.no_grad(): | |
| # encoded_frames = tokenizer.encode(wav) | |
| # return encoded_frames | |
| # @dataclass | |
| # class AudioTokenConfig: | |
| # frame_shift: Seconds = 320.0 / 24000 | |
| # num_quantizers: int = 8 | |
| # def to_dict(self) -> Dict[str, Any]: | |
| # return asdict(self) | |
| # @staticmethod | |
| # def from_dict(data: Dict[str, Any]) -> "AudioTokenConfig": | |
| # return AudioTokenConfig(**data) | |
| # class AudioTokenExtractor(FeatureExtractor): | |
| # name = "encodec" | |
| # config_type = AudioTokenConfig | |
| # def __init__(self, config: Optional[Any] = None): | |
| # super(AudioTokenExtractor, self).__init__(config) | |
| # self.tokenizer = AudioTokenizer() | |
| # def extract( | |
| # self, samples: Union[np.ndarray, torch.Tensor], sampling_rate: int | |
| # ) -> np.ndarray: | |
| # if not isinstance(samples, torch.Tensor): | |
| # samples = torch.from_numpy(samples) | |
| # if sampling_rate != self.tokenizer.sample_rate: | |
| # samples = convert_audio( | |
| # samples, | |
| # sampling_rate, | |
| # self.tokenizer.sample_rate, | |
| # self.tokenizer.channels, | |
| # ) | |
| # if len(samples.shape) == 2: | |
| # samples = samples.unsqueeze(0) | |
| # else: | |
| # raise ValueError() | |
| # device = self.tokenizer.device | |
| # encoded_frames = self.tokenizer.encode(samples.detach().to(device)) | |
| # codes = encoded_frames[0][0] # [B, n_q, T] | |
| # if True: | |
| # duration = round(samples.shape[-1] / sampling_rate, ndigits=12) | |
| # expected_num_frames = compute_num_frames( | |
| # duration=duration, | |
| # frame_shift=self.frame_shift, | |
| # sampling_rate=sampling_rate, | |
| # ) | |
| # assert abs(codes.shape[-1] - expected_num_frames) <= 1 | |
| # codes = codes[..., :expected_num_frames] | |
| # return codes.cpu().squeeze(0).permute(1, 0).numpy() | |
| # @property | |
| # def frame_shift(self) -> Seconds: | |
| # return self.config.frame_shift | |
| # def feature_dim(self, sampling_rate: int) -> int: | |
| # return self.config.num_quantizers | |
| # def pad_tensor_list(self, tensor_list, device, padding_value=0): | |
| # # 计算每个张量的长度 | |
| # lengths = [tensor.shape[0] for tensor in tensor_list] | |
| # # 使用pad_sequence函数进行填充 | |
| # tensor_list = [torch.Tensor(t).to(device) for t in tensor_list] | |
| # padded_tensor = torch.nn.utils.rnn.pad_sequence( | |
| # tensor_list, batch_first=True, padding_value=padding_value | |
| # ) | |
| # return padded_tensor, lengths | |
| # def extract_batch(self, samples, sampling_rate, lengths) -> np.ndarray: | |
| # samples = [wav.squeeze() for wav in samples] | |
| # device = self.tokenizer.device | |
| # samples, lengths = self.pad_tensor_list(samples, device) | |
| # samples = samples.unsqueeze(1) | |
| # if not isinstance(samples, torch.Tensor): | |
| # samples = torch.from_numpy(samples) | |
| # if len(samples.shape) != 3: | |
| # raise ValueError() | |
| # if sampling_rate != self.tokenizer.sample_rate: | |
| # samples = [ | |
| # convert_audio( | |
| # wav, | |
| # sampling_rate, | |
| # self.tokenizer.sample_rate, | |
| # self.tokenizer.channels, | |
| # ) | |
| # for wav in samples | |
| # ] | |
| # # Extract discrete codes from EnCodec | |
| # with torch.no_grad(): | |
| # encoded_frames = self.tokenizer.encode(samples.detach().to(device)) | |
| # encoded_frames = encoded_frames[0][0] # [B, n_q, T] | |
| # batch_codes = [] | |
| # for b, length in enumerate(lengths): | |
| # codes = encoded_frames[b] | |
| # duration = round(length / sampling_rate, ndigits=12) | |
| # expected_num_frames = compute_num_frames( | |
| # duration=duration, | |
| # frame_shift=self.frame_shift, | |
| # sampling_rate=sampling_rate, | |
| # ) | |
| # batch_codes.append(codes[..., :expected_num_frames]) | |
| # return [codes.cpu().permute(1, 0).numpy() for codes in batch_codes] | |
| if __name__ == "__main__": | |
| model = EncodecModel.encodec_model_24khz() | |
| model.set_target_bandwidth(6.0) | |
| # model.cuda() | |
| samples = torch.from_numpy(np.random.random([4, 1, 30000])).type(torch.float32) | |
| codes_norm = model.encode(samples.cuda()) | |
| remove_encodec_weight_norm(model) | |
| codes_raw = model.encode(samples.cuda()) | |
| assert torch.allclose(codes_raw[0][0], codes_norm[0][0]) |