Spaces:
Runtime error
Runtime error
| # Copyright 2024 The YourMT3 Authors. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Please see the details in the LICENSE file. | |
| import os | |
| import json | |
| import time | |
| import hashlib | |
| import requests | |
| import tarfile | |
| import warnings | |
| import argparse | |
| from typing import Tuple, Union, Optional, List, Dict, Any | |
| from tqdm import tqdm | |
| import numpy as np | |
| from collections import Counter | |
| from utils.note_event_dataclasses import Note | |
| from utils.note2event import note2note_event | |
| from utils.midi import note_event2midi | |
| from utils.note2event import slice_multiple_note_events_and_ties_to_bundle | |
| from utils.event2note import merge_zipped_note_events_and_ties_to_notes | |
| from utils.metrics import compute_track_metrics | |
| from utils.tokenizer import EventTokenizer, NoteEventTokenizer | |
| from utils.note_event_dataclasses import Note, NoteEvent, Event | |
| from config.vocabulary import GM_INSTR_FULL, GM_INSTR_CLASS_PLUS | |
| from config.config import shared_cfg | |
| def get_checksum(file_path: os.PathLike, buffer_size: int = 65536) -> str: | |
| md5 = hashlib.md5() | |
| with open(file_path, "rb") as f: | |
| while True: | |
| data = f.read(buffer_size) | |
| if not data: | |
| break | |
| md5.update(data) | |
| return md5.hexdigest() | |
| def download_and_extract(data_home: os.PathLike, | |
| url: str, | |
| remove_tar_file: bool = True, | |
| check_sum: Optional[str] = None, | |
| zenodo_token: Optional[str] = None) -> None: | |
| file_name = url.split("/")[-1].split("?")[0] | |
| tar_path = os.path.join(data_home, file_name) | |
| if not os.path.exists(data_home): | |
| os.makedirs(data_home) | |
| if zenodo_token is not None: | |
| url_with_token = f"{url}&token={zenodo_token}" if "?download=1" in url else f"{url}?token={zenodo_token}" | |
| else: | |
| url_with_token = url | |
| response = requests.get(url_with_token, stream=True) | |
| # Check HTTP Status | |
| if response.status_code != 200: | |
| print(f"Failed to download file. Status code: {response.status_code}") | |
| return | |
| total_size = int(response.headers.get('content-length', 0)) | |
| with open(tar_path, "wb") as f: | |
| for chunk in tqdm(response.iter_content(chunk_size=8192), total=total_size // 8192, unit='KB', desc=file_name): | |
| f.write(chunk) | |
| _check_sum = get_checksum(tar_path) | |
| print(f"Checksum (md5): {_check_sum}") | |
| if check_sum is not None and check_sum != _check_sum: | |
| raise ValueError(f"Checksum doesn't match! Expected: {check_sum}, Actual: {_check_sum}") | |
| with tarfile.open(tar_path, "r:gz") as tar: | |
| tar.extractall(data_home) | |
| if remove_tar_file: | |
| os.remove(tar_path) | |
| def create_inverse_vocab(vocab: Dict) -> Dict: | |
| inverse_vocab = {} | |
| for k, vnp in vocab.items(): | |
| for v in vnp: | |
| inverse_vocab[v] = (vnp[0], k) # (program, str_instrument_name) | |
| return inverse_vocab | |
| def create_program2channel_vocab(program_vocab: Dict, drum_program: int = 128, force_assign_13_ch: bool = False): | |
| """ | |
| Create a direct map for programs to indices, instrument groups, and primary programs. | |
| Args: | |
| program_vocab (dict): A dictionary of program vocabularies. | |
| drum_program (int): The program number for drums. Default: 128. | |
| Returns: | |
| program2channel_vocab (dict): A dictionary of program to indices, instrument groups, and primary programs. | |
| e.g. { | |
| 0: {'channel': 0, 'instrument_group': 'Piano', 'primary_program': 0}, | |
| 1: {'channel': 1, 'instrument_group': 'Chromatic Percussion', 'primary_program': 8}, | |
| ... | |
| 100: {'channel': 11, 'instrument_group': 'Singing Voice', 'primary_program': 100}, | |
| 128: {'channel': 12, 'instrument_group': 'Drums', 'primary_program': 128} | |
| } | |
| "primary_program" is not used now. | |
| num_channels (int): The number of channels. Typically length of program vocab + 1 (for drums) | |
| """ | |
| num_channels = len(program_vocab) + 1 | |
| program2channel_vocab = {} | |
| for idx, (instrument_group, programs) in enumerate(program_vocab.items()): | |
| if idx > num_channels: | |
| raise ValueError( | |
| f"📕 The number of channels ({num_channels}) is less than the number of instrument groups ({idx})") | |
| for program in programs: | |
| if program in program2channel_vocab: | |
| raise ValueError(f"📕 program {program} is duplicated in program_vocab") | |
| else: | |
| program2channel_vocab[program] = { | |
| "channel": int(idx), | |
| "instrument_group": str(instrument_group), | |
| "primary_program": int(programs[0]), | |
| } | |
| # Add drums | |
| if drum_program in program2channel_vocab.keys(): | |
| raise ValueError( | |
| f"📕 drum_program {drum_program} is duplicated in program_vocab. program_vocab should not include drum or program 128" | |
| ) | |
| else: | |
| program2channel_vocab[drum_program] = { | |
| "channel": idx + 1, | |
| "instrument_group": "Drums", | |
| "primary_program": drum_program, | |
| } | |
| return program2channel_vocab, num_channels | |
| def write_model_output_as_npy(data, output_dir, track_id): | |
| output_dir = os.path.join(output_dir, "model_output") | |
| os.makedirs(output_dir, exist_ok=True) | |
| output_file = os.path.join(output_dir, f"output_{track_id}.npy") | |
| np.save(output_file, data, allow_pickle=True) | |
| def write_model_output_as_midi(notes: List[Note], | |
| output_dir: os.PathLike, | |
| track_id: str, | |
| output_inverse_vocab: Optional[Dict] = None, | |
| output_dir_suffix: Optional[str] = None) -> None: | |
| if output_dir_suffix is not None: | |
| output_dir = os.path.join(output_dir, f"model_output/{output_dir_suffix}") | |
| else: | |
| output_dir = os.path.join(output_dir, "model_output") | |
| os.makedirs(output_dir, exist_ok=True) | |
| output_file = os.path.join(output_dir, f"{track_id}.mid") | |
| if output_inverse_vocab is not None: | |
| # Convert the note events to the output vocabulary | |
| new_notes = [] | |
| for note in notes: | |
| if note.is_drum: | |
| new_notes.append(note) | |
| else: | |
| new_notes.append( | |
| Note(is_drum=note.is_drum, | |
| program=output_inverse_vocab.get(note.program, [note.program])[0], | |
| onset=note.onset, | |
| offset=note.offset, | |
| pitch=note.pitch, | |
| velocity=note.velocity)) | |
| note_events = note2note_event(new_notes, return_activity=False) | |
| note_event2midi(note_events, output_file, output_inverse_vocab=output_inverse_vocab) | |
| def write_err_cnt_as_json( | |
| track_id: str, | |
| output_dir: os.PathLike, | |
| output_dir_suffix: Optional[str] = None, | |
| note_err_cnt: Optional[Counter] = None, | |
| note_event_err_cnt: Optional[Counter] = None, | |
| ): | |
| if output_dir_suffix is not None: | |
| output_dir = os.path.join(output_dir, f"model_output/{output_dir_suffix}") | |
| else: | |
| output_dir = os.path.join(output_dir, "model_output") | |
| os.makedirs(output_dir, exist_ok=True) | |
| output_file = os.path.join(output_dir, f"error_count_{track_id}.json") | |
| output_dict = {} | |
| if note_err_cnt is not None: | |
| output_dict['note_err_cnt'] = dict(note_err_cnt) | |
| if note_event_err_cnt is not None: | |
| output_dict['note_event_err_cnt'] = dict(note_event_err_cnt) | |
| output_str = json.dumps(output_dict, indent=4) | |
| with open(output_file, 'w') as json_file: | |
| json_file.write(output_str) | |
| class Timer: | |
| """A simple timer class to measure elapsed time. | |
| Usage: | |
| with Timer() as t: | |
| # Your code here | |
| time.sleep(2) | |
| t.print_elapsed_time() | |
| """ | |
| def __init__(self) -> None: | |
| self.start_time = None | |
| self.end_time = None | |
| def start(self) -> None: | |
| self.start_time = time.time() | |
| def stop(self) -> None: | |
| self.end_time = time.time() | |
| def elapsed_time(self) -> float: | |
| if self.start_time is None: | |
| raise ValueError("Timer has not been started yet.") | |
| if self.end_time is None: | |
| raise ValueError("Timer has not been stopped yet.") | |
| return self.end_time - self.start_time | |
| def print_elapsed_time(self, message: Optional[str] = None) -> float: | |
| elapsed_seconds = self.elapsed_time() | |
| minutes, seconds = divmod(elapsed_seconds, 60) | |
| milliseconds = (elapsed_seconds % 1) * 1000 | |
| if message is not None: | |
| text = f"⏰ {message}: {int(minutes)}m {int(seconds)}s {milliseconds:.2f}ms" | |
| else: | |
| text = f"⏰ elapse time: {int(minutes)}m {int(seconds)}s {milliseconds:.2f}ms" | |
| print(text) | |
| return elapsed_seconds | |
| def reset(self) -> None: | |
| self.start_time = None | |
| self.end_time = None | |
| def __enter__(self) -> 'Timer': | |
| self.start() | |
| return self | |
| def __exit__(self, exc_type, exc_val, exc_tb) -> None: | |
| self.stop() | |
| def merge_file_lists(file_lists: List[Dict]) -> Dict[int, Any]: | |
| """ Merge file lists from different datasets, and return a reindexed | |
| dictionary of file list.""" | |
| merged_file_list = {} | |
| index = 0 | |
| for file_list in file_lists: | |
| for v in file_list.values(): | |
| merged_file_list[index] = v | |
| index += 1 | |
| return merged_file_list | |
| def merge_splits(splits: List[str], dataset_name: Union[str, List[str]]) -> Dict[int, Any]: | |
| """ | |
| merge_splits: | |
| - Merge multiple splits from different datasets, and return a reindexed | |
| dictionary of file list. | |
| - It is also possible to merge splits from different datasets. | |
| """ | |
| n_splits = len(splits) | |
| if n_splits > 1 and isinstance(dataset_name, str): | |
| dataset_name = [dataset_name] * n_splits | |
| elif n_splits > 1 and isinstance(dataset_name, list) and len(dataset_name) != n_splits: | |
| raise ValueError("The number of dataset names in list must be equal to the number of splits.") | |
| else: | |
| pass | |
| # load file_list dictionaries | |
| data_home = shared_cfg['PATH']['data_home'] | |
| file_lists = [] # list of dictionaries | |
| for i, s in enumerate(splits): | |
| json_file = (f"{data_home}/yourmt3_indexes/{dataset_name[i]}_{s}_file_list.json") | |
| # Fix for missing file_list with incomplete dataset package | |
| if not os.path.exists(json_file): | |
| warnings.warn( | |
| f"File list {json_file} does not exist. If you don't have a complete package of dataset, ignore this warning..." | |
| ) | |
| return {} | |
| with open(json_file, 'r') as j: | |
| file_lists.append(json.load(j)) | |
| merged_file_list = merge_file_lists(file_lists) # reindexed, merged file list | |
| return merged_file_list | |
| def reindex_file_list_keys(file_list: Dict[str, Any]) -> Dict[int, Any]: | |
| """ Reindex file list keys from 0 to total count.""" | |
| reindexed_file_list = {} | |
| for i, (k, v) in enumerate(file_list.items()): | |
| reindexed_file_list[i] = v | |
| return reindexed_file_list | |
| def remove_ids_from_file_list(file_list: Dict[str, Any], | |
| selected_ids: List[int], | |
| reindex: bool = True) -> Dict[int, Any]: | |
| """ Remove selected ids from file list.""" | |
| key = None | |
| for v in file_list.values(): | |
| # search keys that contain 'id' | |
| for k in v.keys(): | |
| if 'id' in k: | |
| key = k | |
| break | |
| if key: | |
| break | |
| if key is None: | |
| raise ValueError("No key contains 'id'.") | |
| # generate new filelist by removing selected ids | |
| selected_ids = [str(id) for id in selected_ids] # ids to str | |
| file_list = {k: v for k, v in file_list.items() if str(v[key]) not in selected_ids} | |
| if reindex: | |
| return reindex_file_list_keys(file_list) | |
| else: | |
| return file_list | |
| def deduplicate_splits(split_a: Union[str, Dict], | |
| split_b: Union[str, Dict], | |
| dataset_name: Optional[str] = None, | |
| reindex: bool = True) -> Dict[int, Any]: | |
| """Remove overlapping splits in file_list A with splits from file_list B, | |
| and return a reindexed dictionary of files.""" | |
| data_home = shared_cfg['PATH']['data_home'] | |
| if isinstance(split_a, str): | |
| json_file_a = (f"{data_home}/yourmt3_indexes/{dataset_name}_{split_a}_file_list.json") | |
| with open(json_file_a, 'r') as j: | |
| file_list_a = json.load(j) | |
| elif isinstance(split_a, dict): | |
| file_list_a = split_a | |
| if isinstance(split_b, str): | |
| json_file_b = (f"{data_home}/yourmt3_indexes/{dataset_name}_{split_b}_file_list.json") | |
| with open(json_file_b, 'r') as j: | |
| file_list_b = json.load(j) | |
| elif isinstance(split_b, dict): | |
| file_list_b = split_b | |
| # Get the key that contains 'id' from file_list_a splits | |
| id_key = None | |
| for v in file_list_a.values(): | |
| for k in v.keys(): | |
| if 'id' in k: | |
| id_key = k | |
| break | |
| if id_key: | |
| break | |
| if id_key is None: | |
| raise ValueError("No key contains 'id' in file_list_a.") | |
| # Get IDs from file_list_b splits | |
| ids_b = set(str(v.get(id_key, '')) for v in file_list_b.values()) | |
| # Extract IDs from file_list_a splits | |
| ids_a = [str(v.get(id_key, '')) for v in file_list_a.values()] | |
| # Remove IDs from file_list_a that are also in file_list_b | |
| ids_to_remove = list(set(ids_a).intersection(ids_b)) | |
| filtered_file_list_a = remove_ids_from_file_list(file_list_a, ids_to_remove, reindex) | |
| return filtered_file_list_a | |
| def merge_vocab(vocab_list: List[Dict]) -> Dict[str, Any]: | |
| """ Merge file lists from different datasets, and return a reindexed | |
| dictionary of file list.""" | |
| merged_vocab = {} | |
| for vocab in vocab_list: | |
| for k, v in vocab.items(): | |
| if k not in merged_vocab.keys(): | |
| merged_vocab[k] = v | |
| return merged_vocab | |
| def assert_note_events_almost_equal(actual_note_events, | |
| predicted_note_events, | |
| ignore_time=False, | |
| ignore_activity=True, | |
| delta=5.1e-3): | |
| """ | |
| Asserts that the given lists of Note instances are equal up to a small | |
| floating-point tolerance, similar to `assertAlmostEqual` of `unittest`. | |
| Tolerance is 5e-3 by default, which is 5 ms for 100 ticks-per-second. | |
| If `ignore_time` is True, then the time field is ignored. (useful for | |
| comparing tie note events, default is False) | |
| If `ignore_activity` is True, then the activity field is ignored (default | |
| is True). | |
| """ | |
| assert len(actual_note_events) == len(predicted_note_events) | |
| for j, (actual_note_event, predicted_note_event) in enumerate(zip(actual_note_events, predicted_note_events)): | |
| if ignore_time is False: | |
| assert abs(actual_note_event.time - predicted_note_event.time) <= delta, (j, actual_note_event, | |
| predicted_note_event) | |
| assert actual_note_event.is_drum == predicted_note_event.is_drum, (j, actual_note_event, predicted_note_event) | |
| assert actual_note_event.program == predicted_note_event.program, (j, actual_note_event, predicted_note_event) | |
| assert actual_note_event.pitch == predicted_note_event.pitch, (j, actual_note_event, predicted_note_event) | |
| assert actual_note_event.velocity == predicted_note_event.velocity, (j, actual_note_event, predicted_note_event) | |
| if ignore_activity is False: | |
| assert actual_note_event.activity == predicted_note_event.activity, (j, actual_note_event, | |
| predicted_note_event) | |
| def note_event2token2note_event_sanity_check(note_events: List[NoteEvent], | |
| notes: List[Note], | |
| report_err_cnt=False) -> Counter: | |
| # slice note events | |
| max_time = note_events[-1].time | |
| num_segs = int(max_time * 16000 // 32757 + 1) | |
| seg_len_sec = 32767 / 16000 | |
| start_times = [i * seg_len_sec for i in range(num_segs)] | |
| note_event_segments = slice_multiple_note_events_and_ties_to_bundle( | |
| note_events, | |
| start_times, | |
| seg_len_sec, | |
| ) | |
| # encode | |
| tokenizer = NoteEventTokenizer() | |
| token_array = np.zeros((num_segs, 1024), dtype=np.int32) | |
| for i, tup in enumerate(list(zip(*note_event_segments.values()))): | |
| padded_tokens = tokenizer.encode_plus(*tup) | |
| token_array[i, :] = padded_tokens | |
| # decode: warning: Invalid pitch event without program or velocity --> solved | |
| zipped_note_events_and_tie, list_events, err_cnt = tokenizer.decode_list_batches([token_array], | |
| start_times, | |
| return_events=True) | |
| if report_err_cnt: | |
| # report error and do not break.. | |
| err_cnt_all = err_cnt | |
| else: | |
| assert len(err_cnt) == 0 | |
| err_cnt_all = Counter() | |
| # First check, the number of empty note_events and tie_note_events | |
| cnt_org_empty = 0 | |
| cnt_recon_empty = 0 | |
| for i, (recon_note_events, recon_tie_note_events, _, _) in enumerate(zipped_note_events_and_tie): | |
| org_note_events = note_event_segments['note_events'][i] | |
| org_tie_note_events = note_event_segments['tie_note_events'][i] | |
| if org_note_events == []: | |
| cnt_org_empty += 1 | |
| if recon_note_events == []: | |
| cnt_recon_empty += 1 | |
| # assert len(org_note_events) == len(recon_note_events) # passed after bug fix | |
| # Check the reconstruction of note_events | |
| for i, (recon_note_events, recon_tie_note_events, _, _) in enumerate(zipped_note_events_and_tie): | |
| org_note_events = note_event_segments['note_events'][i] | |
| org_tie_note_events = note_event_segments['tie_note_events'][i] | |
| org_note_events.sort(key=lambda n_ev: (n_ev.time, n_ev.is_drum, n_ev.program, n_ev.velocity, n_ev.pitch)) | |
| org_tie_note_events.sort(key=lambda n_ev: (n_ev.program, n_ev.pitch)) | |
| recon_note_events.sort(key=lambda n_ev: (n_ev.time, n_ev.is_drum, n_ev.program, n_ev.velocity, n_ev.pitch)) | |
| recon_tie_note_events.sort(key=lambda n_ev: (n_ev.program, n_ev.pitch)) | |
| #assert_note_events_almost_equal(org_note_events, recon_note_events) | |
| # assert_note_events_almost_equal( | |
| # org_tie_note_events, recon_tie_note_events, ignore_time=True) | |
| # Check notes: of course this fails.. and a lot of warning for cut off 20s | |
| recon_notes, err_cnt = merge_zipped_note_events_and_ties_to_notes(zipped_note_events_and_tie, fix_offset=False) | |
| # assert len(err_cnt) == 0 # this error is due to the cut off 5 seconds... | |
| # Check metric | |
| drum_metric, non_drum_metric, instr_metric = compute_track_metrics(recon_notes, | |
| notes, | |
| eval_vocab=GM_INSTR_FULL, | |
| onset_tolerance=0.005) # 5ms | |
| if not np.isnan(non_drum_metric['offset_f']) and non_drum_metric['offset_f'] != 1.0: | |
| warnings.warn(f"non_drum_metric['offset_f'] = {non_drum_metric['offset_f']}") | |
| assert non_drum_metric['onset_f'] > 0.99 | |
| if not np.isnan(drum_metric['onset_f_drum']) and non_drum_metric['offset_f'] != 1.0: | |
| warnings.warn(f"drum_metric['offset_f'] = {drum_metric['offset_f']}") | |
| assert drum_metric['offset_f'] > 0.99 | |
| return err_cnt_all + Counter(err_cnt) | |
| def str2bool(v): | |
| """ | |
| Converts a string value to a boolean value. | |
| Args: | |
| v: The string value to convert. | |
| Returns: | |
| The boolean value equivalent of the input string. | |
| Raises: | |
| ArgumentTypeError: If the input string is not a valid boolean value. | |
| """ | |
| if v.lower() in ('yes', 'true', 't', 'y', '1'): | |
| return True | |
| elif v.lower() in ('no', 'false', 'f', 'n', '0'): | |
| return False | |
| else: | |
| raise argparse.ArgumentTypeError('Boolean value expected.') | |
| def freq_to_midi(freq): | |
| return round(69 + 12 * np.log2(freq / 440)) | |
| def dict_iterator(d: Dict): | |
| """ | |
| This function is used to iterate over a dictionary of lists. | |
| As an output, it yields a newly created instance of a dictionary | |
| """ | |
| for values in zip(*d.values()): | |
| yield {k: [v] for k, v in zip(d.keys(), values)} | |
| def extend_dict(dict1: dict, dict2: dict) -> None: | |
| """ | |
| Extends the lists in dict1 with the corresponding lists in dict2. | |
| Modifies dict1 in-place and does not return anything. | |
| Args: | |
| dict1 (dict): The dictionary to be extended. | |
| dict2 (dict): The dictionary with lists to extend dict1. | |
| Example: | |
| dict1 = {'a': [1,2,3], 'b':[4,5,6]} | |
| dict2 = {'a':[10], 'b':[17]} | |
| extend_dict_in_place(dict1, dict2) | |
| print(dict1) # Outputs: {'a': [1, 2, 3, 10], 'b': [4, 5, 6, 17]} | |
| """ | |
| for key in dict1: | |
| dict1[key].extend(dict2[key]) | |