Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright 2024 The YourMT3 Authors. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Please see the details in the LICENSE file. | |
| """augment.py""" | |
| import numpy as np | |
| import random | |
| from collections import defaultdict | |
| from typing import Optional, Tuple, Union, Callable, Literal, DefaultDict, Set, Any, Dict, List | |
| from utils.note_event_dataclasses import NoteEvent, NoteEventListsBundle | |
| from utils.note2event import check_event_len_from_bundle, mix_note_event_lists_bundle, separate_by_subunit_programs_from_note_event_lists_bundle | |
| from utils.utils import dict_iterator, extend_dict | |
| from copy import deepcopy | |
| EPS = 1e-7 | |
| DRUM_PROGRAM = 128 | |
| UNANNOTATED_PROGRAM = 129 | |
| # ------------------------------------------------------------------------------------- | |
| # shared augmentation helper functions | |
| # ------------------------------------------------------------------------------------- | |
| def audio_random_submix_fn(x: np.ndarray, | |
| random_amp_range: Optional[List[float]] = None, | |
| mask: Optional[np.ndarray] = None, | |
| normalize: bool = True, | |
| dtype: np.dtype = np.float32) -> Tuple[np.ndarray, np.ndarray]: | |
| """ | |
| Randomly submix audio. This function supports batch-wise matrix processing. | |
| Parameters: | |
| - x (np.ndarray): Input audio tensor with shape (b, c, t). | |
| - random_amp_range (List[float], optional): A list containing [min_amp, max_amp]. | |
| Defaults to [0.6, 1.2]. | |
| - mask (np.ndarray, optional): Mask tensor with shape (b, c). Defaults to None. | |
| - dtype (np.dtype): Data type for computations. Defaults to np.float32. | |
| Returns: | |
| - Tuple[np.ndarray, np.ndarray]: Processed audio (stems, mix). | |
| """ | |
| b, c, t = x.shape | |
| if random_amp_range is None: | |
| random_amp_range = [0.6, 1.2] | |
| if len(random_amp_range) == 2: | |
| min_w, max_w = random_amp_range | |
| ws = np.random.uniform(min_w, max_w, size=(b, c)).astype(dtype) | |
| else: | |
| raise ValueError( | |
| f"random_amp_range should be a list of two floats, [min_amp, max_amp] or None, but got {random_amp_range}") | |
| if mask is not None: | |
| ws *= mask # (b, c) | |
| processed_audio_stems = x * ws[:, :, np.newaxis] # (b, c, t) | |
| processed_audio_mix = np.sum(processed_audio_stems, axis=1, keepdims=True) # (b, 1, t) | |
| # Normalize | |
| if normalize is True: | |
| norm_factors = np.max(np.abs(processed_audio_mix), axis=2, keepdims=True) + EPS # (b, 1, 1) | |
| processed_audio_stems /= norm_factors # (b, c, t) | |
| processed_audio_mix /= norm_factors # (b, 1, t) | |
| else: | |
| pass | |
| return processed_audio_stems, processed_audio_mix | |
| def audio_random_submix_processor(sampled_data: Dict[str, Any], | |
| random_amp_range: List[float] = [0.6, 1.2], | |
| audio_masks: Optional[List[Optional[np.ndarray]]] = None, | |
| update_audio_segments: bool = True, | |
| create_processed_audio_array: bool = True) -> None: | |
| """Randomly submix audio from sampled data | |
| Args: | |
| sampled_data: a dictionary containing sampled data. | |
| ['audio_segments']: a list of audio segments with length B, each element with shape (1, num_stems, T) | |
| random_amp_range: a list of two floats, [min_amp, max_amp] | |
| audio_masks: a list of masks. Each mask is binary vector with shape (num_stems,). | |
| update_audio_segments: if True (default), update sampled_data["audio_segments"] in-place. | |
| create_processed_audio_array: if True (default), create a new key "processed_audio_array" in sampled_data for mix audio. | |
| Returns: | |
| None (processed audio is stored in sampled_data["processed_audio_array"]) | |
| NOTE: | |
| - This function creates a new key "processed_audio_array" in sampled_data, in-place of `sampled_data`. | |
| - Input audio should exist in sampled_data["audio_segments"]. | |
| - The created sampled_data["processed_audio_array"] has shape of (B, 1, T) | |
| """ | |
| if update_audio_segments is False and create_processed_audio_array is False: | |
| raise ValueError("At least one of update_audio_segments and create_processed_audio_mix should be True.") | |
| # create a new key "processed_audio" in sampled_data | |
| b = len(sampled_data["audio_segments"]) # sub-batch size | |
| t = sampled_data["audio_segments"][0].shape[2] # audio length | |
| if create_processed_audio_array is True: | |
| sampled_data["processed_audio_array"] = np.zeros((b, 1, t), dtype=np.float32) | |
| # loop over each audio segment | |
| if audio_masks is None: | |
| # no audio mask is provided, randomly submix all audio segments | |
| for i, audio_segment in enumerate(sampled_data["audio_segments"]): | |
| processed_audio_stems, processed_audio_mix = audio_random_submix_fn(x=audio_segment, | |
| random_amp_range=random_amp_range, | |
| mask=None) | |
| if create_processed_audio_array is True: | |
| sampled_data["processed_audio_array"][i, :, :] = processed_audio_mix | |
| if update_audio_segments is True: | |
| sampled_data["audio_segments"][i] = processed_audio_stems | |
| else: | |
| # audio mask is provided, randomly submix audio segments based on the audio mask | |
| for i, (audio_segment, mask) in enumerate(zip(sampled_data["audio_segments"], audio_masks)): | |
| processed_audio_stems, processed_audio_mix = audio_random_submix_fn(x=audio_segment, | |
| random_amp_range=random_amp_range, | |
| mask=mask) | |
| if create_processed_audio_array is True: | |
| sampled_data["processed_audio_array"][i, :, :] = processed_audio_mix | |
| if update_audio_segments is True: | |
| sampled_data["audio_segments"][i] = processed_audio_stems | |
| def drop_random_stems_from_bundle(sampled_data: Dict[str, Any], prob: float = 0.7) -> None: | |
| """ | |
| Drop stems with a probability of `prob` from a bundle containing `note_event_segments` and | |
| `audio_segments`. It also update `programs`, and add `has_unannotated` info. This function | |
| serves as a utility for stem-based data augmentation used by `intra_stem_augment_processor` | |
| and `cross_stem_augment_processor`. | |
| Args: | |
| sampled_data: A dict of sampled data. | |
| prob: The probability of dropping stems from the data. | |
| Returns: | |
| None. The processed data is stored in-place within the `sampled_data` dictionary. | |
| Update keys in sampled_data (in-place): | |
| sampled_data["note_event_segments"]: NoteEventListsBundle | |
| sampled_data["audio_segments"]: NoteEventListsBundle | |
| sampled_data["programs_segments"]: a list of list, drum program is 128. updated. | |
| sampled_data["has_unannotated_segments"]: a list of bool, True if unannotated program 129 is in use. Newly added. | |
| Removed kyes in sampled_data (in-place): | |
| all other keys except for the above are removed. | |
| Function execution time: 16ms for bsz=36 with single worker | |
| """ | |
| # Create a deep copy to avoid modifying the original data. | |
| note_event_segments = deepcopy(sampled_data["note_event_segments"]) | |
| has_unannotated = [] # List of bool, True if unannotated program 129 is in use | |
| for i, (has_stems, note_events, tie_note_events, audio_segment, programs, is_drum) in enumerate( | |
| zip(sampled_data["has_stems_segments"], note_event_segments['note_events'], | |
| note_event_segments['tie_note_events'], sampled_data["audio_segments"], | |
| sampled_data["programs_segments"], sampled_data["is_drum_segments"])): | |
| # Make sure that programs is np.ndarray | |
| if not isinstance(programs, np.ndarray): | |
| programs = np.array(programs) | |
| if has_stems is True and UNANNOTATED_PROGRAM not in programs: | |
| # Get unique and actual presence of instruments. 128 means drums, 129 means unannotated. | |
| uniq_programs = np.unique([ne.program if not ne.is_drum else 128 for ne in (tie_note_events + note_events)]) | |
| # Debug | |
| if DRUM_PROGRAM in uniq_programs: | |
| assert DRUM_PROGRAM in programs, "Drum program 128 not in programs" | |
| if is_drum.any(): | |
| assert DRUM_PROGRAM in programs, "Drum program 128 not in programs" | |
| # Vectorized random choice for each unique_program | |
| rand_sel_prgs = uniq_programs[np.random.rand(len(uniq_programs)) < prob] | |
| if len(rand_sel_prgs) == 0 and len(uniq_programs) != 0: # Make sure at least one program is active | |
| rand_sel_prgs = np.random.choice(uniq_programs, size=1) | |
| programs_mask = np.isin(programs, rand_sel_prgs).astype(np.int32) | |
| drums_mask = programs_mask * is_drum # NOTE: if drums are not annotated as program 128, this would not work properly | |
| _programs_in_use = programs[programs_mask == 1] | |
| _drum_in_use = np.any(drums_mask == 1) # True if any drum is in use | |
| # Drop note_events and tie_note_events in-place | |
| note_events[:] = [ | |
| ne for ne in note_events | |
| if (not ne.is_drum and ne.program in _programs_in_use) or (ne.is_drum and _drum_in_use) | |
| ] | |
| tie_note_events[:] = [ne for ne in tie_note_events if ne.program in _programs_in_use] | |
| # Drop stems from audio_segments, update programs_segments | |
| sampled_data["audio_segments"][i] = audio_segment[:, programs_mask == 1, :] | |
| sampled_data["programs_segments"][i] = programs[programs_mask == 1] | |
| # Create has_unannotated | |
| has_unannotated.append(False) | |
| elif has_stems is True and UNANNOTATED_PROGRAM in programs: | |
| # If unannotated program is included in programs, we only drop 129 with a probability of `prob`. | |
| # `note_event_segments` remains the same. | |
| # TODO: Actually, we can drop any annoated programs, but current datasets are not the case. | |
| uniq_programs = np.unique([ne.program if not ne.is_drum else 128 for ne in (tie_note_events + note_events)]) | |
| if np.random.rand() > prob: | |
| # keep unannotated program, and this will not allow further cross-stem augmentation. | |
| has_unannotated.append(True) | |
| else: | |
| # drop unannotated program | |
| assert UNANNOTATED_PROGRAM not in uniq_programs # 129 is not included here... | |
| sampled_data["audio_segments"][i] = audio_segment[:, programs != 129, :] | |
| sampled_data["programs_segments"][i] = programs[programs != 129] | |
| has_unannotated.append(False) | |
| elif has_stems is False and UNANNOTATED_PROGRAM in programs: | |
| # No stems, but has unannoted program: cannot be used for cross-stem augmentation. | |
| has_unannotated.append(True) | |
| else: | |
| # No stems, no unannotated program: nothing to do. | |
| has_unannotated.append(False) | |
| # Update sampled_data in-place | |
| sampled_data["note_event_segments"] = note_event_segments | |
| sampled_data["has_unannotated_segments"] = has_unannotated | |
| # Remove all other keys except for the above, because they are not used in the downstream pipeline. | |
| keys_to_remove = ['is_drum_segments', 'has_stems_segments'] | |
| for key in keys_to_remove: | |
| del sampled_data[key] | |
| # ------------------------------------------------------------------------------------- | |
| # intra stem augmentation processor | |
| # ------------------------------------------------------------------------------------- | |
| def intra_stem_augment_processor(sampled_data: Dict[str, Any], | |
| random_amp_range: List[float] = [0.6, 1.2], | |
| prob: float = 0.7, | |
| update_audio_segments: bool = True, | |
| submix_audio: bool = True) -> None: | |
| """ | |
| Intra_stem_augmentation | |
| Shape of input: | |
| sampled_data: | |
| ['note_event_segments']['note_events']: | |
| List[List[NoteEvent]] with length B, each element is a list of NoteEvent | |
| with length num_notes | |
| ['note_event_segments']['tie_note_events']: | |
| List[List[NoteEvent]] with length B, each element is a list of NoteEvent | |
| with length num_tie_notes | |
| ['note_event_segments']['start_times']: | |
| List[float] with length B | |
| ['audio_segments']: | |
| np.ndarray with shape(B, num_stems, T) | |
| ['programs_segments']: | |
| np.ndarray with shape(num_stems,) | |
| ['is_drum_segments']: | |
| np.ndarray with shape(num_stems,) | |
| ['has_stems_segments']: | |
| List[bool] with length B | |
| Output (modified in-place): | |
| sampled_data: | |
| ['note_event_segments']: | |
| ['note_events']: | |
| ['tie_note_events']: | |
| ['start_times']: (not modified) | |
| ['audio_segments']: | |
| np.ndarray with shape(1, num_stems, T) | |
| ['processed_audio_array']: # if submix_audio is True | |
| np.ndarray with shape(B, 1, T) | |
| ['programs_segments']: | |
| List[np.ndarray] with length B, each element is a np.ndarray with shape(num_stems,) | |
| ['has_unannotated_segments']: | |
| List[bool] with length B | |
| Execution time: 27 ms for bsz=36 with single worker, including submix audio | |
| """ | |
| # Randomly drop stems: | |
| # - p (0. < p <= 1.) chances to keep each stem, at least one non-drum is guaranteed to be kept. | |
| # - This method modifies the input 'note_event_segments' in-place. | |
| drop_random_stems_from_bundle(sampled_data, prob=prob) | |
| # Audio processing | |
| if submix_audio is True: | |
| # Randomly submix audio, and update audio_segments in-place with random amplitude applied. | |
| audio_random_submix_processor(sampled_data=sampled_data, | |
| random_amp_range=random_amp_range, | |
| audio_masks=None, | |
| update_audio_segments=True, | |
| create_processed_audio_array=True) # mix | |
| # assert "processed_audio_array" in sampled_data.keys() | |
| else: | |
| # NOTE: This is used within the cross-stem augmentation pipeline. | |
| pass | |
| # ------------------------------------------------------------------------------------- | |
| # cross-stem augmentation helper functions | |
| # ------------------------------------------------------------------------------------- | |
| def combined_survival_and_stop(max_k: int = 5, tau: float = 0.3, alpha: float = 1.0) -> Tuple[np.ndarray, np.ndarray]: | |
| """ | |
| Compute the survival function and prob_stop for exponential or Weibull distributions based on the value of alpha. | |
| - S(k) represents the probability of "surviving" up to k-th trial. | |
| - P_stop(k), the stopping probability at trial k is the difference between the survival probabilities at | |
| k-1 and k. | |
| Parameters: | |
| - max_k (int) : Maximum number of trials. k=0, 1, ..., max_k. k=0 means no cross-stem augmentation. | |
| - tau (float) : Scale parameter. Represents average time to the first failure for exponential distribution. | |
| For Weibull distribution, it influences the spread and shape of the distribution. | |
| - alpha (float) : Shape parameter. If alpha=1, the function reduces to exponential distribution. | |
| Otherwise, it represents the Weibull distribution. | |
| Returns: | |
| - survival (array-like) : Computed survival function values. | |
| - prob_stop (array-like) : Computed stop probabilities. | |
| Example 1: | |
| >>> survival_exp, stop_exp = combined_survival_and_stop(max_k=5, tau=0.3, alpha=1.0) | |
| Exponential Survival: [1. 0.74081822 0.54881164 0.40656966 0.30119421 0.22313016] | |
| Exponential Stop Prob: [0.22313016 0.25918178 0.19200658 0.14224198 0.10537545 0.07806405] | |
| Example 2: | |
| max_k = 5 | |
| survival_exp, stop_exp_03 = combined_survival_and_stop(max_k, 0.3, 1) | |
| survival_weibull, stop_weibull = combined_survival_and_stop(max_k, 0.3, 1.5) | |
| import matplotlib.pyplot as plt | |
| plt.plot(range(max_k+1), list(stop_exp_03), 'o-', label='Exponential (tau=0.3)') | |
| plt.plot(range(max_k+1), list(stop_weibull), 's-', label='Weibull (tau=0.3, alpha=1.5)') | |
| plt.title("Stop Probabilities"); plt.xlabel("k"); plt.ylabel("Probability") | |
| plt.legend(); plt.grid(True); plt.show() | |
| References: | |
| - Weibull, Waloddi. "A statistical distribution function of wide applicability." Journal of applied mechanics (1951). | |
| """ | |
| # Generate k values based on max_k | |
| k_values = np.arange(max_k + 1) | |
| # Calculate survival function | |
| if alpha == 1: | |
| survival = np.exp(-k_values * tau) | |
| else: | |
| survival = np.exp(-np.power(k_values * tau, alpha)) | |
| # Calculate prob_stop and normalize | |
| prob_stop_at_k = -np.diff(np.append(survival, 0.)) | |
| return survival, prob_stop_at_k # (max_k+1,), (max_k+1,) | |
| def deterministic_random_ux_sampler(prob_stop_at_k, bsz) -> np.ndarray: | |
| """ | |
| Deterministic random sampler for sampling U\X for cross-stem augmentation. | |
| Args: | |
| prob_stop_at_k (array-like): Probabilities of stopping at k-th trial. | |
| bsz (int) : Batch size. Usually local batch size. | |
| Returns: | |
| ux_count_per_item (array-like): Number of U\X to sample for each item in the batch. | |
| Example: | |
| >>> max_k = 5; tau = 0.3; alpha = 1.0; bsz = 20 | |
| >>> _, prob_stop_at_k = combined_survival_and_stop(max_k, tau, alpha) | |
| prob_stop_at_k: [0.22313016 0.25918178 0.19200658 0.14224198 0.10537545 0.07806405] | |
| >>> np.random.choice(np.arange(max_k+1), size=bsz, p=prob_stop_at_k) | |
| array([1, 4, 1, 3, 0, 3, 0, 2, 5, 0]) | |
| """ | |
| ux_count_per_item = np.random.choice(np.arange(len(prob_stop_at_k)), size=bsz, p=prob_stop_at_k) | |
| return ux_count_per_item | |
| def check_programs_overlap(list_programs: List[np.ndarray], programs: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: | |
| """ | |
| Check if there is any instrument overlap between two lists of programs. | |
| Example: | |
| >>> list_programs = np.array([np.array([1,2,3]), np.array([5,6])], dtype=object) | |
| >>> print(check_programs_overlap(list_programs, np.array([np.array([1,7])], dtype=object))) # Expected [1] | |
| >>> print(check_programs_overlap(list_programs, np.array([np.array([])], dtype=object))) # Expected [] | |
| """ | |
| list_programs_set = set(item for sublist in list_programs for item in sublist) | |
| overlaps = [p for p in programs if p in list_programs_set] | |
| uniq_prg_mask = np.array([p not in list_programs_set for p in programs]) | |
| return np.array(overlaps), uniq_prg_mask | |
| def regroup_program_and_audio_by_minimal_shared_subunits( | |
| gathered_programs: List[np.ndarray], | |
| gathered_audio_array: List[np.ndarray], | |
| max_num_groups: Optional[int] = None | |
| ) -> Tuple[List[List[int]], DefaultDict[Tuple[int, ...], List[Tuple[int, int]]]]: | |
| # Check if each audio has stems | |
| gathered_has_stem = [ | |
| audio_array.shape[1] > 1 for programs, audio_array in zip(gathered_programs, gathered_audio_array) | |
| ] | |
| # Create a dictionary for mapping audio to programs | |
| audio2prg = defaultdict(list) | |
| for i, programs in enumerate(gathered_programs): | |
| for j, value in enumerate(programs): | |
| if gathered_has_stem[i] is True: | |
| audio2prg[(i, j)].append(value) | |
| else: | |
| audio2prg[(i, 0)].append(value) | |
| grouped_prg2audio = defaultdict(list) | |
| for k_tuple, v_list in audio2prg.items(): | |
| grouped_prg2audio[tuple(sorted(v_list))].append(k_tuple) | |
| # defaultdict(list, | |
| # {(61, 69, 71, 72): [(0, 0)], | |
| # (128,): [(1, 0)], ...} | |
| # Limit the number of groups | |
| if max_num_groups is not None: | |
| # randomly merge groups | |
| while len(grouped_prg2audio) > max_num_groups: | |
| # randomly select two groups to merge | |
| k1, k2 = random.sample(list(grouped_prg2audio.keys()), 2) | |
| grouped_prg2audio[k1].extend(grouped_prg2audio[k2]) | |
| del grouped_prg2audio[k2] | |
| grouped_programs = list(grouped_prg2audio.keys()) | |
| return grouped_programs, grouped_prg2audio # (List[Tuple[int]], DefaultDict[Tuple[int], List[int]]) | |
| def audio_random_submix_by_regroup_program_processor(gathered_programs: List[np.ndarray], | |
| gathered_audio_array: np.ndarray, | |
| submix_random_amp_range: List[float] = [0.9, 1.0], | |
| max_num_stems: int = 12) -> Tuple[List[Tuple[int]], np.ndarray]: | |
| """Regroup programs into subunit programs, and submix regrouped audio arrays | |
| Return: | |
| grouped_programs: List[Tuple[int]] | |
| submix_audio_array: np.ndarray with shape (1, num_grouped_submix_audio, T) | |
| """ | |
| # Regroup programs into subunit programs | |
| grouped_programs, grouped_prg2audio = regroup_program_and_audio_by_minimal_shared_subunits( | |
| gathered_programs, gathered_audio_array, max_num_groups=max_num_stems) | |
| # Submix subunit audio arrays, based on the regrouped programs | |
| n_frames = gathered_audio_array[0].shape[2] | |
| submix_audio_array = np.zeros((1, max_num_stems, n_frames), dtype=np.float32) | |
| for i, prgs in enumerate(grouped_programs): | |
| audio_ids = grouped_prg2audio[prgs] # id of gathered_audio_array, e.g.:[(i,j),...] | |
| if len(audio_ids) == 1: | |
| # no need to submix, already subunits | |
| src_idx, stem_idx = audio_ids[0] | |
| submix_audio_array[:, i, :] = gathered_audio_array[src_idx][:, [stem_idx], :] | |
| else: | |
| # submix audio from elements of subunit programs | |
| _submix_audio_list = [gathered_audio_array[src_idx][:, [stem_idx], :] for (src_idx, stem_idx) in audio_ids] | |
| _submix_audio_arr = np.concatenate(_submix_audio_list, axis=1, dtype=np.float32) # (1, C, T) | |
| _, _submix_audio_arr = audio_random_submix_fn(_submix_audio_arr, | |
| random_amp_range=submix_random_amp_range, | |
| normalize=False) | |
| submix_audio_array[:, i, :] = _submix_audio_arr | |
| return [list(prgs) for prgs in grouped_programs], submix_audio_array | |
| # ------------------------------------------------------------------------------------- | |
| # cross stem augmentation processor | |
| # ------------------------------------------------------------------------------------- | |
| def cross_stem_augment_processor( | |
| sampled_data: Dict[str, Any], | |
| sampled_ids: np.ndarray, | |
| get_rand_segments_from_cache_fn: Callable, | |
| random_amp_range: List[float] = [0.6, 1.2], | |
| stem_iaug_prob: float = 0.7, | |
| stem_xaug_policy: Dict = { | |
| "max_k": 3, # max number of external sources used for cross-stem augmentations | |
| "tau": 0.3, # exponential decay rate for cross-stem augmentation | |
| "alpha": 1.0, # shape parameter for Weibull distribution. set 1.0 for exponential. | |
| "max_subunit_stems": 12, # the number of subunit stems to be reduced to | |
| "p_include_singing": | |
| 0.8, # probability of including singing for cross augmented examples. if None, use base probaility. | |
| "no_instr_overlap": True, | |
| "no_drum_overlap": True, | |
| "uhat_intra_stem_augment": True, | |
| }, | |
| max_l: int = 1024, | |
| precomputed_prob_stop_at_k: Optional[np.array] = None, | |
| mix_audio: bool = True, | |
| create_subunit_note_events: bool = False) -> None: | |
| """ | |
| Cross-stem augmentation | |
| Args: | |
| sampled_data: a dictionary containing sampled data. | |
| ['note_event_segments']: a list of NoteEventListsBundle with length B | |
| ['audio_segments']: a list of audio segments with length B, each element with shape (1, num_stems, T) | |
| ['programs_segments']: a list of programs with length B, each element with shape (num_stems,) | |
| ['has_unannotated_segments']: a list of bool with length B | |
| sampled_ids: a numpy array of sampled ids used in sampled_data. (B,) | |
| get_rand_segments_from_cache_fn: a function for getting random segments from cache. | |
| random_amp_range: a list of two floats, [min_amp, max_amp] | |
| stem_iaug_prob: a float, probability of intra-stem augmentation | |
| stem_xaug_policy: a dictionary of cross-stem augmentation policy | |
| - max_k (int) : Maximum number of trials. k=0, 1, ..., max_k. k=0 means no cross-stem augmentation. | |
| - tau (float) : Scale parameter. Represents average time to the first failure for exponential distribution. | |
| For Weibull distribution, it influences the spread and shape of the distribution. | |
| - alpha (float) : Shape parameter. If alpha=1, the function reduces to exponential distribution. | |
| Otherwise, it represents the Weibull distribution. | |
| - max_subunit_stems (int): Maximum number of subunit stems. If larger, they are reduced to this number | |
| by submix. Default: 12 | |
| - p_include_singing (float): Probability of including singing for cross augmented examples. If None, use | |
| base probaility. | |
| - no_instr_overlap (bool): If True, do not allow instrument overlap between X and U\X. | |
| - no_drum_overlap (bool): If True, do not allow drum overlap between X and U\X. | |
| - uhat_intra_stem_augment (bool): If True, apply intra-stem augmentation to U\X. | |
| max_l: a int, maximum number of note events in a note event list. Default: 1024 | |
| precomputed_prob_stop_at_k: a numpy array of precomputed prob_stop_at_k. If None, it will be computed every time. | |
| mix_audio: a bool, if True, mix audio from X and U\X. Default: True | |
| create_subunit_note_events: a bool, if True, create subunit note events. This is necessary for multi channel | |
| decoder training. Default is False. | |
| Returns: | |
| None (processed data is stored in-place within the `sampled_data` dictionary) | |
| Update keys in sampled_data (in-place): | |
| sampled_data["subunit_programs_segments"]: List[List[np.ndarray]], with length B | |
| sampled_data["subunit_note_event_segments"]: List[NoteEventListsBundle], with length B | |
| sampled_data["subunit_audio_array"]: np.ndarray with shape (B, max_subunit_stems, T) | |
| sampled_data["programs_segments"]: List[np.ndarray], with length B | |
| sampled_data["note_event_segments"]: NoteEventListsBundle | |
| sampled_data["has_unannotated_segments"]: List[bool], with length B | |
| sampled_data["processed_audio_array"]: np.ndarray with shape (B, 1, T) | |
| Removed kyes in sampled_data (in-place): | |
| all other keys except for the above are removed. | |
| """ | |
| # Setup parameters | |
| max_k = stem_xaug_policy["max_k"] | |
| tau = stem_xaug_policy["tau"] | |
| alpha = stem_xaug_policy.get("alpha", 1.0) | |
| max_subunit_stems = stem_xaug_policy.get("max_subunit_stems", 12) | |
| p_include_singing = stem_xaug_policy.get("p_include_singing", None) | |
| no_instr_overlap = stem_xaug_policy["no_instr_overlap"] | |
| no_drum_overlap = stem_xaug_policy["no_drum_overlap"] | |
| uhat_intra_stem_augment = stem_xaug_policy["uhat_intra_stem_augment"] | |
| bsz = len(sampled_ids) # local batch size | |
| n_frames = sampled_data["audio_segments"][0].shape[2] | |
| if precomputed_prob_stop_at_k is None: | |
| _, prob_stop_at_k = combined_survival_and_stop(max_k, tau, alpha) | |
| else: | |
| prob_stop_at_k = precomputed_prob_stop_at_k | |
| ux_count_per_item = deterministic_random_ux_sampler(prob_stop_at_k, bsz) | |
| ux_count_sum = int(np.sum(ux_count_per_item)) | |
| # X_in: sampled_data, which we have already applied intra-stem augmentation | |
| # U\X: ux_sampled_data, complement of X in U | |
| ux_sampled_data, _ = get_rand_segments_from_cache_fn( | |
| num_segments=ux_count_sum, | |
| use_ordered_read_pos=False, # fully random sampling segments from cache | |
| sample_excluding_ids=sampled_ids) | |
| # Randomly drop stems from U\X, and update audio stems without submixing audio. | |
| if uhat_intra_stem_augment is True: | |
| intra_stem_augment_processor(sampled_data=ux_sampled_data, | |
| random_amp_range=random_amp_range, | |
| prob=stem_iaug_prob, | |
| update_audio_segments=True, | |
| submix_audio=False) | |
| # Loop for creating X_hat | |
| iter_ux = iter( | |
| zip( | |
| ux_sampled_data['audio_segments'], | |
| dict_iterator(ux_sampled_data['note_event_segments']), | |
| ux_sampled_data['programs_segments'], | |
| ux_sampled_data['has_unannotated_segments'], | |
| )) | |
| iter_x_in = iter( | |
| zip( | |
| sampled_data['audio_segments'], | |
| dict_iterator(sampled_data['note_event_segments']), | |
| sampled_data['programs_segments'], | |
| sampled_data['has_unannotated_segments'], | |
| )) | |
| x_hat = { | |
| "subunit_programs_segments": [], # List[List[np.ndarray]], with length B | |
| "subunit_note_event_segments": [], # List[NoteEventListsBundle], with length B | |
| "subunit_audio_array": np.zeros((bsz, max_subunit_stems, n_frames), | |
| dtype=np.float32), # (B, max_submix_stems, T) | |
| "programs_segments": [], # List[np.ndarray], with length B | |
| "note_event_segments": { | |
| "note_events": [], | |
| "tie_note_events": [], | |
| "start_times": [] | |
| }, # NoteEventListsBundle | |
| "has_unannotated_segments": [], # List[bool], with length B | |
| "processed_audio_array": np.zeros((bsz, 1, n_frames), dtype=np.float32), # mixed audio array, B, 1, T) | |
| } | |
| for i, (audio_array, ne_bundle, programs, has_unannotated) in enumerate(iter_x_in): | |
| num_ux_samples = ux_count_per_item[i] | |
| if num_ux_samples > 0 and has_unannotated is False: | |
| # gather the main source and k external sources | |
| gathered_programs = [programs] | |
| gathered_ne_bundle = ne_bundle # mutable, but ok because `dict_iterator` yields new dict | |
| gathered_audio_array = [audio_array] | |
| for k in range(num_ux_samples): | |
| # Get next external source | |
| ex_audio_array, ex_ne_bundle, ex_programs, ex_has_unannotated = next(iter_ux) | |
| ex_prg_mask = None # None: no need to mask external programs | |
| ex_has_stem = bool(ex_audio_array.shape[1] > 1) | |
| """Criteria for skipping sources""" | |
| if ex_has_unannotated is True: | |
| continue | |
| """Criteria for instrument overlap and drum overlap """ | |
| instr_overlap, uniq_ex_prg_mask = check_programs_overlap(gathered_programs, ex_programs) | |
| if no_instr_overlap is True and len(instr_overlap) > 0: | |
| if np.any(uniq_ex_prg_mask) and ex_has_stem is True: | |
| # mask out non-unique external programs | |
| ex_prg_mask = uniq_ex_prg_mask | |
| else: | |
| # print(i, k, num_ux_samples, ex_programs, | |
| # 'Warning: no unique external programs, skip this source') | |
| continue # no unique external programs, skip this source | |
| else: | |
| # programs is already unique or don't care about overlap | |
| pass | |
| if no_drum_overlap is True and no_instr_overlap is False and DRUM_PROGRAM in instr_overlap: | |
| non_drum_ex_prg_mask = np.array([prg != DRUM_PROGRAM for prg in ex_programs]) | |
| if np.any(non_drum_ex_prg_mask): | |
| # mask only drum external programs | |
| ex_prg_mask = non_drum_ex_prg_mask | |
| else: | |
| # print(i, k, num_ux_samples, ex_programs, | |
| # 'Warning: no non-drum external programs, skip this source') | |
| continue # drum overlapped, but no non-drum programs, skip this source | |
| else: | |
| pass | |
| """Criteria for stopping iteration with respect to max length""" | |
| if check_event_len_from_bundle(gathered_ne_bundle, ex_ne_bundle, max_len=max_l) is False: | |
| # print(i, k, num_ux_samples, 'Warning: max length reached, stop iteration') | |
| break | |
| # Apply mask and gather | |
| if ex_prg_mask is None: | |
| gathered_programs.append(ex_programs) | |
| extend_dict(gathered_ne_bundle, ex_ne_bundle) | |
| gathered_audio_array.append(ex_audio_array) | |
| else: | |
| # apply mask to external programs, and add to list | |
| ex_programs = ex_programs[ex_prg_mask] | |
| gathered_programs.append(ex_programs) | |
| # drop note_events with masked programs, and extend dictionary | |
| _ex_has_drum = np.any(ex_programs == DRUM_PROGRAM) | |
| ex_ne_bundle["note_events"][0] = [ | |
| ne for ne in ex_ne_bundle["note_events"][0] | |
| if (not ne.is_drum and ne.program in ex_programs) or (ne.is_drum and _ex_has_drum) | |
| ] | |
| ex_ne_bundle["tie_note_events"][0] = [ | |
| ne for ne in ex_ne_bundle["tie_note_events"][0] if ne.program in ex_programs | |
| ] | |
| extend_dict(gathered_ne_bundle, ex_ne_bundle) | |
| # apply mask to external audio_array, and add to list | |
| gathered_audio_array.append(ex_audio_array[:, ex_prg_mask, :]) | |
| # print(gathered_programs) | |
| # Regroup gathered programs, and cresate submix by subunits programs | |
| subunit_programs, subunit_audio_array = audio_random_submix_by_regroup_program_processor( | |
| gathered_programs, gathered_audio_array, max_num_stems=max_subunit_stems) | |
| mixed_ne_bundle = mix_note_event_lists_bundle(gathered_ne_bundle, | |
| sort=True, | |
| start_time_to_zero=True, | |
| use_deepcopy=True) #False) | |
| if create_subunit_note_events is True: | |
| subunit_ne_bundle = separate_by_subunit_programs_from_note_event_lists_bundle(mixed_ne_bundle, | |
| subunit_programs, | |
| start_time_to_zero=False, | |
| sort=True) | |
| else: | |
| subunit_ne_bundle = None | |
| x_hat["subunit_note_event_segments"].append(subunit_ne_bundle) | |
| x_hat["subunit_programs_segments"].append(subunit_programs) | |
| x_hat["subunit_audio_array"][i, :subunit_audio_array.shape[1], :] = subunit_audio_array # (B, C, T) | |
| x_hat["programs_segments"].append(np.concatenate(gathered_programs, axis=0)) | |
| extend_dict(x_hat["note_event_segments"], mixed_ne_bundle) | |
| x_hat["has_unannotated_segments"].append(has_unannotated) | |
| else: | |
| num_stems = audio_array.shape[1] | |
| if num_stems > max_subunit_stems: | |
| # If num_stems exceeds max_subunit_stems, randomly select max_subunit_stems stems | |
| subunit_programs, subunit_audio_array = audio_random_submix_by_regroup_program_processor( | |
| [programs], [audio_array], max_num_stems=max_subunit_stems) | |
| else: | |
| subunit_programs = [programs] | |
| subunit_audio_array = audio_array | |
| x_hat["subunit_programs_segments"].append(subunit_programs) | |
| x_hat["subunit_audio_array"][i, :subunit_audio_array.shape[1], :] = subunit_audio_array | |
| if create_subunit_note_events is True: | |
| subunit_ne_bundle = separate_by_subunit_programs_from_note_event_lists_bundle(ne_bundle, | |
| subunit_programs, | |
| start_time_to_zero=True, | |
| sort=True) | |
| else: | |
| subunit_ne_bundle = None | |
| x_hat["subunit_note_event_segments"].append(subunit_ne_bundle) | |
| x_hat["programs_segments"].append(programs) | |
| extend_dict(x_hat["note_event_segments"], ne_bundle) | |
| x_hat["has_unannotated_segments"].append(has_unannotated) | |
| # Mix subunit audio and update subunit audio arrays | |
| if mix_audio is True: | |
| amp_applied_stem_arr, mix_audio_arr = audio_random_submix_fn(x_hat["subunit_audio_array"], | |
| random_amp_range=random_amp_range, | |
| mask=None, | |
| normalize=True) | |
| x_hat["subunit_audio_array"] = amp_applied_stem_arr # (B, C, T) | |
| x_hat["processed_audio_array"] = mix_audio_arr # (B, 1, T) | |
| # Update sampled_data in-place | |
| sampled_data["subunit_programs_segments"] = x_hat["subunit_programs_segments"] | |
| sampled_data["subunit_note_event_segments"] = x_hat["subunit_note_event_segments"] | |
| sampled_data["subunit_audio_array"] = x_hat["subunit_audio_array"] | |
| sampled_data["programs_segments"] = x_hat["programs_segments"] | |
| sampled_data["note_event_segments"] = x_hat["note_event_segments"] | |
| sampled_data["has_unannotated_segments"] = x_hat["has_unannotated_segments"] | |
| sampled_data["processed_audio_array"] = x_hat["processed_audio_array"] | |
| del sampled_data["audio_segments"] | |