Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import torch | |
| import torchaudio | |
| import numpy as np | |
| import random | |
| import whisper | |
| import fire | |
| from argparse import Namespace | |
| from data.tokenizer import ( | |
| AudioTokenizer, | |
| TextTokenizer, | |
| ) | |
| from models import voice_star | |
| from inference_tts_utils import inference_one_sample | |
| ############################################################ | |
| # Utility Functions | |
| ############################################################ | |
| def seed_everything(seed=1): | |
| os.environ['PYTHONHASHSEED'] = str(seed) | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed(seed) | |
| torch.backends.cudnn.benchmark = False | |
| torch.backends.cudnn.deterministic = True | |
| def estimate_duration(ref_audio_path, text): | |
| """ | |
| Estimate duration based on seconds per character from the reference audio. | |
| """ | |
| info = torchaudio.info(ref_audio_path) | |
| audio_duration = info.num_frames / info.sample_rate | |
| length_text = max(len(text), 1) | |
| spc = audio_duration / length_text # seconds per character | |
| return len(text) * spc | |
| ############################################################ | |
| # Main Inference Function | |
| ############################################################ | |
| def run_inference( | |
| reference_speech="./demo/5895_34622_000026_000002.wav", | |
| target_text="I cannot believe that the same model can also do text to speech synthesis too! And you know what? this audio is 8 seconds long.", | |
| # Model | |
| model_name="VoiceStar_840M_30s", # or VoiceStar_840M_40s, the later model is trained on maximally 40s long speech | |
| model_root="./pretrained", | |
| # Additional optional | |
| reference_text=None, # if None => run whisper on reference_speech | |
| target_duration=None, # if None => estimate from reference_speech and target_text | |
| # Default hyperparameters from snippet | |
| codec_audio_sr=16000, # do not change | |
| codec_sr=50, # do not change | |
| top_k=10, # try 10, 20, 30, 40 | |
| top_p=1, # do not change | |
| min_p=1, # do not change | |
| temperature=1, | |
| silence_tokens=None, # do not change it | |
| kvcache=1, # if OOM, set to 0 | |
| multi_trial=None, # do not change it | |
| repeat_prompt=1, # increase this to improve speaker similarity, but it reference speech duration in total adding target duration is longer than maximal training duration, quality may drop | |
| stop_repetition=3, # will not use it | |
| sample_batch_size=1, # do not change | |
| # Others | |
| seed=1, | |
| output_dir="./generated_tts", | |
| # Some snippet-based defaults | |
| cut_off_sec=100, # do not adjust this, we always use the entire reference speech. If you wish to change, also make sure to change the reference_transcript, so that it's only the trasnscript of the speech remained | |
| ): | |
| """ | |
| Inference script using Fire. | |
| Example: | |
| python inference_commandline.py \ | |
| --reference_speech "./demo/5895_34622_000026_000002.wav" \ | |
| --target_text "I cannot believe ... this audio is 10 seconds long." \ | |
| --reference_text "(optional) text to use as prefix" \ | |
| --target_duration (optional float) | |
| """ | |
| # Seed everything | |
| seed_everything(seed) | |
| # Load model, phn2num, and args | |
| torch.serialization.add_safe_globals([Namespace]) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| ckpt_fn = os.path.join(model_root, model_name+".pth") | |
| if not os.path.exists(ckpt_fn): | |
| # use wget to download | |
| print(f"[Info] Downloading {model_name} checkpoint...") | |
| os.system(f"wget https://huggingface.co/pyp1/VoiceStar/resolve/main/{model_name}.pth?download=true -O {ckpt_fn}") | |
| bundle = torch.load(ckpt_fn, map_location=device, weights_only=True) | |
| args = bundle["args"] | |
| phn2num = bundle["phn2num"] | |
| model = voice_star.VoiceStar(args) | |
| model.load_state_dict(bundle["model"]) | |
| model.to(device) | |
| model.eval() | |
| # If reference_text not provided, use whisper large-v3-turbo | |
| if reference_text is None: | |
| print("[Info] No reference_text provided, transcribing reference_speech with Whisper.") | |
| wh_model = whisper.load_model("large-v3-turbo") | |
| result = wh_model.transcribe(reference_speech) | |
| prefix_transcript = result["text"] | |
| print(f"[Info] Whisper transcribed text: {prefix_transcript}") | |
| else: | |
| prefix_transcript = reference_text | |
| # If target_duration not provided, estimate from reference speech + target_text | |
| if target_duration is None: | |
| target_generation_length = estimate_duration(reference_speech, target_text) | |
| print(f"[Info] target_duration not provided, estimated as {target_generation_length:.2f} seconds. If not desired, please provide a target_duration.") | |
| else: | |
| target_generation_length = float(target_duration) | |
| # signature from snippet | |
| if args.n_codebooks == 4: | |
| signature = "./pretrained/encodec_6f79c6a8.th" | |
| elif args.n_codebooks == 8: | |
| signature = "./pretrained/encodec_8cb1024_giga.th" | |
| else: | |
| # fallback, just use the 6-f79c6a8 | |
| signature = "./pretrained/encodec_6f79c6a8.th" | |
| if silence_tokens is None: | |
| # default from snippet | |
| silence_tokens = [] | |
| if multi_trial is None: | |
| # default from snippet | |
| multi_trial = [] | |
| delay_pattern_increment = args.n_codebooks + 1 # from snippet | |
| # We can compute prompt_end_frame if we want, from snippet | |
| info = torchaudio.info(reference_speech) | |
| prompt_end_frame = int(cut_off_sec * info.sample_rate) | |
| # Prepare tokenizers | |
| audio_tokenizer = AudioTokenizer(signature=signature) | |
| text_tokenizer = TextTokenizer(backend="espeak") | |
| # decode_config from snippet | |
| decode_config = { | |
| 'top_k': top_k, | |
| 'top_p': top_p, | |
| 'min_p': min_p, | |
| 'temperature': temperature, | |
| 'stop_repetition': stop_repetition, | |
| 'kvcache': kvcache, | |
| 'codec_audio_sr': codec_audio_sr, | |
| 'codec_sr': codec_sr, | |
| 'silence_tokens': silence_tokens, | |
| 'sample_batch_size': sample_batch_size | |
| } | |
| # Run inference | |
| print("[Info] Running TTS inference...") | |
| concated_audio, gen_audio = inference_one_sample( | |
| model, args, phn2num, text_tokenizer, audio_tokenizer, | |
| reference_speech, target_text, | |
| device, decode_config, | |
| prompt_end_frame=prompt_end_frame, | |
| target_generation_length=target_generation_length, | |
| delay_pattern_increment=delay_pattern_increment, | |
| prefix_transcript=prefix_transcript, | |
| multi_trial=multi_trial, | |
| repeat_prompt=repeat_prompt, | |
| ) | |
| # The model returns a list of waveforms, pick the first | |
| concated_audio, gen_audio = concated_audio[0].cpu(), gen_audio[0].cpu() | |
| # Save the audio (just the generated portion, as the snippet does) | |
| os.makedirs(output_dir, exist_ok=True) | |
| out_filename = "generated.wav" | |
| out_path = os.path.join(output_dir, out_filename) | |
| torchaudio.save(out_path, gen_audio, codec_audio_sr) | |
| print(f"[Success] Generated audio saved to {out_path}") | |
| def main(): | |
| fire.Fire(run_inference) | |
| if __name__ == "__main__": | |
| main() | |