Spaces:
Running
on
Zero
Running
on
Zero
| import argparse | |
| from email.policy import default | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="encode the dataset using codec model") | |
| parser.add_argument('--root', type=str, default="/data/scratch/pyp/datasets/emilia", help="Path to the directory") | |
| parser.add_argument('--sub_root', type=str, default="preprocessed", help="sub directory") | |
| parser.add_argument('--encodec_name', type=str, default="encodec_6f79c6a8.th", help="name of the codec model") | |
| parser.add_argument('--n_workers', type=int, default=16, help="Number of parallel worker processes") | |
| parser.add_argument('--batch_size', type=int, default=16, help="batch size for codec encoding, decrease it if OOM. This is the sum of batch size *over each gpu*, so increase it if you are using more gpus") | |
| parser.add_argument('--audio_sr', type=int, default=16000, help='input audio sample rate') | |
| parser.add_argument('--model_sr', type=int, default=16000, help='encodec input audio sample rate') | |
| parser.add_argument('--downsample_rate', type=int, default=320, help='encodec downsample rate') | |
| parser.add_argument('--model_code_sr', type=float, default=50, help='codec model code sample rate') | |
| parser.add_argument('--len_cap', type=float, default=1000, help='will drop audios that are longer than this number') | |
| parser.add_argument('--min_len', type=float, default=0.5, help='will drop audios that are shorter than this number') | |
| parser.add_argument('--partition', type=str, default="1/1", help='split for parallel processing') | |
| parser.add_argument('--split', type=str, default='train', choices=['train', 'valid', 'test']) | |
| return parser.parse_args() | |
| if __name__ == "__main__": | |
| import logging | |
| formatter = ( | |
| "%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s" | |
| ) | |
| logging.basicConfig(format=formatter, level=logging.INFO) | |
| import os, sys | |
| import numpy as np | |
| import torch | |
| import torchaudio | |
| import tqdm | |
| import time | |
| args = parse_args() | |
| def sort_by_audio_len(lens): | |
| inds = np.argsort(lens).tolist() | |
| logging.info(f"longest: {lens[inds[-1]]/args.downsample_rate} encodec codes, {lens[inds[-1]]/args.model_sr:.2f} sec.") | |
| logging.info(f"shortest: {lens[inds[0]]/args.downsample_rate} encodec codes, {lens[inds[0]]/args.model_sr:.2f} sec.") | |
| logging.info(f"median: {lens[inds[len(inds)//2]]/args.downsample_rate} encodec codes, {lens[inds[len(inds)//2]]/args.model_sr:.2f} sec.") | |
| logging.info(f"95 percentile longest: {lens[inds[int(len(inds)*0.95)]]/args.downsample_rate} encodec codes, {lens[inds[int(len(inds)*0.95)]]/args.model_sr:.2f} sec.") | |
| return inds[::-1] | |
| def write_array_to_txt_file(array, filename): | |
| with open(filename, 'w') as f: | |
| for a in array[:-1]: | |
| f.write(' '.join(map(str, a))+'\n') | |
| f.write(' '.join(map(str, array[-1]))) | |
| class mydataset(torch.utils.data.Dataset): | |
| def __init__(self, split): | |
| super().__init__() | |
| self.split = split | |
| self.audio_dir = audio_dir | |
| manifest_fn = os.path.join(encodec_manifest_dir, split+".txt") | |
| cur_sp = int(args.partition.split("/")[0])-1 | |
| total_sp = int(args.partition.split("/")[1]) | |
| with open(manifest_fn, "r") as rf: | |
| self.data = [l.strip().split("\t") for l in rf.readlines()][cur_sp::total_sp] | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, ind): | |
| try: | |
| afn = self.data[ind][0] | |
| fn = os.path.join(self.audio_dir, afn) | |
| audio, sr = torchaudio.load(fn) | |
| if sr != args.model_sr: | |
| audio = torchaudio.transforms.Resample(sr, args.model_sr)(audio) | |
| sr = args.model_sr | |
| assert sr == args.model_sr, sr | |
| except Exception as e: | |
| # logging.info(f"{e}") | |
| return None, None, None | |
| assert audio.ndim==2 and audio.shape[0] == 1, audio.shape | |
| return audio.type(torch.float32).squeeze(0), audio.shape[-1], os.path.splitext(afn)[0] | |
| def collate(self, batch): | |
| lens, audios, segment_ids = [], [], [] | |
| for item in batch: | |
| if item[0] != None: | |
| audios.append(item[0]) | |
| lens.append(item[1]) | |
| segment_ids.append(item[2]) | |
| return audios, lens, segment_ids | |
| # roots | |
| sub_root = args.sub_root | |
| encodec_manifest_dir = os.path.join(args.root, sub_root, "manifest_for_codec") | |
| audio_dir = os.path.join(args.root, sub_root, "audio") | |
| save_manifest_dir = os.path.join(args.root, sub_root,"manifest_final_encodec") | |
| if args.encodec_name == "encodec_6f79c6a8.th": | |
| save_codes_dir = os.path.join(args.root, sub_root,"encodec_4cb") | |
| elif args.encodec_name == "encodec_8cb1024_giga.th": | |
| save_codes_dir = os.path.join(args.root, sub_root,"encodec_8cb") | |
| os.makedirs(save_manifest_dir, exist_ok=True) | |
| os.makedirs(save_codes_dir, exist_ok=True) | |
| def import_encodec(): | |
| from encodec import get_compression_model | |
| userdir = os.path.expanduser("~") | |
| model = get_compression_model(os.path.join(userdir, "VoiceStar", f"pretrained/{args.encodec_name}"), encode_only=True, device="cuda") | |
| model = torch.nn.DataParallel(model) | |
| return model | |
| model = import_encodec() | |
| # setup dataloader | |
| mega_batch_size = 2048 | |
| batch_size = args.batch_size | |
| dataset = mydataset(args.split) | |
| if len(dataset) == 0: | |
| logging.info(f"no data found for split {args.split} partition {args.partition}") | |
| sys.exit(0) | |
| loader = torch.torch.utils.data.DataLoader(dataset, batch_size=mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=dataset.collate) | |
| split = args.split | |
| skip = 0 | |
| logging.info(f"now processing split {split} partition {args.partition}...") | |
| mega_n_steps = int(np.ceil(len(loader.dataset) / mega_batch_size)) | |
| # mega_n_steps = int(np.ceil(len(gs) / mega_batch_size)) | |
| logging.info(f"partition the split {split} into {mega_n_steps} parts, each has at most {mega_batch_size} samples") | |
| mani_fn = os.path.join(save_manifest_dir, f"{split}_{args.partition.replace('/', '=')}.txt") | |
| logging.info(f"manifest for split {split} partition {args.partition.replace('/', '=')}.txt will be saved at {mani_fn}") | |
| with open(mani_fn, "w") as mani_wf: | |
| # with open(mani_fn, "a") as mani_wf: # resume from where we failed | |
| for m, mega_batch in enumerate(tqdm.tqdm(loader, mininterval=60, maxinterval=60)): | |
| logging.info(f"====================================") | |
| logging.info(f"====================================") | |
| logging.info(f"now processing mega step {m+1}/{mega_n_steps}") | |
| try: | |
| lengths = np.array(mega_batch[1]) | |
| sorted_inds = sort_by_audio_len(lengths) | |
| for j in range(len(sorted_inds))[::-1]: | |
| if lengths[sorted_inds[j]] < args.model_sr*args.min_len or lengths[sorted_inds[j]] > args.model_sr*args.len_cap: # skip samples that are too short (shorter than 0.2s), or too big (bigger than 80s) | |
| skip += 1 | |
| del sorted_inds[j] | |
| n_steps = int(np.ceil(len(sorted_inds) / batch_size)) | |
| for n in tqdm.tqdm(range(n_steps), disable=True): | |
| inds_used = sorted_inds[n*batch_size:(n+1)*batch_size] | |
| wav_batch = [mega_batch[0][id] for id in inds_used] | |
| all_lens = [mega_batch[1][id] for id in inds_used] | |
| segment_id_batch = [mega_batch[2][id] for id in inds_used] | |
| padded_wav = torch.nn.utils.rnn.pad_sequence(wav_batch, batch_first=True).unsqueeze(1) # [B, T] -> [B, 1, T] | |
| # Extract discrete codes from EnCodec | |
| with torch.no_grad(): | |
| if max(all_lens) > 300000 and len(all_lens) > 1: # if utterances are long, simply pass half of them at a time | |
| codes = [] | |
| inwav = padded_wav.cuda() | |
| codes.append(model(inwav[:len(inwav)//2])[0].cpu()) | |
| codes.append(model(inwav[len(inwav)//2:])[0].cpu()) | |
| codes = torch.cat(codes, dim=0) | |
| else: | |
| encoded_frames = model(padded_wav.cuda()) | |
| codes = encoded_frames[0].cpu() # [B, n_codebook, T] | |
| for i, length in enumerate(all_lens): | |
| save_fn = os.path.join(save_codes_dir, segment_id_batch[i]+".txt") | |
| actual_len = round(length / args.downsample_rate) # 320 is downsample rate for this model | |
| cur_code = codes[i].tolist() if type(codes) == list else codes[i, :, :actual_len].tolist() | |
| os.makedirs(os.path.dirname(save_fn), exist_ok=True) | |
| write_array_to_txt_file(cur_code, save_fn) | |
| mani_wf.write(f"{segment_id_batch[i]}\t{len(cur_code[0])}\n") # write to manifest file | |
| # if i == 10: | |
| # raise | |
| except Exception as e: | |
| print(f'exception!! at {m+1}') | |
| print(e) | |
| continue | |
| # break | |
| logging.info(f"split {split} partition {args.partition} has {len(loader.dataset)} samples in total, skipped {skip} due to utterance being too long or too short") | |
| # break | |