Spaces:
Runtime error
Runtime error
| import json | |
| import logging | |
| import os | |
| import random | |
| import re | |
| import sys | |
| import time | |
| import uuid | |
| from threading import Thread | |
| from typing import Optional | |
| import torch | |
| import tqdm | |
| from torch import nn | |
| from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| from transformers.generation import GenerationConfig | |
| import torchaudio | |
| from vita_audio.data.processor.audio_processor import add_audio_input_contiguous | |
| from vita_audio.tokenizer import get_audio_tokenizer | |
| logger = logging.getLogger(__name__) | |
| logger.setLevel(logging.INFO) | |
| torch.manual_seed(1234) | |
| device_map = "cuda:0" | |
| audio_tokenizer_rank = 0 | |
| torch_dtype = torch.bfloat16 | |
| # model_name_or_path = sys.argv[1] | |
| # audio_tokenizer_path = sys.argv[2] | |
| # flow_path = sys.argv[3] | |
| if True: | |
| # if False: | |
| # sensevoice glm4voice tokenizer | |
| sys.path.append("third_party/GLM-4-Voice/") | |
| sys.path.append("third_party/GLM-4-Voice/cosyvoice/") | |
| sys.path.append("third_party/GLM-4-Voice/third_party/Matcha-TTS/") | |
| audio_tokenizer_path = "/data/models/THUDM/glm-4-voice-tokenizer" | |
| flow_path = "/data/models/THUDM/glm-4-voice-decoder" | |
| audio_tokenizer_type = "sensevoice_glm4voice" | |
| model_name_or_path = "VITA-MLLM/VITA-Audio-Plus-Vanilla/" | |
| # if True: | |
| if False: | |
| # glm4voice tokenizer | |
| sys.path.append("third_party/GLM-4-Voice/") | |
| sys.path.append("third_party/GLM-4-Voice/cosyvoice/") | |
| sys.path.append("third_party/GLM-4-Voice/third_party/Matcha-TTS/") | |
| audio_tokenizer_path = "/data/models/THUDM/glm-4-voice-tokenizer" | |
| flow_path = "/data/models/THUDM/glm-4-voice-decoder" | |
| audio_tokenizer_type = "glm4voice" | |
| # model_name_or_path = "VITA-MLLM/VITA-Audio-Balance" | |
| model_name_or_path = "VITA-MLLM/VITA-Audio-Boost" | |
| output_dir = "/data/output/LM/inference/" | |
| os.makedirs(output_dir, exist_ok=True) | |
| class TextAudioIteratorStreamer(TextIteratorStreamer): | |
| def __init__( | |
| self, | |
| tokenizer: "AutoTokenizer", | |
| skip_prompt: bool = False, | |
| timeout: Optional[float] = None, | |
| **decode_kwargs, | |
| ): | |
| super().__init__(tokenizer, skip_prompt, timeout, **decode_kwargs) | |
| # self.audio_offset = tokenizer.convert_tokens_to_ids("<|audio_0|>") | |
| self.audio_offset = tokenizer.convert_tokens_to_ids("<|begin_of_audio|>") | |
| self.num_decode_tokens = 0 | |
| def put(self, value): | |
| """ | |
| Receives tokens, decodes them, and prints them to stdout as soon as they form entire words. | |
| """ | |
| if len(value.shape) > 1 and value.shape[0] > 1: | |
| raise ValueError("TextStreamer only supports batch size 1") | |
| elif len(value.shape) > 1: | |
| value = value[0] | |
| if self.skip_prompt and self.next_tokens_are_prompt: | |
| self.next_tokens_are_prompt = False | |
| return | |
| self.num_decode_tokens += len(value) | |
| # Add the new token to the cache and decodes the entire thing. | |
| self.token_cache.extend(value.tolist()) | |
| text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs) | |
| # After the symbol for a new line, we flush the cache. | |
| if text.endswith("\n"): | |
| printable_text = text[self.print_len :] | |
| self.token_cache = [] | |
| self.print_len = 0 | |
| # If the last token is a CJK character, we print the characters. | |
| elif len(text) > 0 and self._is_chinese_char(ord(text[-1])): | |
| printable_text = text[self.print_len :] | |
| self.print_len += len(printable_text) | |
| elif self.token_cache[-1] >= self.audio_offset: | |
| printable_text = text[self.print_len :] | |
| self.print_len += len(printable_text) | |
| # Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words, | |
| # which may change with the subsequent token -- there are probably smarter ways to do this!) | |
| else: | |
| printable_text = text[self.print_len : text.rfind(" ") + 1] | |
| self.print_len += len(printable_text) | |
| self.on_finalized_text(printable_text) | |
| while self.text_queue.qsize() > 10: | |
| time.sleep(0.01) | |
| class BenchmarkIteratorStreamer(TextIteratorStreamer): | |
| def __init__( | |
| self, | |
| tokenizer: "AutoTokenizer", | |
| skip_prompt: bool = False, | |
| timeout: Optional[float] = None, | |
| **decode_kwargs, | |
| ): | |
| super().__init__(tokenizer, skip_prompt, timeout, **decode_kwargs) | |
| self.num_decode_tokens = 0 | |
| def put(self, value): | |
| """ | |
| Receives tokens, decodes them, and prints them to stdout as soon as they form entire words. | |
| """ | |
| if len(value.shape) > 1 and value.shape[0] > 1: | |
| raise ValueError("TextStreamer only supports batch size 1") | |
| elif len(value.shape) > 1: | |
| value = value[0] | |
| if self.skip_prompt and self.next_tokens_are_prompt: | |
| self.next_tokens_are_prompt = False | |
| return | |
| self.num_decode_tokens += len(value) | |
| printable_text = " ".join([str(x) for x in value.tolist()]) + " " | |
| self.on_finalized_text(printable_text) | |
| def find_audio_segments_regex(text): | |
| """ | |
| Find all substrings between <|begin_of_audio|> and <|end_of_audio|> using regex. | |
| Args: | |
| text (str): The input string to search through | |
| Returns: | |
| list: A list of all found audio segments (substrings between the delimiters) | |
| """ | |
| pattern = re.compile(r"<\|begin_of_audio\|>(.*?)<\|end_of_audio\|>", re.DOTALL) | |
| segments = pattern.findall(text) | |
| return [segment.strip() for segment in segments] | |
| def extract_token_ids_as_int(text): | |
| pattern = re.compile(r"<\|audio_(\d+)\|>") | |
| token_ids = pattern.findall(text) | |
| return [int(id) for id in token_ids] | |
| def custom_init_weights(module): | |
| if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear): | |
| torch.nn.init.xavier_uniform_(module.weight) | |
| if module.bias is not None: | |
| torch.nn.init.constant_(module.bias, 0) | |
| elif isinstance(module, torch.nn.BatchNorm2d) or isinstance(module, torch.nn.BatchNorm1d): | |
| torch.nn.init.constant_(module.weight, 1) | |
| torch.nn.init.constant_(module.bias, 0) | |
| class S2SInference: | |
| def __init__( | |
| self, model_name_or_path, audio_tokenizer_path, audio_tokenizer_type, flow_path=None | |
| ): | |
| config = AutoConfig.from_pretrained( | |
| model_name_or_path, | |
| trust_remote_code=True, | |
| ) | |
| if "qwen2" in config.model_type.lower(): | |
| from evaluation.get_chat_template import qwen2_chat_template as chat_template | |
| add_generation_prompt = True | |
| default_system_message = [] | |
| if "hunyuan" in config.model_type.lower(): | |
| from evaluation.get_chat_template import hunyuan_chat_template as chat_template | |
| add_generation_prompt = False | |
| default_system_message = [ | |
| { | |
| "role": "system", | |
| "content": "You are a helpful AI assistant.", | |
| } | |
| ] | |
| luke_system_message = [ | |
| { | |
| "role": "system", | |
| "content": "Your Name: Luke\nYour Gender: male\n\nRespond in a text-audio interleaved manner.", | |
| }, | |
| ] | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_name_or_path, | |
| trust_remote_code=True, | |
| chat_template=chat_template, | |
| ) | |
| # print(f"{tokenizer=}") | |
| print(f"{tokenizer.get_chat_template()=}") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name_or_path, | |
| trust_remote_code=True, | |
| device_map=device_map, | |
| torch_dtype=torch_dtype, | |
| attn_implementation="flash_attention_2", | |
| ).eval() | |
| # print("model", model) | |
| print(f"{model.config.model_type=}") | |
| print(f"{model.hf_device_map=}") | |
| model.generation_config = GenerationConfig.from_pretrained( | |
| model_name_or_path, trust_remote_code=True | |
| ) | |
| model.generation_config.max_new_tokens = 8192 | |
| model.generation_config.chat_format = "chatml" | |
| model.generation_config.max_window_size = 8192 | |
| model.generation_config.use_cache = True | |
| # model.generation_config.use_cache = False | |
| model.generation_config.do_sample = False | |
| model.generation_config.temperature = 1.0 | |
| model.generation_config.top_k = 50 | |
| model.generation_config.top_p = 1.0 | |
| model.generation_config.num_beams = 1 | |
| model.generation_config.pad_token_id = tokenizer.pad_token_id | |
| if model.config.model_type == "hunyuan": | |
| model.generation_config.eos_token_id = tokenizer.eos_id | |
| print(f"{model.generation_config=}") | |
| audio_tokenizer = get_audio_tokenizer( | |
| audio_tokenizer_path, | |
| audio_tokenizer_type, | |
| flow_path=flow_path, | |
| rank=audio_tokenizer_rank, | |
| ) | |
| self.model = model | |
| self.tokenizer = tokenizer | |
| self.audio_tokenizer = audio_tokenizer | |
| self.add_generation_prompt = add_generation_prompt | |
| self.default_system_message = default_system_message | |
| self.luke_system_message = luke_system_message | |
| audio_0_id = tokenizer("<|audio_0|>").input_ids[0] | |
| print(f"{audio_0_id=}") | |
| def benchmark_forward(self, mtp_inference_mode): | |
| print("-" * 100) | |
| print("benchmark_forward...") | |
| print(f"{mtp_inference_mode=}") | |
| total_time = 0 | |
| past_key_values = None | |
| use_cache = True | |
| self.model.input_ids = None | |
| self.model.inputs_embeds = None | |
| self.model.hidden_states = [None] * (self.model.config.num_nextn_predict_layers + 1) | |
| self.model.position_ids = None | |
| self.model.attention_mask = None | |
| self.model.mtp_idx = -1 | |
| self.model.num_prefill_tokens = -1 | |
| model_max_length = 1024 | |
| if mtp_inference_mode is not None: | |
| ori_mtp_inference_mode = self.model.generation_config.mtp_inference_mode | |
| self.model._prepare_mtp_for_generation(mtp_inference_mode, model_max_length) | |
| else: | |
| self.model._prepare_mtp_for_generation( | |
| self.model.generation_config.mtp_inference_mode, model_max_length | |
| ) | |
| for i in tqdm.tqdm(range(1, model_max_length + 1)): | |
| if use_cache: | |
| input_ids = torch.tensor([i - 1], dtype=torch.long).unsqueeze(0).to("cuda") | |
| position_ids = torch.tensor([i - 1], dtype=torch.long).unsqueeze(0).to("cuda") | |
| else: | |
| input_ids = torch.arange(i, dtype=torch.long).unsqueeze(0).to("cuda") | |
| position_ids = torch.arange(i, dtype=torch.long).unsqueeze(0).to("cuda") | |
| attention_mask = torch.tensor([1] * i, dtype=torch.float).unsqueeze(0).to("cuda") | |
| torch.cuda.synchronize() | |
| start = time.time() | |
| output = self.model( | |
| input_ids, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| past_key_values=past_key_values, | |
| use_cache=use_cache, | |
| num_logits_to_keep=1, | |
| ) | |
| torch.cuda.synchronize() | |
| end = time.time() | |
| total_time += end - start | |
| # print(f"{i=} {total_time=}") | |
| past_key_values = output.past_key_values | |
| print() | |
| print(f"{total_time=}") | |
| print(f"second/token {total_time/model_max_length=}") | |
| print(f"token/second {model_max_length/total_time=}") | |
| if mtp_inference_mode is not None: | |
| self.model.mtp_inference_mode = ori_mtp_inference_mode | |
| def benchmark_generate(self, mtp_inference_mode): | |
| self.model.apply(custom_init_weights) | |
| print("-" * 100) | |
| print("benchmark_generate...") | |
| print(f"{mtp_inference_mode=}") | |
| total_time = 0 | |
| self.model.generation_config.use_cache = True | |
| self.model.generation_config.max_new_tokens = 8192 | |
| if mtp_inference_mode is not None: | |
| ori_mtp_inference_mode = self.model.generation_config.mtp_inference_mode | |
| self.model.generation_config.mtp_inference_mode = mtp_inference_mode | |
| input_ids = torch.tensor([0], dtype=torch.long).unsqueeze(0).to("cuda") | |
| torch.cuda.synchronize() | |
| start = time.time() | |
| output = self.model.generate( | |
| input_ids, | |
| ) | |
| # print(f"{output.size()=}") | |
| torch.cuda.synchronize() | |
| end = time.time() | |
| total_time += end - start | |
| print() | |
| print(f"{total_time=}") | |
| print(f"second/token {total_time/output.size(1)=}") | |
| print(f"token/second {output.size(1)/total_time=}") | |
| if mtp_inference_mode is not None: | |
| self.model.generation_config.mtp_inference_mode = ori_mtp_inference_mode | |
| def benchmark_generate_stream(self, mtp_inference_mode): | |
| print("-" * 100) | |
| print("benchmark_generate_stream...") | |
| print(f"{mtp_inference_mode=}") | |
| self.model.apply(custom_init_weights) | |
| total_time = 0 | |
| self.model.generation_config.use_cache = True | |
| # model_max_length = 8192 | |
| model_max_length = 4096 | |
| # model_max_length = 2048 | |
| # model_max_length = 1024 | |
| num_prefill_tokens = 32 | |
| self.model.generation_config.max_new_tokens = model_max_length | |
| self.model.generation_config.do_sample = False | |
| if mtp_inference_mode is not None: | |
| ori_mtp_inference_mode = self.model.generation_config.mtp_inference_mode | |
| self.model.generation_config.mtp_inference_mode = mtp_inference_mode | |
| input_ids = torch.tensor([0] * num_prefill_tokens, dtype=torch.long).unsqueeze(0).to("cuda") | |
| streamer = BenchmarkIteratorStreamer(self.tokenizer, skip_prompt=True) | |
| generation_kwargs = dict(input_ids=input_ids, streamer=streamer) | |
| thread = Thread(target=self.model.generate, kwargs=generation_kwargs) | |
| token_decode_time = [] | |
| torch.cuda.synchronize() | |
| start = time.time() | |
| thread.start() | |
| generated_text = "" | |
| for new_text in tqdm.tqdm(streamer, total=model_max_length): | |
| generated_text += new_text | |
| end = time.time() | |
| token_decode_time.append(end - start) | |
| yield new_text | |
| # print(f"{len(generated_text)}") | |
| torch.cuda.synchronize() | |
| end = time.time() | |
| total_time += end - start | |
| print() | |
| print(f"{token_decode_time[-1]=}") | |
| print(f"{streamer.num_decode_tokens=}") | |
| print(f"second/token {token_decode_time[-1]/streamer.num_decode_tokens=}") | |
| print(f"token/second {streamer.num_decode_tokens/token_decode_time[-1]=}") | |
| # if mtp_inference_mode is None: | |
| # mtp_inference_mode = [] | |
| # with open(f'token_decode_time_{str(mtp_inference_mode)}.json', 'w') as f: | |
| # json.dump(token_decode_time, f) | |
| if mtp_inference_mode is not None: | |
| self.model.generation_config.mtp_inference_mode = ori_mtp_inference_mode | |
| def run_infer( | |
| self, | |
| audio_path=None, | |
| prompt_audio_path=None, | |
| stream_stride=4, | |
| max_returned_tokens=4096, | |
| sample_rate=16000, | |
| request_id="", | |
| audio_feats=None, | |
| message="", | |
| use_past=False, | |
| mode="luke", | |
| do_sample=False, | |
| mtp_inference_mode=None, | |
| ): | |
| AUD_TAG_TOKEN = "<|audio|>" | |
| AUD_CONTEXT_TOKEN = "<|context_of_audio|>" | |
| AUD_START_TOKEN = "<|begin_of_audio|>" | |
| AUD_END_TOKEN = "<|end_of_audio|>" | |
| if prompt_audio_path is not None: | |
| system_message = [ | |
| { | |
| "role": "system", | |
| "content": f"Your Voice: <|audio|>\n", | |
| }, | |
| ] | |
| elif mode == "luke": | |
| system_message = self.luke_system_message | |
| else: | |
| system_message = self.default_system_message | |
| if prompt_audio_path is not None and self.audio_tokenizer.apply_to_role("user", is_discrete=True): | |
| # discrete codec | |
| audio_tokens = self.audio_tokenizer.encode(prompt_audio_path) | |
| audio_tokens = "".join(f"<|audio_{i}|>" for i in audio_tokens) | |
| system_message[-1]["content"] = system_message[-1]["content"].replace( | |
| "<|audio|>", f"<|begin_of_audio|>{audio_tokens}<|end_of_audio|>" | |
| ) | |
| if audio_path is not None: | |
| messages = system_message + [ | |
| { | |
| "role": "user", | |
| "content": message + "\n<|audio|>", | |
| }, | |
| ] | |
| else: | |
| messages = system_message + [ | |
| { | |
| "role": "user", | |
| "content": message, | |
| }, | |
| ] | |
| if audio_path is not None and self.audio_tokenizer.apply_to_role("user", is_discrete=True): | |
| # discrete codec | |
| audio_tokens = self.audio_tokenizer.encode(audio_path) | |
| audio_tokens = "".join(f"<|audio_{i}|>" for i in audio_tokens) | |
| messages[-1]["content"] = messages[-1]["content"].replace( | |
| "<|audio|>", f"<|begin_of_audio|>{audio_tokens}<|end_of_audio|>" | |
| ) | |
| input_ids = self.tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=True, | |
| add_generation_prompt=self.add_generation_prompt, | |
| ) | |
| if (audio_path is not None or prompt_audio_path is not None) and self.audio_tokenizer.apply_to_role( | |
| "user", is_contiguous=True | |
| ): | |
| # contiguous codec | |
| audio_paths = [] | |
| if audio_path is not None: | |
| audio_paths.append(audio_path) | |
| if prompt_audio_path is not None: | |
| audio_paths.append(prompt_audio_path) | |
| input_ids, audios, audio_indices = add_audio_input_contiguous( | |
| input_ids, audio_paths, self.tokenizer, self.audio_tokenizer | |
| ) | |
| else: | |
| audios = None | |
| audio_indices = None | |
| input_ids = torch.tensor([input_ids], dtype=torch.long).to("cuda") | |
| print("input", self.tokenizer.decode(input_ids[0], skip_special_tokens=False), flush=True) | |
| self.model.generation_config.do_sample = do_sample | |
| if mtp_inference_mode is not None: | |
| ori_mtp_inference_mode = self.model.generation_config.mtp_inference_mode | |
| self.model.generation_config.mtp_inference_mode = mtp_inference_mode | |
| outputs = self.model.generate( | |
| input_ids, | |
| audios=audios, | |
| audio_indices=audio_indices, | |
| ) | |
| output = self.tokenizer.decode(outputs[0], skip_special_tokens=False) | |
| print(f"{output=}", flush=True) | |
| audio_offset = self.tokenizer.convert_tokens_to_ids("<|audio_0|>") | |
| audio_tokens = [] | |
| for token_id in outputs[0]: | |
| if token_id >= audio_offset: | |
| audio_tokens.append(token_id - audio_offset) | |
| if len(audio_tokens) > 0: | |
| tts_speech = self.audio_tokenizer.decode( | |
| audio_tokens, source_speech_16k=prompt_audio_path | |
| ) | |
| else: | |
| tts_speech = None | |
| if mtp_inference_mode is not None: | |
| self.model.generation_config.mtp_inference_mode = ori_mtp_inference_mode | |
| return output, tts_speech | |
| def run_infer_stream( | |
| self, | |
| audio_path=None, | |
| prompt_audio_path=None, | |
| stream_stride=4, | |
| max_returned_tokens=4096, | |
| sample_rate=16000, | |
| request_id="", | |
| audio_feats=None, | |
| message="", | |
| use_past=False, | |
| mode="luke", | |
| do_sample=False, | |
| mtp_inference_mode=None, | |
| ): | |
| if prompt_audio_path is not None: | |
| system_message = [ | |
| { | |
| "role": "system", | |
| "content": f"Your Voice: <|audio|>\n", | |
| }, | |
| ] | |
| elif mode == "luke": | |
| system_message = self.luke_system_message | |
| else: | |
| system_message = self.default_system_message | |
| if prompt_audio_path is not None and self.audio_tokenizer.apply_to_role("user", is_discrete=True): | |
| # discrete codec | |
| audio_tokens = self.audio_tokenizer.encode(prompt_audio_path) | |
| audio_tokens = "".join(f"<|audio_{i}|>" for i in audio_tokens) | |
| system_message[-1]["content"] = system_message[-1]["content"].replace( | |
| "<|audio|>", f"<|begin_of_audio|>{audio_tokens}<|end_of_audio|>" | |
| ) | |
| if audio_path is not None: | |
| messages = system_message + [ | |
| { | |
| "role": "user", | |
| "content": message + "\n<|audio|>", | |
| }, | |
| ] | |
| else: | |
| messages = system_message + [ | |
| { | |
| "role": "user", | |
| "content": message, | |
| }, | |
| ] | |
| if audio_path is not None and self.audio_tokenizer.apply_to_role("user", is_discrete=True): | |
| # discrete codec | |
| audio_tokens = self.audio_tokenizer.encode(audio_path) | |
| audio_tokens = "".join(f"<|audio_{i}|>" for i in audio_tokens) | |
| messages[-1]["content"] = messages[-1]["content"].replace( | |
| "<|audio|>", f"<|begin_of_audio|>{audio_tokens}<|end_of_audio|>" | |
| ) | |
| input_ids = self.tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=True, | |
| add_generation_prompt=self.add_generation_prompt, | |
| ) | |
| if (audio_path is not None or prompt_audio_path is not None) and self.audio_tokenizer.apply_to_role( | |
| "user", is_contiguous=True | |
| ): | |
| # contiguous codec | |
| audio_paths = [] | |
| if audio_path is not None: | |
| audio_paths.append(audio_path) | |
| if prompt_audio_path is not None: | |
| audio_paths.append(prompt_audio_path) | |
| input_ids, audios, audio_indices = add_audio_input_contiguous( | |
| input_ids, audio_paths, self.tokenizer, self.audio_tokenizer | |
| ) | |
| else: | |
| audios = None | |
| audio_indices = None | |
| input_ids = torch.tensor([input_ids], dtype=torch.long).to("cuda") | |
| print("input", self.tokenizer.decode(input_ids[0], skip_special_tokens=False), flush=True) | |
| self.model.generation_config.do_sample = do_sample | |
| if mtp_inference_mode is not None: | |
| ori_mtp_inference_mode = self.model.generation_config.mtp_inference_mode | |
| self.model.generation_config.mtp_inference_mode = mtp_inference_mode | |
| streamer = TextAudioIteratorStreamer(self.tokenizer, skip_prompt=True) | |
| generation_kwargs = dict( | |
| input_ids=input_ids, | |
| audios=audios, | |
| audio_indices=audio_indices, | |
| streamer=streamer, | |
| ) | |
| thread = Thread(target=self.model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| # generated_text = "" | |
| for new_text in streamer: | |
| # generated_text += new_text | |
| yield new_text | |
| # torch.cuda.synchronize() | |
| if mtp_inference_mode is not None: | |
| self.model.generation_config.mtp_inference_mode = ori_mtp_inference_mode | |
| def benchmark_llm(): | |
| for mtp_inference_mode, tag in zip( | |
| [ | |
| [8192, 0], | |
| [1, 4, 3, 8, 4, 10], | |
| [1, 10, 4, 10], | |
| [1, 10], | |
| ], | |
| [ | |
| "Vanilla", | |
| "Balance", | |
| "Boost", | |
| "Turbo", | |
| ], | |
| ): | |
| print("=" * 100) | |
| print("benchmark_llm") | |
| print(f"{tag}") | |
| s2s_inference.benchmark_forward(mtp_inference_mode) | |
| s2s_inference.benchmark_generate(mtp_inference_mode) | |
| generated_text = "" | |
| for new_text in s2s_inference.benchmark_generate_stream( | |
| mtp_inference_mode=mtp_inference_mode | |
| ): | |
| generated_text += new_text | |
| # print(new_text, end="", flush=True) | |
| def benchmark_sts(): | |
| audio_paths = [ | |
| "asset/介绍一下上海.wav", | |
| "asset/发表一个悲伤的演讲.wav", | |
| "asset/发表一个振奋人心的演讲.wav", | |
| ] | |
| for _ in range(10): | |
| print("=" * 100) | |
| print("benchmark_sts") | |
| audio_path = random.choice(audio_paths) | |
| print(f"{audio_path}") | |
| start = time.time() | |
| audio_idx = 0 | |
| generated_text = "" | |
| all_tts_speech = [] | |
| past_tts_speech_len = 0 | |
| for new_text in s2s_inference.run_infer_stream(audio_path=audio_path): | |
| # print(new_text, end="", flush=True) | |
| generated_text += new_text | |
| if new_text == "<|end_of_audio|>": | |
| audio_tokens = extract_token_ids_as_int(generated_text) | |
| tts_speech = s2s_inference.audio_tokenizer.decode(audio_tokens, option_steps=1) | |
| tts_speech = tts_speech[past_tts_speech_len:] | |
| past_tts_speech_len += len(tts_speech) | |
| all_tts_speech.append(tts_speech) | |
| end = time.time() | |
| if audio_idx == 0: | |
| print(audio_tokens) | |
| print(f"{audio_idx} audio chunk {end - start}") | |
| wav_path = os.path.join(output_dir, audio_path[:-4] + f"_{audio_idx}.wav") | |
| os.makedirs(os.path.dirname(wav_path), exist_ok=True) | |
| torchaudio.save(wav_path, tts_speech.unsqueeze(0), 22050, format="wav") | |
| audio_idx += 1 | |
| start = time.time() | |
| wav_path = os.path.join(output_dir, audio_path[:-4] + ".wav") | |
| tts_speech = torch.cat(all_tts_speech, dim=0) | |
| os.makedirs(os.path.dirname(wav_path), exist_ok=True) | |
| torchaudio.save(wav_path, tts_speech.unsqueeze(0), 22050, format="wav") | |
| # ============================================================== | |
| # Text | |
| def text_task(): | |
| for text in [ | |
| "How many helicopters can a human eat in one sitting?", | |
| "你叫什么名字?", | |
| "写一首诗", | |
| "介绍一下上海", | |
| ]: | |
| print("=" * 100) | |
| print("text_task") | |
| print(f"{text=}") | |
| output, _ = s2s_inference.run_infer( | |
| message=text, | |
| mode=None, | |
| # do_sample=True, | |
| mtp_inference_mode=[8192, 0], | |
| ) | |
| print(f"{output=}", flush=True) | |
| # ============================================================== | |
| # Text stream | |
| def text_stream_task(): | |
| for text in [ | |
| "你叫什么名字?", | |
| ]: | |
| print("=" * 100) | |
| print("text_stream_task") | |
| print(f"{text=}") | |
| generated_text = "" | |
| for new_text in s2s_inference.run_infer_stream( | |
| message=text, | |
| mode=None, | |
| # do_sample=True, | |
| mtp_inference_mode=[8192, 0], | |
| ): | |
| generated_text += new_text | |
| print(new_text, end="") | |
| print("") | |
| # ============================================================== | |
| # S2S | |
| def sts_task(): | |
| for audio_path in [ | |
| "asset/介绍一下上海.wav", | |
| "asset/发表一个悲伤的演讲.wav", | |
| "asset/发表一个振奋人心的演讲.wav", | |
| "asset/piano.mp3", | |
| ]: | |
| print("=" * 100) | |
| print("sts_task") | |
| print(f"{audio_path=}") | |
| output, tts_speech = s2s_inference.run_infer( | |
| audio_path=audio_path, | |
| ) | |
| wav_path = os.path.join(output_dir, audio_path[:-4] + ".wav") | |
| os.makedirs(os.path.dirname(wav_path), exist_ok=True) | |
| torchaudio.save(wav_path, tts_speech.unsqueeze(0), 22050, format="wav") | |
| # ============================================================== | |
| # S2S stream | |
| def sts_stream_task(): | |
| for audio_path in [ | |
| "asset/介绍一下上海.wav", | |
| ]: | |
| print("=" * 100) | |
| print("sts_stream_task") | |
| print(f"{audio_path=}") | |
| generated_text = "" | |
| for new_text in s2s_inference.run_infer_stream(audio_path=audio_path): | |
| generated_text += new_text | |
| print(new_text, end="") | |
| print("") | |
| audio_decode_time = [] | |
| audio_segments = find_audio_segments_regex(generated_text) | |
| for audio_idx, audio_segment in enumerate(audio_segments): | |
| start = time.time() | |
| audio_tokens = extract_token_ids_as_int(audio_segment) | |
| # print(audio_tokens) | |
| tts_speech = s2s_inference.audio_tokenizer.decode(audio_tokens) | |
| end = time.time() | |
| audio_decode_time.append(end - start) | |
| wav_path = os.path.join(output_dir, audio_path[:-4] + f"_{audio_idx}.wav") | |
| os.makedirs(os.path.dirname(wav_path), exist_ok=True) | |
| torchaudio.save(wav_path, tts_speech.unsqueeze(0), 22050, format="wav") | |
| # print(f"{audio_decode_time=}") | |
| # ============================================================== | |
| # ASR | |
| def asr_task(): | |
| for audio_path in [ | |
| "/data/data/wenet-e2e/wenetspeech/data/cuts_TEST_NET.00000000/TES/TEST_NET_Y0000000020_5XD21BihDd8_S00395.wav", | |
| "/data/data/wenet-e2e/wenetspeech/data/cuts_TEST_NET.00000000/TES/TEST_NET_Y0000000000_-KTKHdZ2fb8_S00424.wav", | |
| "/data/data/wenet-e2e/wenetspeech/data/cuts_TEST_NET.00000000/TES/TEST_NET_Y0000000050_LOLTeK1BNMo_S00045.wav", | |
| "/data/data/fixie-ai/librispeech_asr/test.clean/2830-3980-0034.wav", | |
| "/data/data/fixie-ai/librispeech_asr/test.clean/237-134500-0040.wav", | |
| ]: | |
| print("=" * 100) | |
| print("asr_task") | |
| print(f"{audio_path=}") | |
| output, tts_speech = s2s_inference.run_infer( | |
| audio_path=audio_path, | |
| # message="Translate the speech to text.", | |
| message="Convert the speech to text.", | |
| mode=None, | |
| ) | |
| print(f"{output=}", flush=True) | |
| # ============================================================== | |
| # TTS | |
| def tts_task(): | |
| TTS_texts = [ | |
| "我们将为全球城市的可持续发展贡献力量。", | |
| "通天河 灵感大王", | |
| "他本是我莲花池里养大的金鱼,每日浮头听经,修成手段。那一柄九瓣铜锤,乃是一枝未开的菡萏,被他运炼成兵。不知是那一日,海潮泛涨,走到此间。我今早扶栏看花,却不见这厮出拜,掐指巡纹,算着他在此成精,害你师父,故此未及梳妆,运神功,织个竹篮儿擒他。", | |
| "一二三四五六七八九十", | |
| "One Two Tree Four Five Six Seven Eight Night Ten", | |
| "1 2 3 4 5 6 7 8 9 10", | |
| "12345678910", | |
| "两个黄鹂鸣翠柳,一行白鹭上青天。窗含西岭千秋雪,门泊东吴万里船。", | |
| "坡上立着一只鹅,坡下就是一条河。宽宽的河,肥肥的鹅,鹅要过河,河要渡鹅不知是鹅过河,还是河渡鹅?", | |
| "扁担长,板凳宽,扁担没有板凳宽,板凳没有扁担长。扁担绑在板凳上,板凳不让扁担绑在板凳上。", | |
| "化肥会挥发,黑化肥发灰,灰化肥发黑。黑化肥发灰会挥发;灰化肥挥发会发黑。黑化肥挥发发灰会花飞;灰化肥挥发发黑会飞花,黑灰化肥会挥发发灰黑讳为花飞;灰黑化肥会挥发发黑灰为讳飞花。", | |
| "圆桌儿、方桌儿没有腿儿,墨水瓶儿里没有水儿,花瓶里有花儿没有叶儿,练习本儿上写字儿没有准儿,甘蔗好吃净是节儿。西瓜挺大没有味儿,坛儿里的小米儿长了虫儿,鸡毛掸子成了棍儿,水缸沿儿上系围裙儿,耗子打更猫打盹儿,新买的小褂儿没钉扣儿,奶奶想说没有劲儿。", | |
| "起床歌:小宝宝,起得早,睁开眼,眯眯笑,咿呀呀,学说话,伸伸手,要人抱。穿衣歌小胳膊,穿袖子,穿上衣,扣扣子,小脚丫,穿裤子,穿上袜子穿鞋子。小镜子-小镜子,圆又圆,看宝宝,露笑脸。闭上眼,做个梦,变月亮,挂上天。小铃铛叮铃铃,叮铃铃,一会远,一会近。小宝宝,耳朵灵,听铃声,找到铃。学画画小宝宝,学画画,大蜡笔,手中拿,画小鸭,叫嘎嘎,画小马,骑回家。大鞋子大鞋子,像只船,爸爸穿,我也穿,一二一,向前走,走呀走,翻了船。逛公园逛公园,宝宝笑,东看看,西瞧瞧,花儿香,鸟儿叫,小草绿,小树摇。看画报小娃娃,看画报,睁大眼,仔细瞧,布娃娃,哈哈笑,伸伸手,要你抱。搭积木大积木,红黄兰,小宝宝,最爱玩,搭火车,钻山洞,盖高楼,连着天。小汽车小汽车,嘀嘀嘀,开过来,开过去,小宝宝,当司机,送妈妈,上班去。藏猫猫儿歌:躲猫猫,躲猫猫, 猫猫、猫猫在哪里?喵……猫咪在这里。", | |
| ] | |
| for text in TTS_texts: | |
| print("=" * 100) | |
| print("tts_task") | |
| print(f"{text=}") | |
| output, tts_speech = s2s_inference.run_infer( | |
| message="Convert the text to speech.\n" + text, | |
| mode=None, | |
| do_sample=True, | |
| ) | |
| wav_path = os.path.join(output_dir, text[:16] + ".wav") | |
| os.makedirs(os.path.dirname(wav_path), exist_ok=True) | |
| torchaudio.save(wav_path, tts_speech.unsqueeze(0), 22050, format="wav") | |
| # ============================================================== | |
| # Clone TTS | |
| for text in TTS_texts: | |
| for prompt_audio_path in [ | |
| "asset/2631296891109983590.wav", | |
| "asset/379838640-d5ff0815-74f8-4738-b0f1-477cfc8dcc2d.wav", | |
| "asset/4202818730519913143.wav", | |
| ]: | |
| print("=" * 100) | |
| print("tts_task") | |
| print(f"{text=} {prompt_audio_path=}") | |
| output, tts_speech = s2s_inference.run_infer( | |
| prompt_audio_path=prompt_audio_path, | |
| # message="Translate the text to speech.\n" + text, | |
| message="Convert the text to speech.\n" + text, | |
| mode=None, | |
| do_sample=True, | |
| ) | |
| wav_path = os.path.join(output_dir, prompt_audio_path[:16] + "_" + text[:16] + ".wav") | |
| os.makedirs(os.path.dirname(wav_path), exist_ok=True) | |
| torchaudio.save(wav_path, tts_speech.unsqueeze(0), 22050, format="wav") | |
| # ============================================================== | |
| # TTS stream | |
| def tts_stream_task(): | |
| TTS_texts = [ | |
| "他本是我莲花池里养大的金鱼,每日浮头听经,修成手段。那一柄九瓣铜锤,乃是一枝未开的菡萏,被他运炼成兵。不知是那一日,海潮泛涨,走到此间。我今早扶栏看花,却不见这厮出拜,掐指巡纹,算着他在此成精,害你师父,故此未及梳妆,运神功,织个竹篮儿擒他。", | |
| ] | |
| for text in TTS_texts: | |
| print("=" * 100) | |
| print("tts_stream_task") | |
| print(f"{text=}") | |
| generated_text = "" | |
| for new_text in s2s_inference.run_infer_stream( | |
| message="Convert the text to speech.\n" + text, | |
| mode=None, | |
| do_sample=True, | |
| ): | |
| generated_text += new_text | |
| print(new_text, end="") | |
| print("") | |
| audio_segments = find_audio_segments_regex(generated_text) | |
| for audio_idx, audio_segment in enumerate(audio_segments): | |
| audio_tokens = extract_token_ids_as_int(audio_segment) | |
| # print(audio_tokens) | |
| tts_speech = s2s_inference.audio_tokenizer.decode(audio_tokens) | |
| wav_path = os.path.join(output_dir, text[:16] + f"_{audio_idx}.wav") | |
| os.makedirs(os.path.dirname(wav_path), exist_ok=True) | |
| torchaudio.save(wav_path, tts_speech.unsqueeze(0), 22050, format="wav") | |
| s2s_inference = S2SInference( | |
| model_name_or_path, audio_tokenizer_path, audio_tokenizer_type, flow_path=flow_path | |
| ) | |
| text_task() | |
| text_stream_task() | |
| sts_task() | |
| sts_stream_task() | |
| asr_task() | |
| tts_task() | |
| tts_stream_task() | |
| benchmark_sts() | |
| benchmark_llm() | |