Spaces:
Running
Running
| import torch | |
| import torchaudio | |
| from indextts.infer import IndexTTS | |
| from indextts.utils.feature_extractors import MelSpectrogramFeatures | |
| from torch.nn import functional as F | |
| if __name__ == "__main__": | |
| """ | |
| Test the padding of text tokens in inference. | |
| ``` | |
| python tests/padding_test.py checkpoints | |
| python tests/padding_test.py IndexTTS-1.5 | |
| ``` | |
| """ | |
| import transformers | |
| transformers.set_seed(42) | |
| import sys | |
| sys.path.append("..") | |
| if len(sys.argv) > 1: | |
| model_dir = sys.argv[1] | |
| else: | |
| model_dir = "checkpoints" | |
| audio_prompt="tests/sample_prompt.wav" | |
| tts = IndexTTS(cfg_path=f"{model_dir}/config.yaml", model_dir=model_dir, is_fp16=False, use_cuda_kernel=False) | |
| text = "晕 XUAN4 是 一 种 not very good GAN3 觉" | |
| text_tokens = tts.tokenizer.encode(text) | |
| text_tokens = torch.tensor(text_tokens, dtype=torch.int32, device=tts.device).unsqueeze(0) # [1, L] | |
| audio, sr = torchaudio.load(audio_prompt) | |
| audio = torch.mean(audio, dim=0, keepdim=True) | |
| audio = torchaudio.transforms.Resample(sr, 24000)(audio) | |
| auto_conditioning = MelSpectrogramFeatures()(audio).to(tts.device) | |
| cond_mel_lengths = torch.tensor([auto_conditioning.shape[-1]]).to(tts.device) | |
| with torch.no_grad(): | |
| kwargs = { | |
| "cond_mel_lengths": cond_mel_lengths, | |
| "do_sample": False, | |
| "top_p": 0.8, | |
| "top_k": None, | |
| "temperature": 1.0, | |
| "num_return_sequences": 1, | |
| "length_penalty": 0.0, | |
| "num_beams": 1, | |
| "repetition_penalty": 10.0, | |
| "max_generate_length": 100, | |
| } | |
| # baseline for non-pad | |
| baseline = tts.gpt.inference_speech(auto_conditioning, text_tokens, **kwargs) | |
| baseline = baseline.squeeze(0) | |
| print("Inference padded text tokens...") | |
| pad_text_tokens = [ | |
| F.pad(text_tokens, (8, 0), value=0), # left bos | |
| F.pad(text_tokens, (0, 8), value=1), # right eos | |
| F.pad(F.pad(text_tokens, (4, 0), value=0), (0, 4), value=1), # both side | |
| F.pad(F.pad(text_tokens, (6, 0), value=0), (0, 2), value=1), | |
| F.pad(F.pad(text_tokens, (0, 4), value=0), (0, 4), value=1), | |
| ] | |
| output_for_padded = [] | |
| for t in pad_text_tokens: | |
| # test for each padded text | |
| out = tts.gpt.inference_speech(auto_conditioning, text_tokens, **kwargs) | |
| output_for_padded.append(out.squeeze(0)) | |
| # batched inference | |
| print("Inference padded text tokens as one batch...") | |
| batched_text_tokens = torch.cat(pad_text_tokens, dim=0).to(tts.device) | |
| assert len(pad_text_tokens) == batched_text_tokens.shape[0] and batched_text_tokens.ndim == 2 | |
| batch_output = tts.gpt.inference_speech(auto_conditioning, batched_text_tokens, **kwargs) | |
| del pad_text_tokens | |
| mismatch_idx = [] | |
| print("baseline:", baseline.shape, baseline) | |
| print("--"*10) | |
| print("baseline vs padded output:") | |
| for i in range(len(output_for_padded)): | |
| if not baseline.equal(output_for_padded[i]): | |
| mismatch_idx.append(i) | |
| if len(mismatch_idx) > 0: | |
| print("mismatch:", mismatch_idx) | |
| for i in mismatch_idx: | |
| print(f"[{i}]: {output_for_padded[i]}") | |
| else: | |
| print("all matched") | |
| del output_for_padded | |
| print("--"*10) | |
| print("baseline vs batched output:") | |
| mismatch_idx = [] | |
| for i in range(batch_output.shape[0]): | |
| if not baseline.equal(batch_output[i]): | |
| mismatch_idx.append(i) | |
| if len(mismatch_idx) > 0: | |
| print("mismatch:", mismatch_idx) | |
| for i in mismatch_idx: | |
| print(f"[{i}]: {batch_output[i]}") | |
| else: | |
| print("all matched") | |
| print("Test finished.") |