Spaces:
Runtime error
Runtime error
| import sys | |
| sys.path.append('..') | |
| from inference import AudioBartInference | |
| from tqdm import tqdm | |
| import os | |
| import pandas as pd | |
| import csv | |
| os.environ["CUDA_VISIBLE_DEVICES"] = "5" | |
| if __name__ == "__main__": | |
| ckpt_path = "/data/jyk/aac_results/clap/clap/checkpoints/epoch_12" | |
| infer_module = AudioBartInference(ckpt_path) | |
| from_encodec = True | |
| csv_path = "/workspace/audiobart/csv/test.csv" | |
| base_path = "/data/jyk/aac_dataset/clotho/encodec" | |
| df = pd.read_csv(csv_path) | |
| save_path = "/workspace/audiobart/csv/predictions/prediction_clap.csv" | |
| f = open(save_path, 'w', newline='') | |
| writer = csv.writer(f) | |
| writer.writerow(['file_path', 'prediction', 'caption_1', 'caption_2', 'caption_3', 'caption_4', 'caption_5']) | |
| print(f"> Making Predictions for model {ckpt_path}...") | |
| for idx in tqdm(range(len(df)), dynamic_ncols=True, colour="red"): | |
| if not from_encodec: | |
| wav_path = df.loc[idx]['file_name'] | |
| else: | |
| wav_path = df.loc[idx]['file_path'] | |
| wav_path = os.path.join(base_path,wav_path) | |
| if not os.path.exists(wav_path): | |
| pass | |
| if not from_encodec: | |
| prediction = infer_module.infer(wav_path) | |
| else: | |
| prediction = infer_module.infer_from_encodec(wav_path) | |
| line = [wav_path, prediction[0], df.loc[idx]['caption_1'], df.loc[idx]['caption_2'],df.loc[idx]['caption_3'],df.loc[idx]['caption_4'],df.loc[idx]['caption_5']] | |
| writer.writerow(line) | |
| f.close() |