Spaces:
Runtime error
Runtime error
| import torch | |
| import os | |
| import numpy as np | |
| import copy | |
| import gradio as gr | |
| import sys | |
| import spaces | |
| from vita_audio.tokenizer import get_audio_tokenizer | |
| from vita_audio.data.processor.audio_processor import add_audio_input_contiguous | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, AutoConfig, GenerationConfig | |
| PUNCTUATION = "οΌοΌγοΌοΌοΌοΌ οΌοΌοΌοΌοΌοΌοΌοΌοΌοΌοΌοΌοΌοΌοΌ οΌ»οΌΌοΌ½οΌΎοΌΏο½ο½ο½ο½ο½ο½ο½ ο½’ο½£ο½€γγγγγγγγγγγγγγγγγγγγγγ°γΎγΏββββββββββ¦β§οΉ." | |
| import math | |
| from numba import jit | |
| def float_to_int16(audio: np.ndarray) -> np.ndarray: | |
| am = int(math.ceil(float(np.abs(audio).max())) * 32768) | |
| am = 32767 * 32768 // am | |
| return np.multiply(audio, am).astype(np.int16) | |
| def is_wav(file_path): | |
| wav_extensions = {'.wav'} | |
| _, ext = os.path.splitext(file_path) | |
| return ext.lower() in wav_extensions | |
| def _parse_text(text): | |
| lines = text.split("\n") | |
| lines = [line for line in lines if line != ""] | |
| count = 0 | |
| for i, line in enumerate(lines): | |
| if "```" in line: | |
| count += 1 | |
| items = line.split("`") | |
| if count % 2 == 1: | |
| lines[i] = f'<pre><code class="language-{items[-1]}">' | |
| else: | |
| lines[i] = "<br></code></pre>" | |
| else: | |
| if i > 0 and count % 2 == 1: | |
| line = line.replace("`", r"\`") | |
| line = line.replace("<", "<") | |
| line = line.replace(">", ">") | |
| line = line.replace(" ", " ") | |
| line = line.replace("*", "*") | |
| line = line.replace("_", "_") | |
| line = line.replace("-", "-") | |
| line = line.replace(".", ".") | |
| line = line.replace("!", "!") | |
| line = line.replace("(", "(") | |
| line = line.replace(")", ")") | |
| line = line.replace("$", "$") | |
| lines[i] = "<br>" + line | |
| return "".join(lines) | |
| def _launch_demo(model, tokenizer, audio_tokenizer): | |
| def predict(_chatbot, task_history,task): | |
| chat_query = task_history[-1][0] | |
| print(task_history) | |
| messages = [] | |
| audio_path_list =[] | |
| if task == 'Spoken QA': | |
| messages = [ | |
| { | |
| "role": "system", | |
| #"content": "Your Name: Luke\nYour Gender: male\n\nRespond in a text-audio interleaved manner.", | |
| # "content": "Your Name: Lucy\nYour Gender: female\nRespond in a text-audio interleaved manner.", | |
| "content": "Your Name: Omni\nYour Gender: female\nRespond in a text-audio interleaved manner.", | |
| }, | |
| ] | |
| for i, (q, a) in enumerate(task_history): | |
| if isinstance(q, (tuple, list)) and is_wav(q[0]): | |
| audio_path_list.append(q[0]) | |
| messages = messages + [ | |
| { | |
| "role": "user", | |
| "content": f"\n<|audio|>", | |
| }, | |
| ] | |
| else: | |
| messages = messages + [ | |
| { | |
| "role": "user", | |
| "content": q , | |
| }, | |
| ] | |
| if a != None: | |
| messages = messages + [ | |
| { | |
| "role": "assistant", | |
| "content": a , | |
| }, | |
| ] | |
| model.generation_config.do_sample = False | |
| elif task == 'TTS': | |
| for i, (q, a) in enumerate(task_history): | |
| if isinstance(q, (tuple, list)) and is_wav(q[0]): | |
| audio_path_list.append(q[0]) | |
| messages = messages + [ | |
| { | |
| "role": "user", | |
| "content": f"\n<|audio|>", | |
| }, | |
| ] | |
| else: | |
| messages = messages + [ | |
| { | |
| "role": "user", | |
| "content": f'Convert the text to speech.\n{q}' , | |
| }, | |
| ] | |
| if a != None: | |
| messages = messages + [ | |
| { | |
| "role": "assistant", | |
| "content": a , | |
| }, | |
| ] | |
| model.generation_config.do_sample = True | |
| elif task == 'ASR': | |
| for i, (q, a) in enumerate(task_history): | |
| if isinstance(q, (tuple, list)) and is_wav(q[0]): | |
| audio_path_list.append(q[0]) | |
| messages = messages + [ | |
| { | |
| "role": "user", | |
| "content": f"Convert the speech to text.\n<|audio|>", | |
| }, | |
| ] | |
| else: | |
| messages = messages + [ | |
| { | |
| "role": "user", | |
| "content": f"{q}" , | |
| }, | |
| ] | |
| if a != None: | |
| messages = messages + [ | |
| { | |
| "role": "assistant", | |
| "content": a , | |
| }, | |
| ] | |
| model.generation_config.do_sample = False | |
| add_generation_prompt =True | |
| input_ids = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=True, | |
| add_generation_prompt=add_generation_prompt, | |
| # return_tensors="pt", | |
| ) | |
| input_ids, audios, audio_indices = add_audio_input_contiguous( | |
| input_ids, audio_path_list, tokenizer, audio_tokenizer | |
| ) | |
| input_ids = torch.tensor([input_ids], dtype=torch.long).to("cuda") | |
| print("input", tokenizer.decode(input_ids[0], skip_special_tokens=False), flush=True) | |
| if audio_path_list == []: | |
| audios = None | |
| audio_indices = None | |
| outputs = model.generate( | |
| input_ids, | |
| audios=audios, | |
| audio_indices=audio_indices, | |
| ) | |
| output = tokenizer.decode(outputs[0], skip_special_tokens=False) | |
| # print(f"{output=}", flush=True) | |
| audio_offset = tokenizer.convert_tokens_to_ids("<|audio_0|>") | |
| begin_of_audio = tokenizer.convert_tokens_to_ids("<|begin_of_audio|>") | |
| end_of_audio = tokenizer.convert_tokens_to_ids("<|end_of_audio|>") | |
| im_end = tokenizer.convert_tokens_to_ids("<|im_end|>") | |
| response = outputs[0][len(input_ids[0]):] | |
| audio_tokens = [] | |
| text_tokens = [] | |
| for token_id in response: | |
| if token_id >= audio_offset: | |
| audio_tokens.append(token_id - audio_offset) | |
| elif (token_id.item() != begin_of_audio) and (token_id.item() != end_of_audio) and (token_id.item() != im_end): | |
| text_tokens.append(token_id) | |
| if len(audio_tokens) > 0: | |
| tts_speech = audio_tokenizer.decode(audio_tokens) | |
| audio_np = float_to_int16(tts_speech.cpu().numpy()) | |
| tts_speech = (22050,audio_np) | |
| else: | |
| tts_speech = None | |
| # import pdb;pdb.set_trace() | |
| history_response = tokenizer.decode(text_tokens) | |
| task_history[-1] = (chat_query, history_response) | |
| _chatbot[-1] = (chat_query, history_response) | |
| # print("query",chat_query) | |
| # print("task_history",task_history) | |
| # print(_chatbot) | |
| # print("answer: ",outputs) | |
| return _chatbot, tts_speech | |
| def add_text(history, task_history, text): | |
| task_text = text | |
| # import pdb;pdb.set_trace() | |
| if len(text) >= 2 and text[-1] in PUNCTUATION and text[-2] not in PUNCTUATION: | |
| task_text = text[:-1] | |
| history = history + [(_parse_text(text), None)] | |
| task_history = task_history + [(task_text, None)] | |
| return history, task_history, "" | |
| def add_audio(history, task_history, file): | |
| print(file) | |
| if file is None: | |
| return history, task_history | |
| history = history + [((file,), None)] | |
| task_history = task_history + [((file,), None)] | |
| return history, task_history | |
| def reset_user_input(): | |
| # import pdb;pdb.set_trace() | |
| return gr.update(value="") | |
| def reset_state(task_history): | |
| task_history.clear() | |
| return [] | |
| font_size = "2.5em" | |
| html = f""" | |
| <p align="center" style="font-size: {font_size}; line-height: 1;"> | |
| <span style="display: inline-block; vertical-align: middle;">VITA-Audio-Plus-Vanilla</span> | |
| </p> | |
| <center> | |
| <font size=3> | |
| <p> | |
| <b>VITA-Audio</b> has been fully open-sourced on <a href='https://huggingface.co/VITA-MLLM'>π Huggingface</a> and <a href='https://github.com/VITA-MLLM/VITA-Audio'>π GitHub</a>. If you find VITA-Audio useful, a likeβ€οΈ or a starπ would be appreciated. | |
| </p> | |
| </font> | |
| <font size=3> | |
| <p> | |
| The deployment of the VITA-Audio-Plus-Vanilla model employs a non-streaming deployment approach. | |
| For the ASR and TTS tasks, only single-turn dialogues are supported. In the Spoken QA task, generated text is used as dialogue history to reduce the context length. | |
| </p> | |
| </font> | |
| </center> | |
| """ | |
| with gr.Blocks(title="VITA-Audio-Plus-Vanilla") as demo: | |
| gr.HTML(html) | |
| chatbot = gr.Chatbot(label='VITA-Audio-Plus-Vanilla', elem_classes="control-height", height=500) | |
| query = gr.Textbox(lines=2, label='Text Input') | |
| task_history = gr.State([]) | |
| with gr.Row(): | |
| add_text_button = gr.Button("Submit Text (ζδΊ€ζζ¬)") | |
| add_audio_button = gr.Button("Submit Audio (ζδΊ€ι³ι’)") | |
| empty_bin = gr.Button("π§Ή Clear History (ζΈ ι€εε²)") | |
| task = gr.Radio( | |
| choices = ["ASR", "TTS", "Spoken QA"], label="TASK", value = 'Spoken QA' | |
| ) | |
| with gr.Row(scale=1): | |
| record_btn = gr.Audio(sources=[ "microphone","upload"], type="filepath", label="π€ Record or Upload Audio (ε½ι³ζδΈδΌ ι³ι’)", show_download_button=True, waveform_options=gr.WaveformOptions(sample_rate=16000)) | |
| audio_output = gr.Audio(label="Play", streaming=True, | |
| autoplay=True, show_download_button=True) | |
| add_text_button.click(add_text, [chatbot, task_history, query], [chatbot, task_history], show_progress=True).then( | |
| reset_user_input, [], [query] | |
| ).then( | |
| predict, [chatbot, task_history,task], [chatbot,audio_output], show_progress=True | |
| ) | |
| empty_bin.click(reset_state, [task_history], [chatbot], show_progress=True) | |
| add_audio_button.click(add_audio, [chatbot, task_history,record_btn], [chatbot, task_history], show_progress=True).then( | |
| predict, [chatbot, task_history,task], [chatbot,audio_output], show_progress=True | |
| ) | |
| demo.launch( | |
| show_error=True, | |
| ) | |
| def main(): | |
| model_name_or_path = "VITA-MLLM/VITA-Audio-Plus-Vanilla" | |
| device_map = "cuda:0" | |
| sys.path.append("third_party/GLM-4-Voice/") | |
| sys.path.append("third_party/GLM-4-Voice/cosyvoice/") | |
| sys.path.append("third_party/GLM-4-Voice/third_party/Matcha-TTS/") | |
| from huggingface_hub import snapshot_download | |
| audio_tokenizer_path = snapshot_download(repo_id="THUDM/glm-4-voice-tokenizer") | |
| flow_path = snapshot_download(repo_id="THUDM/glm-4-voice-decoder") | |
| audio_tokenizer_rank = 0 | |
| audio_tokenizer_type = "sensevoice_glm4voice" | |
| torch_dtype = torch.bfloat16 | |
| audio_tokenizer = get_audio_tokenizer( | |
| audio_tokenizer_path, audio_tokenizer_type, flow_path=flow_path, rank=audio_tokenizer_rank | |
| ) | |
| audio_tokenizer.load_model() | |
| from evaluation.get_chat_template import qwen2_chat_template as chat_template | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_name_or_path, | |
| trust_remote_code=True, | |
| chat_template=chat_template, | |
| ) | |
| # print(f"{tokenizer=}") | |
| # print(f"{tokenizer.get_chat_template()=}") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name_or_path, | |
| trust_remote_code=True, | |
| device_map=device_map, | |
| torch_dtype=torch_dtype, | |
| attn_implementation="flash_attention_2", | |
| ).eval() | |
| # print(f"{model.config.model_type=}") | |
| model.generation_config = GenerationConfig.from_pretrained( | |
| model_name_or_path, trust_remote_code=True | |
| ) | |
| model.generation_config.max_new_tokens = 4096 | |
| model.generation_config.chat_format = "chatml" | |
| model.generation_config.max_window_size = 8192 | |
| model.generation_config.use_cache = True | |
| model.generation_config.do_sample = True | |
| model.generation_config.temperature = 1.0 | |
| model.generation_config.top_k = 50 | |
| model.generation_config.top_p = 1.0 | |
| model.generation_config.num_beams = 1 | |
| model.generation_config.pad_token_id = tokenizer.pad_token_id | |
| model.generation_config.mtp_inference_mode = [8192,10] | |
| _launch_demo(model, tokenizer, audio_tokenizer) | |
| if __name__ == '__main__': | |
| main() | |