Spaces:
Running
on
Zero
Running
on
Zero
| import spaces | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| import torch | |
| import soundfile as sf | |
| from xcodec2.modeling_xcodec2 import XCodec2Model | |
| import torchaudio | |
| import gradio as gr | |
| import tempfile | |
| import os | |
| import numpy as np | |
| llasa_1b = "SebastianBodza/SmolKartoffel-135M-v0.1" | |
| tokenizer = AutoTokenizer.from_pretrained(llasa_1b, token=os.getenv("HF_TOKEN")) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| llasa_1b, trust_remote_code=True, device_map="cuda", token=os.getenv("HF_TOKEN") | |
| ) | |
| model_path = "srinivasbilla/xcodec2" | |
| Codec_model = XCodec2Model.from_pretrained(model_path) | |
| Codec_model.eval().cuda() | |
| whisper_turbo_pipe = pipeline( | |
| "automatic-speech-recognition", | |
| model="openai/whisper-large-v3-turbo", | |
| torch_dtype=torch.float16, | |
| device="cuda", | |
| ) | |
| SPEAKERS = { | |
| "Male 1": { | |
| "path": "speakers/deep_speaker.mp3", | |
| "transcript": "Das große Tor von Minas Tirith brach erst, nachdem er die Ramme eingesetzt hatte.", | |
| "description": "Eine tiefe epische Männerstimme.", | |
| }, | |
| "Male 2": { | |
| "path": "speakers/male_austrian_accent.mp3", | |
| "transcript": "Man kann sich auch leichter vorstellen, wie schwierig es ist, dass man Entscheidungen trifft, die allen passen.", | |
| "description": "Eine männliche Stimme mit österreicherischem Akzent.", | |
| }, | |
| "Male 3": { | |
| "path": "speakers/male_energic.mp3", | |
| "transcript": "Wo keine Infrastruktur, da auch keine Ansiedlung von IT-Unternehmen und deren Beschäftigten bzw. dem geeigneten Fachkräftenachwuchs. Kann man diese Rechnung so einfach aufmachen, wie es es tatsächlich um deren regionale Verteilung beschäftigt?", | |
| "description": "Eine männliche energische Stimme", | |
| }, | |
| "Male 4": { | |
| "path": "speakers/schneller_speaker.mp3", | |
| "transcript": "Genau, wenn wir alle Dächer voll machen, also alle Dächer von Einfamilienhäusern, alleine mit den Einfamilienhäusern können wir 20 Prozent des heutigen Strombedarfs decken.", | |
| "description": "Eine männliche Spreche mit schnellerem Tempo.", | |
| }, | |
| "Female 1": { | |
| "path": "speakers/female_standard.mp3", | |
| "transcript": "Es wird ein Beispiel für ein barrierearmes Layout gegeben, sowie Tipps und ein Verweis auf eine Checkliste, die hilft, Barrierearmut in den eigenen Materialien zu prüfen bzw. umzusetzen.", | |
| "description": "Eine weibliche Stimme.", | |
| }, | |
| "Female 2": { | |
| "path": "speakers/female_energic.mp3", | |
| "transcript": "Dunkel flog weiter durch das Wald. Er sah die Sterne am Phaneten an sich vorbeiziehen und fühlte sich frei und glücklich.", | |
| "description": "Eine weibliche Erzähler-Stimme.", | |
| }, | |
| "Female 3": { | |
| "path": "speakers/austrian_accent.mp3", | |
| "transcript": "Die politische Europäische Union war geboren, verbrieft im Vertrag von Maastricht. Ab diesem Zeitpunkt bestehen zwei Vertragswerke.", | |
| "description": "Eine weibliche Stimme mit österreicherischem Akzent.", | |
| }, | |
| "Special 1": { | |
| "path": "speakers/low_audio.mp3", | |
| "transcript": "Druckplatten und Lasersensoren, um sicherzugehen, dass er auch da drin ist und", | |
| "description": "Eine männliche Stimme mit schlechter Audioqualität als Effekt.", | |
| }, | |
| } | |
| def preview_speaker(display_name): | |
| """Returns the audio and transcript for preview""" | |
| speaker_name = speaker_display_dict[display_name] | |
| if speaker_name in SPEAKERS: | |
| waveform, sample_rate = torchaudio.load(SPEAKERS[speaker_name]["path"]) | |
| return (sample_rate, waveform[0].numpy()), SPEAKERS[speaker_name]["transcript"] | |
| return None, "" | |
| def normalize_audio(waveform: torch.Tensor, target_db: float = -20) -> torch.Tensor: | |
| """ | |
| Normalize audio volume to target dB and limit gain range. | |
| Args: | |
| waveform (torch.Tensor): Input audio waveform | |
| target_db (float): Target dB level (default: -20) | |
| Returns: | |
| torch.Tensor: Normalized audio waveform | |
| """ | |
| # Calculate current dB | |
| eps = 1e-10 | |
| current_db = 20 * torch.log10(torch.max(torch.abs(waveform)) + eps) | |
| # Calculate required gain | |
| gain_db = target_db - current_db | |
| # Limit gain to -3 to 3 dB range | |
| gain_db = torch.clamp(gain_db, min=-3, max=3) | |
| # Apply gain | |
| gain_factor = 10 ** (gain_db / 20) | |
| normalized = waveform * gain_factor | |
| # Final peak normalization | |
| max_amplitude = torch.max(torch.abs(normalized)) | |
| if max_amplitude > 0: | |
| normalized = normalized / max_amplitude | |
| return normalized | |
| def ids_to_speech_tokens(speech_ids): | |
| speech_tokens_str = [] | |
| for speech_id in speech_ids: | |
| speech_tokens_str.append(f"<|s_{speech_id}|>") | |
| return speech_tokens_str | |
| def extract_speech_ids(speech_tokens_str): | |
| speech_ids = [] | |
| for token_str in speech_tokens_str: | |
| if token_str.startswith("<|s_") and token_str.endswith("|>"): | |
| num_str = token_str[4:-2] | |
| num = int(num_str) | |
| speech_ids.append(num) | |
| else: | |
| print(f"Unexpected token: {token_str}") | |
| return speech_ids | |
| def infer_with_speaker( | |
| display_name, | |
| target_text, | |
| temp, | |
| top_p_val, | |
| min_new_tokens, | |
| do_sample, | |
| progress=gr.Progress(), | |
| ): | |
| """Modified infer function that uses predefined speaker""" | |
| speaker_name = speaker_display_dict[display_name] # Get actual speaker name | |
| if speaker_name not in SPEAKERS: | |
| return None, "Invalid speaker selected" | |
| return infer( | |
| SPEAKERS[speaker_name]["path"], | |
| target_text, | |
| temp, | |
| top_p_val, | |
| min_new_tokens, | |
| do_sample, | |
| SPEAKERS[speaker_name]["transcript"], # Pass the predefined transcript | |
| progress, | |
| ) | |
| def infer_random( | |
| target_text, temp, top_p_val, min_new_tokens, do_sample, progress=gr.Progress() | |
| ): | |
| progress(0, "Generating speech with random voice...") | |
| if len(target_text) == 0: | |
| return None, "Please provide some text." | |
| elif len(target_text) > 500: | |
| gr.Warning("Text is too long. Please keep it under 500 characters.") | |
| target_text = target_text[:500] | |
| formatted_text = ( | |
| f"<|TEXT_UNDERSTANDING_START|>{target_text}<|TEXT_UNDERSTANDING_END|>" | |
| ) | |
| chat = [ | |
| { | |
| "role": "user", | |
| "content": "Convert the text to speech:" + formatted_text, | |
| }, | |
| { | |
| "role": "assistant", | |
| "content": "<|SPEECH_GENERATION_START|>", | |
| }, | |
| ] | |
| input_ids = tokenizer.apply_chat_template( | |
| chat, tokenize=True, return_tensors="pt", continue_final_message=True | |
| ) | |
| input_ids = input_ids.to("cuda") | |
| speech_end_id = tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_END|>") | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| input_ids, | |
| max_length=3000, | |
| eos_token_id=speech_end_id, | |
| do_sample=do_sample, | |
| top_p=top_p_val, | |
| temperature=temp, | |
| min_new_tokens=min_new_tokens, | |
| ) | |
| # Extract the generated speech tokens (skip prompt tokens) | |
| generated_ids = outputs[0][input_ids.shape[1] : -1] | |
| speech_tokens = tokenizer.batch_decode(generated_ids, skip_special_tokens=False) | |
| raw_output = " ".join(speech_tokens) | |
| speech_tokens = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) | |
| # Convert tokens (e.g., <|s_23456|>) to integers | |
| speech_tokens = extract_speech_ids(speech_tokens) | |
| speech_tokens = torch.tensor(speech_tokens).cuda().unsqueeze(0).unsqueeze(0) | |
| # Decode the speech tokens to speech waveform | |
| gen_wav = Codec_model.decode_code(speech_tokens) | |
| progress(1, "Synthesized!") | |
| return ( | |
| 16000, | |
| gen_wav[0, 0, :].cpu().numpy(), | |
| ), raw_output | |
| def gradio_infer(*args, **kwargs): | |
| return infer(*args, **kwargs) | |
| def infer( | |
| sample_audio_path, | |
| target_text, | |
| temp, | |
| top_p_val, | |
| min_new_tokens, | |
| do_sample, | |
| transcribed_text=None, | |
| progress=gr.Progress(), | |
| ): | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f: | |
| progress(0, "Loading and trimming audio...") | |
| waveform, sample_rate = torchaudio.load(sample_audio_path) | |
| waveform = normalize_audio(waveform) | |
| if len(waveform[0]) / sample_rate > 15: | |
| gr.Warning("Trimming audio to first 15secs.") | |
| waveform = waveform[:, : sample_rate * 15] | |
| waveform = torch.nn.functional.pad( | |
| waveform, (0, int(sample_rate * 0.5)), "constant", 0 | |
| ) | |
| # Check if the audio is stereo (i.e., has more than one channel) | |
| if waveform.size(0) > 1: | |
| # Convert stereo to mono by averaging the channels | |
| waveform_mono = torch.mean(waveform, dim=0, keepdim=True) | |
| else: | |
| # If already mono, just use the original waveform | |
| waveform_mono = waveform | |
| prompt_wav = torchaudio.transforms.Resample( | |
| orig_freq=sample_rate, new_freq=16000 | |
| )(waveform_mono) | |
| if transcribed_text is None: | |
| progress(0.3, "Transcribing audio...") | |
| prompt_text = whisper_turbo_pipe(prompt_wav[0].numpy())["text"].strip() | |
| else: | |
| prompt_text = transcribed_text | |
| progress(0.5, "Transcribed! Generating speech...") | |
| if len(target_text) == 0: | |
| return None | |
| elif len(target_text) > 500: | |
| gr.Warning("Text is too long. Please keep it under 300 characters.") | |
| target_text = target_text[:500] | |
| input_text = prompt_text + " " + target_text | |
| # TTS start! | |
| with torch.no_grad(): | |
| # Encode the prompt wav | |
| vq_code_prompt = Codec_model.encode_code(input_waveform=prompt_wav) | |
| vq_code_prompt = vq_code_prompt[0, 0, :] | |
| # Convert int 12345 to token <|s_12345|> | |
| speech_ids_prefix = ids_to_speech_tokens(vq_code_prompt) | |
| formatted_text = ( | |
| f"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>" | |
| ) | |
| # Tokenize the text and the speech prefix | |
| chat = [ | |
| { | |
| "role": "user", | |
| "content": "Convert the text to speech:" + formatted_text, | |
| }, | |
| { | |
| "role": "assistant", | |
| "content": "<|SPEECH_GENERATION_START|>" | |
| + "".join(speech_ids_prefix), | |
| }, | |
| ] | |
| input_ids = tokenizer.apply_chat_template( | |
| chat, | |
| tokenize=True, | |
| return_tensors="pt", | |
| continue_final_message=True, | |
| ) | |
| input_ids = input_ids.to("cuda") | |
| speech_end_id = tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_END|>") | |
| # Generate the speech autoregressively | |
| outputs = model.generate( | |
| input_ids, | |
| max_length=2048, # We trained our model with a max length of 2048 | |
| eos_token_id=speech_end_id, | |
| do_sample=do_sample, | |
| top_p=top_p_val, | |
| temperature=temp, | |
| min_new_tokens=min_new_tokens, | |
| ) | |
| # Extract the speech tokens | |
| generated_ids = outputs[0][input_ids.shape[1] - len(speech_ids_prefix) : -1] | |
| speech_tokens = tokenizer.batch_decode( | |
| generated_ids, skip_special_tokens=False | |
| ) | |
| raw_output = " ".join(speech_tokens) # Capture raw tokens | |
| speech_tokens = tokenizer.batch_decode( | |
| generated_ids, skip_special_tokens=True | |
| ) | |
| # Convert token <|s_23456|> to int 23456 | |
| speech_tokens = extract_speech_ids(speech_tokens) | |
| speech_tokens = torch.tensor(speech_tokens).cuda().unsqueeze(0).unsqueeze(0) | |
| # Decode the speech tokens to speech waveform | |
| gen_wav = Codec_model.decode_code(speech_tokens) | |
| # if only need the generated part | |
| gen_wav = gen_wav[:, :, prompt_wav.shape[1] :] | |
| progress(1, "Synthesized!") | |
| return ( | |
| 16000, | |
| gen_wav[0, 0, :].cpu().numpy(), | |
| ), raw_output # Return both audio and raw tokens | |
| with gr.Blocks() as app_tts: | |
| gr.Markdown("# Zero Shot Voice Clone TTS") | |
| with gr.Accordion("Model Settings", open=False): | |
| temperature = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.8, | |
| step=0.1, | |
| label="Temperature", | |
| info="Higher values = more random/creative output", | |
| ) | |
| top_p = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=1.0, | |
| step=0.1, | |
| label="Top P", | |
| info="Nucleus sampling threshold", | |
| ) | |
| min_new_tokens = gr.Slider( | |
| minimum=0, | |
| maximum=128, | |
| value=3, | |
| step=1, | |
| label="Min Length", | |
| info="If the model just produces a click you can force it to create longer generations.", | |
| ) | |
| do_sample = gr.Checkbox( | |
| label="Sample", value=True, info="Sample from the distribution" | |
| ) | |
| ref_audio_input = gr.Audio(label="Reference Audio", type="filepath") | |
| gen_text_input = gr.Textbox(label="Text to Generate", lines=10) | |
| generate_btn = gr.Button("Synthesize", variant="primary") | |
| audio_output = gr.Audio(label="Synthesized Audio") | |
| raw_output_display = gr.Textbox( | |
| label="Raw Model Output", interactive=False | |
| ) # Add textbox | |
| generate_btn.click( | |
| lambda *args: gradio_infer(*args, transcribed_text=None), | |
| inputs=[ | |
| ref_audio_input, | |
| gen_text_input, | |
| temperature, | |
| top_p, | |
| min_new_tokens, | |
| do_sample, | |
| ], | |
| outputs=[audio_output, raw_output_display], # Include both outputs | |
| ) | |
| with gr.Blocks() as app_speaker: | |
| gr.Markdown("# Predefined Speaker TTS") | |
| with gr.Accordion("Model Settings", open=False): | |
| temperature = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=1.0, | |
| step=0.1, | |
| label="Temperature", | |
| info="Higher values = more random/creative output", | |
| ) | |
| top_p = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=1.0, | |
| step=0.1, | |
| label="Top P", | |
| info="Nucleus sampling threshold", | |
| ) | |
| min_new_tokens = gr.Slider( | |
| minimum=0, | |
| maximum=128, | |
| value=3, | |
| step=1, | |
| label="Min Length", | |
| info="If the model just produces a click you can force it to create longer generations.", | |
| ) | |
| do_sample = gr.Checkbox( | |
| label="Sample", value=True, info="Sample from the distribution" | |
| ) | |
| with gr.Row(): | |
| speaker_display_dict = { | |
| f"{name} - {SPEAKERS[name]['description']}": name | |
| for name in SPEAKERS.keys() | |
| } | |
| speaker_dropdown = gr.Dropdown( | |
| choices=list(speaker_display_dict.keys()), | |
| label="Select Speaker", | |
| value=list(speaker_display_dict.keys())[0], | |
| ) | |
| preview_btn = gr.Button("Preview Voice") | |
| with gr.Row(): | |
| preview_audio = gr.Audio(label="Preview") | |
| preview_text = gr.Textbox(label="Original Transcript", interactive=False) | |
| gen_text_input = gr.Textbox(label="Text to Generate", lines=10) | |
| generate_btn = gr.Button("Synthesize", variant="primary") | |
| audio_output = gr.Audio(label="Synthesized Audio") | |
| raw_output_display = gr.Textbox(label="Raw Model Output", interactive=False) | |
| # Connect the preview button | |
| preview_btn.click( | |
| preview_speaker, | |
| inputs=[speaker_dropdown], | |
| outputs=[preview_audio, preview_text], | |
| ) | |
| # Connect the generate button | |
| generate_btn.click( | |
| infer_with_speaker, | |
| inputs=[ | |
| speaker_dropdown, | |
| gen_text_input, | |
| temperature, | |
| top_p, | |
| min_new_tokens, | |
| do_sample, | |
| ], | |
| outputs=[audio_output, raw_output_display], | |
| ) | |
| with gr.Blocks() as random: | |
| gr.Markdown("# Random Voice") | |
| with gr.Accordion("Model Settings", open=False): | |
| temperature = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=1.0, | |
| step=0.1, | |
| label="Temperature", | |
| info="Higher values = more random/creative output", | |
| ) | |
| top_p = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=1.0, | |
| step=0.1, | |
| label="Top P", | |
| info="Nucleus sampling threshold", | |
| ) | |
| min_new_tokens = gr.Slider( | |
| minimum=0, | |
| maximum=128, | |
| value=3, | |
| step=1, | |
| label="Min Length", | |
| info="If the model just produces a click you can force it to create longer generations.", | |
| ) | |
| do_sample = gr.Checkbox( | |
| label="Sample", value=True, info="Sample from the distribution" | |
| ) | |
| # Note: For random voice generation, no reference audio is used. | |
| gen_text_input = gr.Textbox(label="Text to Generate", lines=10) | |
| generate_btn = gr.Button("Synthesize", variant="primary") | |
| audio_output = gr.Audio(label="Synthesized Audio") | |
| raw_output_display = gr.Textbox(label="Raw Model Output", interactive=False) | |
| generate_btn.click( | |
| infer_random, | |
| inputs=[ | |
| gen_text_input, | |
| temperature, | |
| top_p, | |
| min_new_tokens, | |
| do_sample, | |
| ], | |
| outputs=[audio_output, raw_output_display], | |
| ) | |
| with gr.Blocks() as app_credits: | |
| gr.Markdown( | |
| """ | |
| # Credits | |
| * [zhenye234](https://github.com/zhenye234) for the original [repo](https://github.com/zhenye234/LLaSA_training) | |
| * [mrfakename](https://huggingface.co/mrfakename) for the [gradio demo code](https://huggingface.co/spaces/mrfakename/E2-F5-TTS) | |
| """ | |
| ) | |
| with gr.Blocks() as app: | |
| gr.Markdown( | |
| """ | |
| # SmolKartoffel-135M-v0.1 | |
| """ | |
| ) | |
| gr.TabbedInterface( | |
| [random, app_speaker, app_tts, app_credits], # app_audiobook, | |
| [ | |
| "Random", | |
| "Speaker", | |
| "Clone", | |
| # "Audiobook", | |
| "Credits", | |
| ], | |
| ) | |
| app.launch(ssr_mode=False) | |