Spaces:
Runtime error
Runtime error
| import argparse | |
| import itertools | |
| import json | |
| import os | |
| import random | |
| import sys | |
| import uuid | |
| from datetime import timedelta | |
| from functools import partial | |
| from pathlib import Path | |
| import torch | |
| import tqdm | |
| from datasets import load_dataset | |
| from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer | |
| from transformers.generation import GenerationConfig | |
| import torchaudio | |
| from vita_audio.data.processor.audio_processor import add_audio_input_contiguous | |
| from vita_audio.tokenizer import get_audio_tokenizer | |
| def collate_fn(batches): | |
| input_ids = [sample["input_ids"] for sample in batches] | |
| audios = [sample["audios"] for sample in batches] | |
| audio_indices = [sample["audio_indices"] for sample in batches] | |
| refs = [sample["ref"] for sample in batches] | |
| return input_ids, audios, audio_indices, refs | |
| class ASRDataset(torch.utils.data.Dataset): | |
| def __init__( | |
| self, | |
| json_path, | |
| tokenizer, | |
| audio_tokenizer, | |
| default_system_message=None, | |
| add_generation_prompt=True, | |
| ): | |
| data = load_dataset("json", data_files=json_path, keep_in_memory=False) | |
| self.data = data["train"] | |
| self.tokenizer = tokenizer | |
| self.add_generation_prompt = add_generation_prompt | |
| self.audio_tokenizer = audio_tokenizer | |
| self.default_system_message = default_system_message | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| sample = self.data[idx] | |
| # print(f"sample {sample}") | |
| audio_path = sample["audios"][0] | |
| if self.audio_tokenizer.apply_to_role("user", is_discrete=True): | |
| # discrete codec | |
| audio_tokens = self.audio_tokenizer.encode(audio_path) | |
| audio_tokens = "".join(f"<|audio_{i}|>" for i in audio_tokens) | |
| else: | |
| audio_tokens = None | |
| messages = [] | |
| if len(sample["messages"]) == 2: | |
| assert len(sample["messages"]) == 2 | |
| assert sample["messages"][0]["role"] == "user" | |
| assert sample["messages"][1]["role"] == "assistant" | |
| if self.default_system_message is not None: | |
| messages = self.default_system_message + messages | |
| elif len(sample["messages"]) == 3: | |
| assert len(sample["messages"]) == 3 | |
| assert sample["messages"][0]["role"] == "system" | |
| assert sample["messages"][1]["role"] == "user" | |
| assert sample["messages"][2]["role"] == "assistant" | |
| else: | |
| raise NotImplementedError | |
| # print(sample) | |
| for conv in sample["messages"][:-1]: | |
| new_conv = {} | |
| new_conv["role"] = conv["role"] | |
| content = conv["content"] | |
| if audio_tokens is not None: | |
| content = content.replace( | |
| "<|audio|>", f"<|begin_of_audio|>{audio_tokens}<|end_of_audio|>" | |
| ) | |
| new_conv["content"] = content | |
| messages.append(new_conv) | |
| input_ids = self.tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=True, | |
| add_generation_prompt=self.add_generation_prompt, | |
| # return_tensors="pt", | |
| ) | |
| ref = sample["messages"][-1]["content"] | |
| if self.audio_tokenizer.apply_to_role("user", is_contiguous=True): | |
| # contiguous codec | |
| input_ids, audios, audio_indices = add_audio_input_contiguous( | |
| input_ids, [audio_path], self.tokenizer, self.audio_tokenizer | |
| ) | |
| else: | |
| audios = None | |
| audio_indices = None | |
| input_ids = torch.tensor([input_ids], dtype=torch.long) | |
| return { | |
| "input_ids": input_ids, | |
| "audios": audios, | |
| "audio_indices": audio_indices, | |
| "ref": ref, | |
| } | |
| class InferenceSampler(torch.utils.data.sampler.Sampler): | |
| def __init__(self, size): | |
| self._size = int(size) | |
| assert size > 0 | |
| self._rank = torch.distributed.get_rank() | |
| self._world_size = torch.distributed.get_world_size() | |
| self._local_indices = self._get_local_indices(size, self._world_size, self._rank) | |
| def _get_local_indices(total_size, world_size, rank): | |
| shard_size = total_size // world_size | |
| left = total_size % world_size | |
| shard_sizes = [shard_size + int(r < left) for r in range(world_size)] | |
| begin = sum(shard_sizes[:rank]) | |
| end = min(sum(shard_sizes[: rank + 1]), total_size) | |
| return range(begin, end) | |
| def __iter__(self): | |
| yield from self._local_indices | |
| def __len__(self): | |
| return len(self._local_indices) | |
| def inference(model, tokenizer, audio_tokenizer, dataloader, output_dir): | |
| audio_offset = tokenizer.convert_tokens_to_ids("<|audio_0|>") | |
| outputs = [] | |
| for _, (batched_input_ids, batched_audios, batched_audio_indices, batched_ref) in enumerate( | |
| tqdm.tqdm(dataloader) | |
| ): | |
| for input_ids, audios, audio_indices, ref in zip( | |
| batched_input_ids, batched_audios, batched_audio_indices, batched_ref | |
| ): | |
| kwargs = { | |
| # "temperature": 0.2, | |
| # "top_p": 0.8, | |
| # "do_sample": False, | |
| # "temperature": 1.0, | |
| "max_new_tokens": max([len(x) for x in batched_ref]) + 10, | |
| "min_new_tokens": 1, | |
| } | |
| if audios is not None: | |
| kwargs["audios"] = audios | |
| kwargs["audio_indices"] = audio_indices | |
| responses = model.generate( | |
| input_ids=input_ids.cuda(), | |
| **kwargs, | |
| ) | |
| response = responses[0][len(input_ids[0]) :] | |
| text_tokens = [] | |
| audio_tokens = [] | |
| for token_id in response: | |
| if token_id >= audio_offset: | |
| audio_tokens.append(token_id - audio_offset) | |
| else: | |
| text_tokens.append(token_id) | |
| hyp = tokenizer.decode(text_tokens, skip_special_tokens=True) | |
| outputs.append((hyp, ref)) | |
| print("") | |
| print("=" * 100) | |
| print(f"{hyp=}") | |
| print(f"{ref=}") | |
| return outputs | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser( | |
| description="", | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| ) | |
| parser.add_argument("--model_name_or_path", type=str, required=True, help="model_name_or_path") | |
| parser.add_argument( | |
| "--audio_tokenizer_path", type=str, required=True, help="audio_tokenizer_path" | |
| ) | |
| parser.add_argument( | |
| "--audio_tokenizer_type", type=str, required=True, help="audio_tokenizer_type" | |
| ) | |
| parser.add_argument("--flow_path", type=str, required=True, help="flow_path") | |
| parser.add_argument("--json_path", type=str, required=True, help="json_path") | |
| parser.add_argument("--output_dir", type=str, required=True, help="output_dir") | |
| parser.add_argument("--batch_size", type=int, default=1) | |
| parser.add_argument("--num_workers", type=int, default=0) | |
| args = parser.parse_args() | |
| print(f"{args=}") | |
| torch.distributed.init_process_group( | |
| backend="nccl", | |
| world_size=int(os.getenv("WORLD_SIZE", "1")), | |
| rank=int(os.getenv("RANK", "0")), | |
| timeout=timedelta(seconds=7200), | |
| ) | |
| torch.cuda.set_device(int(os.getenv("LOCAL_RANK", 0))) | |
| random.seed(42) | |
| torch.manual_seed(42) | |
| config = AutoConfig.from_pretrained( | |
| args.model_name_or_path, | |
| trust_remote_code=True, | |
| ) | |
| # ================================================================ | |
| if "glm" in config.model_type.lower(): | |
| from get_chat_template import glm4_chat_template as chat_template | |
| add_generation_prompt = True | |
| default_system_message = [ | |
| { | |
| "role": "system", | |
| "content": "User will provide you with a speech instruction. Do it step by step. First, think about the instruction and respond in a interleaved manner, with 13 text token followed by 26 audio tokens.", | |
| } | |
| ] | |
| if "qwen2" in config.model_type.lower(): | |
| from get_chat_template import qwen2_chat_template as chat_template | |
| add_generation_prompt = True | |
| default_system_message = [] | |
| if "hunyuan" in config.model_type.lower(): | |
| from get_chat_template import hunyuan_chat_template as chat_template | |
| add_generation_prompt = False | |
| default_system_message = [ | |
| { | |
| "role": "system", | |
| "content": "You are a helpful AI assistant.", | |
| } | |
| ] | |
| # ================================================================ | |
| print("Loading model") | |
| # device_map = "auto" | |
| device_map = "cuda" | |
| # torch_dtype=torch.float16 | |
| torch_dtype = torch.bfloat16 | |
| rank = torch.distributed.get_rank() | |
| audio_tokenizer = get_audio_tokenizer( | |
| args.audio_tokenizer_path, args.audio_tokenizer_type, flow_path=args.flow_path, rank=rank | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| args.model_name_or_path, | |
| trust_remote_code=True, | |
| chat_template=chat_template, | |
| ) | |
| # print("tokenizer", tokenizer) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| args.model_name_or_path, | |
| trust_remote_code=True, | |
| device_map=device_map, | |
| torch_dtype=torch_dtype, | |
| attn_implementation="flash_attention_2", | |
| ).eval() | |
| # print("model", model) | |
| model.generation_config = GenerationConfig.from_pretrained( | |
| args.model_name_or_path, trust_remote_code=True | |
| ) | |
| model.generation_config.max_new_tokens = 4096 | |
| model.generation_config.chat_format = "chatml" | |
| model.generation_config.max_window_size = 8192 | |
| model.generation_config.use_cache = True | |
| model.generation_config.do_sample = False | |
| model.generation_config.pad_token_id = tokenizer.pad_token_id | |
| if model.config.model_type == "hunyuan": | |
| model.generation_config.eos_token_id = tokenizer.eos_id | |
| # ================================================================ | |
| print("Loading data") | |
| dataset = ASRDataset( | |
| json_path=args.json_path, | |
| tokenizer=tokenizer, | |
| audio_tokenizer=audio_tokenizer, | |
| default_system_message=default_system_message, | |
| add_generation_prompt=add_generation_prompt, | |
| ) | |
| dataloader = torch.utils.data.DataLoader( | |
| dataset=dataset, | |
| sampler=InferenceSampler(len(dataset)), | |
| batch_size=args.batch_size, | |
| num_workers=args.num_workers, | |
| pin_memory=True, | |
| drop_last=False, | |
| collate_fn=partial( | |
| collate_fn, | |
| ), | |
| ) | |
| # ================================================================ | |
| outputs = inference(model, tokenizer, audio_tokenizer, dataloader, args.output_dir) | |
| torch.distributed.barrier() | |
| world_size = torch.distributed.get_world_size() | |
| merged_outputs = [None for _ in range(world_size)] | |
| torch.distributed.all_gather_object(merged_outputs, json.dumps(outputs)) | |
| merged_outputs = [json.loads(_) for _ in merged_outputs] | |
| merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)] | |
| if torch.distributed.get_rank() == 0: | |
| # json_name = Path("_".join(os.path.normpath(args.json_path).split(os.sep)[-2:])).stem | |
| json_name = Path(os.path.normpath(args.json_path).split(os.sep)[-1]).stem | |
| hyp_path = os.path.join(args.output_dir, f"{json_name}_hyp.txt") | |
| ref_path = os.path.join(args.output_dir, f"{json_name}_ref.txt") | |
| os.makedirs(os.path.dirname(ref_path), exist_ok=True) | |
| os.makedirs(os.path.dirname(hyp_path), exist_ok=True) | |
| hyp_file = open(hyp_path, "w") | |
| ref_file = open(ref_path, "w") | |
| for sample_idx, (hyp, ref) in enumerate(merged_outputs): | |
| hyp_file.write(f"{sample_idx} {hyp}" + "\n") | |
| ref_file.write(f"{sample_idx} {ref}" + "\n") | |
| hyp_file.close() | |
| ref_file.close() | |
| hyp_ref_path = os.path.join(args.output_dir, f"{json_name}_hyp_ref.json") | |
| hyp_ref_file = open(hyp_ref_path, "w") | |
| json.dump(merged_outputs, hyp_ref_file, indent=4) | |
| hyp_ref_file.close() | |
| torch.distributed.barrier() | |
| print("Done.") | |