Spaces:
Running
on
Zero
Running
on
Zero
| #!/usr/bin/env python3 | |
| """ | |
| gradio_tts_app.py | |
| Run: | |
| python gradio_tts_app.py | |
| Then open the printed local or public URL in your browser. | |
| """ | |
| import os | |
| import random | |
| import numpy as np | |
| import torch | |
| import torchaudio | |
| import whisper | |
| import gradio as gr | |
| from argparse import Namespace | |
| import spaces | |
| # --------------------------------------------------------------------- | |
| # The following imports assume your local project structure: | |
| # data/tokenizer.py | |
| # models/voice_star.py | |
| # inference_tts_utils.py | |
| # Adjust if needed. | |
| # --------------------------------------------------------------------- | |
| from data.tokenizer import AudioTokenizer, TextTokenizer | |
| from models import voice_star | |
| from inference_tts_utils import inference_one_sample | |
| ############################################################ | |
| # Utility Functions | |
| ############################################################ | |
| def seed_everything(seed=1): | |
| os.environ['PYTHONHASHSEED'] = str(seed) | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed(seed) | |
| torch.backends.cudnn.benchmark = False | |
| torch.backends.cudnn.deterministic = True | |
| def estimate_duration(ref_audio_path, text): | |
| """ | |
| Estimate duration based on seconds per character from the reference audio. | |
| """ | |
| info = torchaudio.info(ref_audio_path) | |
| audio_duration = info.num_frames / info.sample_rate | |
| length_text = max(len(text), 1) | |
| spc = audio_duration / length_text # seconds per character | |
| return len(text) * spc | |
| ############################################################ | |
| # Main Inference Function | |
| ############################################################ | |
| def run_inference( | |
| # User-adjustable parameters (no "# do not change" in snippet) | |
| reference_speech="./demo/5895_34622_000026_000002.wav", | |
| target_text="VoiceStar is a very interesting model, it's duration controllable and can extrapolate", | |
| model_name="VoiceStar_840M_40s", | |
| model_root="./pretrained", | |
| reference_text=None, # optional | |
| target_duration=None, # optional | |
| top_k=10, # can try 10, 20, 30, 40 | |
| temperature=1, | |
| kvcache=1, # if OOM, set to 0 | |
| repeat_prompt=1, # use higher to improve speaker similarity | |
| stop_repetition=3, # snippet says "will not use it" but not "do not change" | |
| seed=1, | |
| output_dir="./generated_tts", | |
| # Non-adjustable parameters (based on snippet instructions) | |
| codec_audio_sr=16000, # do not change | |
| codec_sr=50, # do not change | |
| top_p=1, # do not change | |
| min_p=1, # do not change | |
| silence_tokens=None, # do not change it | |
| multi_trial=None, # do not change it | |
| sample_batch_size=1, # do not change | |
| cut_off_sec=100, # do not adjust | |
| ): | |
| """ | |
| Inference script for VoiceStar TTS. | |
| """ | |
| # 1. Set seed | |
| seed_everything(seed) | |
| # 2. Load model checkpoint | |
| torch.serialization.add_safe_globals([Namespace]) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| ckpt_fn = os.path.join(model_root, model_name + ".pth") | |
| if not os.path.exists(ckpt_fn): | |
| # use wget to download | |
| print(f"[Info] Downloading {model_name} checkpoint...") | |
| os.system(f"wget https://huggingface.co/pyp1/VoiceStar/resolve/main/{model_name}.pth?download=true -O {ckpt_fn}") | |
| bundle = torch.load(ckpt_fn, map_location=device, weights_only=True) | |
| args = bundle["args"] | |
| phn2num = bundle["phn2num"] | |
| model = voice_star.VoiceStar(args) | |
| model.load_state_dict(bundle["model"]) | |
| model.to(device) | |
| model.eval() | |
| # 3. If reference_text not provided, transcribe reference speech with Whisper | |
| if reference_text is None: | |
| print("[Info] No reference_text provided. Transcribing reference_speech with Whisper (large-v3-turbo).") | |
| wh_model = whisper.load_model("large-v3-turbo") | |
| result = wh_model.transcribe(reference_speech) | |
| prefix_transcript = result["text"] | |
| print(f"[Info] Whisper transcribed text: {prefix_transcript}") | |
| else: | |
| prefix_transcript = reference_text | |
| # 4. If target_duration not provided, estimate from reference speech + target_text | |
| if target_duration is None: | |
| target_generation_length = estimate_duration(reference_speech, target_text) | |
| print(f"[Info] target_duration not provided, estimated as {target_generation_length:.2f}s. Provide --target_duration if needed.") | |
| else: | |
| target_generation_length = float(target_duration) | |
| # 5. Prepare signature from snippet | |
| if args.n_codebooks == 4: | |
| signature = "./pretrained/encodec_6f79c6a8.th" | |
| elif args.n_codebooks == 8: | |
| signature = "./pretrained/encodec_8cb1024_giga.th" | |
| else: | |
| signature = "./pretrained/encodec_6f79c6a8.th" | |
| if silence_tokens is None: | |
| silence_tokens = [] | |
| if multi_trial is None: | |
| multi_trial = [] | |
| delay_pattern_increment = args.n_codebooks + 1 # from snippet | |
| info = torchaudio.info(reference_speech) | |
| prompt_end_frame = int(cut_off_sec * info.sample_rate) | |
| # 6. Tokenizers | |
| audio_tokenizer = AudioTokenizer(signature=signature) | |
| text_tokenizer = TextTokenizer(backend="espeak") | |
| # 7. decode_config | |
| decode_config = { | |
| "top_k": top_k, | |
| "top_p": top_p, | |
| "min_p": min_p, | |
| "temperature": temperature, | |
| "stop_repetition": stop_repetition, | |
| "kvcache": kvcache, | |
| "codec_audio_sr": codec_audio_sr, | |
| "codec_sr": codec_sr, | |
| "silence_tokens": silence_tokens, | |
| "sample_batch_size": sample_batch_size, | |
| } | |
| # 8. Run inference | |
| print("[Info] Running TTS inference...") | |
| concated_audio, gen_audio = inference_one_sample( | |
| model, args, phn2num, text_tokenizer, audio_tokenizer, | |
| reference_speech, target_text, | |
| device, decode_config, | |
| prompt_end_frame=prompt_end_frame, | |
| target_generation_length=target_generation_length, | |
| delay_pattern_increment=delay_pattern_increment, | |
| prefix_transcript=prefix_transcript, | |
| multi_trial=multi_trial, | |
| repeat_prompt=repeat_prompt, | |
| ) | |
| # The model returns a list of waveforms, pick the first | |
| concated_audio, gen_audio = concated_audio[0].cpu(), gen_audio[0].cpu() | |
| # 9. Save generated audio | |
| os.makedirs(output_dir, exist_ok=True) | |
| out_filename = "generated.wav" | |
| out_path = os.path.join(output_dir, out_filename) | |
| torchaudio.save(out_path, gen_audio, codec_audio_sr) | |
| print(f"[Success] Generated audio saved to {out_path}") | |
| return out_path # Return the path for Gradio to load | |
| ############################ | |
| # Transcription function | |
| ############################ | |
| def transcribe_audio(reference_speech): | |
| """ | |
| Transcribe uploaded reference audio with Whisper, return text. | |
| If no file, return empty string. | |
| """ | |
| if reference_speech is None: | |
| return "" | |
| audio_path = reference_speech # Because type="filepath" | |
| if not os.path.exists(audio_path): | |
| return "File not found." | |
| print("[Info] Transcribing with Whisper...") | |
| model = whisper.load_model("medium") # or "large-v2" etc. | |
| result = model.transcribe(audio_path) | |
| return result["text"] | |
| ############################ | |
| # Gradio UI | |
| ############################ | |
| def main(): | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## VoiceStar TTS with Editable Reference Text") | |
| with gr.Row(): | |
| reference_speech_input = gr.Audio( | |
| label="Reference Speech", | |
| type="filepath", | |
| elem_id="ref_speech" | |
| ) | |
| transcribe_button = gr.Button("Transcribe") | |
| # The transcribed text appears here and can be edited | |
| reference_text_box = gr.Textbox( | |
| label="Reference Text (Editable)", | |
| placeholder="Click 'Transcribe' to auto-fill from reference speech...", | |
| lines=2 | |
| ) | |
| target_text_box = gr.Textbox( | |
| label="Target Text", | |
| value="VoiceStar is a very interesting model, it's duration controllable and can extrapolate to unseen duration.", | |
| lines=3 | |
| ) | |
| model_name_box = gr.Textbox( | |
| label="Model Name", | |
| value="VoiceStar_840M_40s" | |
| ) | |
| model_root_box = gr.Textbox( | |
| label="Model Root Directory", | |
| value="/data1/scratch/pyp/BoostedVoiceEditor/runs" | |
| ) | |
| reference_duration_box = gr.Textbox( | |
| label="Target Duration (Optional)", | |
| placeholder="Leave empty for auto-estimate." | |
| ) | |
| top_k_box = gr.Number(label="top_k", value=10) | |
| temperature_box = gr.Number(label="temperature", value=1.0) | |
| kvcache_box = gr.Number(label="kvcache (1 or 0)", value=1) | |
| repeat_prompt_box = gr.Number(label="repeat_prompt", value=1) | |
| stop_repetition_box = gr.Number(label="stop_repetition", value=3) | |
| seed_box = gr.Number(label="Random Seed", value=1) | |
| output_dir_box = gr.Textbox(label="Output Directory", value="./generated_tts") | |
| generate_button = gr.Button("Generate TTS") | |
| output_audio = gr.Audio(label="Generated Audio", type="filepath") | |
| # 1) When user clicks "Transcribe", we call `transcribe_audio` | |
| transcribe_button.click( | |
| fn=transcribe_audio, | |
| inputs=[reference_speech_input], | |
| outputs=[reference_text_box], | |
| ) | |
| # 2) The actual TTS generation function. | |
| def gradio_inference( | |
| reference_speech, | |
| reference_text, | |
| target_text, | |
| model_name, | |
| model_root, | |
| target_duration, | |
| top_k, | |
| temperature, | |
| kvcache, | |
| repeat_prompt, | |
| stop_repetition, | |
| seed, | |
| output_dir | |
| ): | |
| # Convert any empty strings to None for optional fields | |
| dur = float(target_duration) if target_duration else None | |
| out_path = run_inference( | |
| reference_speech=reference_speech, | |
| reference_text=reference_text if reference_text else None, | |
| target_text=target_text, | |
| model_name=model_name, | |
| model_root=model_root, | |
| target_duration=dur, | |
| top_k=int(top_k), | |
| temperature=float(temperature), | |
| kvcache=int(kvcache), | |
| repeat_prompt=int(repeat_prompt), | |
| stop_repetition=int(stop_repetition), | |
| seed=int(seed), | |
| output_dir=output_dir | |
| ) | |
| return out_path | |
| # 3) Link the "Generate TTS" button | |
| generate_button.click( | |
| fn=gradio_inference, | |
| inputs=[ | |
| reference_speech_input, | |
| reference_text_box, | |
| target_text_box, | |
| model_name_box, | |
| model_root_box, | |
| reference_duration_box, | |
| top_k_box, | |
| temperature_box, | |
| kvcache_box, | |
| repeat_prompt_box, | |
| stop_repetition_box, | |
| seed_box, | |
| output_dir_box | |
| ], | |
| outputs=[output_audio], | |
| ) | |
| demo.launch(server_name="0.0.0.0", server_port=7860, debug=True) | |
| if __name__ == "__main__": | |
| main() |