Spaces:
Runtime error
Runtime error
| # Copyright 2024 The YourMT3 Authors. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Please see the details in the LICENSE file. | |
| import os | |
| import json | |
| import torch | |
| import numpy as np | |
| from torch.utils.data import DistributedSampler | |
| from torch.utils.data import Dataset, Sampler | |
| from torch.utils.data import RandomSampler, WeightedRandomSampler | |
| from operator import itemgetter | |
| from typing import List, Tuple, Union, Iterator, Optional | |
| from config.data_presets import data_preset_single_cfg, data_preset_multi_cfg | |
| from config.config import shared_cfg | |
| class DatasetFromSampler(Dataset): | |
| """Dataset to create indexes from `Sampler`. From catalyst library. | |
| Args: | |
| sampler: PyTorch sampler | |
| """ | |
| def __init__(self, sampler: Sampler): | |
| """Initialisation for DatasetFromSampler.""" | |
| self.sampler = sampler | |
| self.sampler_list = None | |
| def __getitem__(self, index: int): | |
| """Gets element of the dataset. | |
| Args: | |
| index: index of the element in the dataset | |
| Returns: | |
| Single element by index | |
| """ | |
| if self.sampler_list is None: | |
| self.sampler_list = list(self.sampler) | |
| return self.sampler_list[index] | |
| def __len__(self) -> int: | |
| """ | |
| Returns: | |
| int: length of the dataset | |
| """ | |
| return len(self.sampler) | |
| class DistributedSamplerWrapper(DistributedSampler): | |
| """ | |
| Wrapper over `Sampler` for distributed training. | |
| Allows to use any sampler in distributed mode. | |
| From https://github.com/catalyst-team/catalyst/blob/master/catalyst/data/sampler.py | |
| It is especially useful in conjunction with | |
| `torch.nn.parallel.DistributedDataParallel`. In such case, each | |
| process can pass a DistributedSamplerWrapper instance as a DataLoader | |
| sampler, and load a subset of subsampled data of the original dataset | |
| that is exclusive to it. | |
| .. note:: | |
| Sampler is assumed to be of constant size. | |
| """ | |
| def __init__( | |
| self, | |
| sampler, | |
| num_replicas: Optional[int] = None, | |
| rank: Optional[int] = None, | |
| shuffle: bool = True, | |
| ): | |
| """ | |
| Args: | |
| sampler: Sampler used for subsampling | |
| num_replicas (int, optional): Number of processes participating in | |
| distributed training | |
| rank (int, optional): Rank of the current process | |
| within ``num_replicas`` | |
| shuffle (bool, optional): If true (default), | |
| sampler will shuffle the indices | |
| """ | |
| super(DistributedSamplerWrapper, self).__init__( | |
| DatasetFromSampler(sampler), | |
| num_replicas=num_replicas, | |
| rank=rank, | |
| shuffle=shuffle, | |
| ) | |
| self.sampler = sampler | |
| def __iter__(self) -> Iterator[int]: | |
| """Iterate over sampler. | |
| Returns: | |
| python iterator | |
| """ | |
| self.dataset = DatasetFromSampler(self.sampler) | |
| indexes_of_indexes = super().__iter__() | |
| subsampler_indexes = self.dataset | |
| return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes)) | |
| def discount_to_target(samples: np.ndarray, target_sum: int) -> np.ndarray: | |
| """Discounts samples to target sum. | |
| NOTE: this function is deprecated. | |
| This function adjusts an array of sample values so that their sum equals a target sum, while ensuring | |
| that each element remains greater than or equal to 1 and attempting to maintain a distribution similar | |
| to the original. | |
| Example 1: | |
| samples = np.array([3, 1, 1, 1, 1, 1]) | |
| target_sum = 7 | |
| discounted_samples = discount_to_target(samples, target_sum) | |
| # [2, 1, 1, 1, 1, 1] | |
| Example 2: | |
| samples = np.array([3,1, 10, 1, 1, 1]) | |
| target_sum = 7 | |
| # [1, 1, 2, 1, 1, 1] | |
| Parameters: | |
| samples (np.ndarray): Original array of sample values. | |
| target_sum (int): The desired sum of the sample array. | |
| Returns: | |
| np.ndarray: Adjusted array of sample values whose sum should equal the target sum, | |
| and where each element is greater than or equal to 1. | |
| """ | |
| samples = samples.copy().astype(int) | |
| if samples.sum() <= target_sum: | |
| samples[0] += 1 | |
| return samples | |
| while samples.sum() > target_sum: | |
| # indices of all elements larger than 1 | |
| indices_to_discount = np.where(samples > 1)[0] | |
| if indices_to_discount.size == 0: | |
| # No elements left to discount, we cannot reach target_sum without going below 1 | |
| print("Cannot reach target sum without going below 1 for some elements.") | |
| return samples | |
| discount_count = int(min(len(indices_to_discount), samples.sum() - target_sum)) | |
| indices_to_discount = indices_to_discount[:discount_count] | |
| samples[indices_to_discount] -= 1 | |
| return samples | |
| def create_merged_train_dataset_info(data_preset_multi: dict, data_home: Optional[os.PathLike] = None): | |
| """Create merged dataset info from data preset multi. | |
| Args: | |
| data_preset_multi (dict): data preset multi | |
| data_home (os.PathLike, optional): path to data home. If None, used the path defined | |
| in config/config.py. | |
| Returns: | |
| dict: merged dataset info | |
| """ | |
| train_dataset_info = { | |
| "n_datasets": 0, | |
| "n_notes_per_dataset": None, # TODO: not implemented yet... | |
| "n_files_per_dataset": [], | |
| "dataset_names": [], # dataset names by order of merging file lists | |
| "data_split_names": [], # dataset names by order of merging file lists | |
| "index_ranges": [], # index ranges of each dataset in the merged file list | |
| "dataset_weights": None, # pre-defined list of dataset weights for sampling, if available | |
| "merged_file_list": {}, | |
| } | |
| if data_home is None: | |
| data_home = shared_cfg["PATH"]["data_home"] | |
| assert os.path.exists(data_home) | |
| for dp in data_preset_multi["presets"]: | |
| train_dataset_info["n_datasets"] += 1 | |
| dataset_name = data_preset_single_cfg[dp]["dataset_name"] | |
| train_dataset_info["dataset_names"].append(dataset_name) | |
| train_dataset_info["data_split_names"].append(dp) | |
| # load file list for train split | |
| if isinstance(data_preset_single_cfg[dp]["train_split"], str): | |
| train_split_name = data_preset_single_cfg[dp]["train_split"] | |
| file_list_path = os.path.join(data_home, 'yourmt3_indexes', | |
| f'{dataset_name}_{train_split_name}_file_list.json') | |
| # check if file list exists | |
| if not os.path.exists(file_list_path): | |
| raise ValueError(f"File list {file_list_path} does not exist.") | |
| _file_list = json.load(open(file_list_path, 'r')) | |
| elif isinstance(data_preset_single_cfg[dp]["train_split"], dict): | |
| _file_list = data_preset_single_cfg[dp]["train_split"] | |
| else: | |
| raise ValueError("Invalid train split.") | |
| # merge file list | |
| start_idx = len(train_dataset_info["merged_file_list"]) | |
| for i, v in enumerate(_file_list.values()): | |
| train_dataset_info["merged_file_list"][start_idx + i] = v | |
| train_dataset_info["n_files_per_dataset"].append(len(_file_list)) | |
| train_dataset_info["index_ranges"].append((start_idx, start_idx + len(_file_list))) | |
| # set dataset weights | |
| if "weights" in data_preset_multi.keys() and data_preset_multi["weights"] is not None: | |
| train_dataset_info["dataset_weights"] = data_preset_multi["weights"] | |
| assert len(train_dataset_info["dataset_weights"]) == train_dataset_info["n_datasets"] | |
| else: | |
| train_dataset_info["dataset_weights"] = np.ones(train_dataset_info["n_datasets"]) | |
| print("No dataset weights specified, using equal weights for all datasets.") | |
| return train_dataset_info | |
| def get_random_sampler(dataset, num_samples): | |
| if torch.distributed.is_initialized(): | |
| return DistributedSamplerWrapper(sampler=RandomSampler(dataset, num_samples=num_samples)) | |
| else: | |
| return RandomSampler(dataset, num_samples=num_samples) | |
| def get_weighted_random_sampler(dataset_weights: List[float], | |
| dataset_index_ranges: List[Tuple[int]], | |
| num_samples_per_epoch: Optional[int] = None, | |
| replacement: bool = True) -> torch.utils.data.sampler.Sampler: | |
| """Get distributed weighted random sampler. | |
| Args: | |
| dataset_weights (List[float]): list of dataset weights of n length for n_datasets | |
| dataset_index_ranges (List[Tuple[int]]): list of dataset index ranges | |
| n_samples_per_epoch (Optional[int]): number of samples per epoch, typically length of | |
| entire dataset. Defaults to None. If None, the total number of samples is calculated. | |
| replacement (bool, optional): replacement. Defaults to True. | |
| Returns: | |
| (distributed) weighted random sampler | |
| """ | |
| assert len(dataset_weights) == len(dataset_index_ranges) | |
| sample_weights = [] | |
| n_total_samples_in_datasets = dataset_index_ranges[-1][1] | |
| if len(dataset_weights) > 1 and len(dataset_index_ranges) > 1: | |
| for dataset_weight, index_range in zip(dataset_weights, dataset_index_ranges): | |
| assert dataset_weight >= 0 | |
| n_samples_in_dataset = index_range[1] - index_range[0] | |
| sample_weight = dataset_weight * (1 - n_samples_in_dataset / n_total_samples_in_datasets) | |
| # repeat the same weight for the number of samples in the dataset | |
| sample_weights += [sample_weight] * (index_range[1] - index_range[0]) | |
| elif len(dataset_weights) == 1 and len(dataset_index_ranges) == 1: | |
| # Single dataset | |
| sample_weights = [1] * n_total_samples_in_datasets | |
| if num_samples_per_epoch is None: | |
| num_samples_per_epoch = n_total_samples_in_datasets | |
| sampler = WeightedRandomSampler(weights=sample_weights, num_samples=num_samples_per_epoch, replacement=replacement) | |
| if torch.distributed.is_initialized(): | |
| return DistributedSamplerWrapper(sampler=sampler) | |
| else: | |
| return sampler | |
| def get_list_of_weighted_random_samplers(num_samplers: int, | |
| dataset_weights: List[float], | |
| dataset_index_ranges: List[Tuple[int]], | |
| num_samples_per_epoch: Optional[int] = None, | |
| replacement: bool = True) -> List[torch.utils.data.sampler.Sampler]: | |
| """Get list of distributed weighted random samplers. | |
| Args: | |
| dataset_weights (List[float]): list of dataset weights of n length for n_datasets | |
| dataset_index_ranges (List[Tuple[int]]): list of dataset index ranges | |
| n_samples_per_epoch (Optional[int]): number of samples per epoch, typically length of | |
| entire dataset. Defaults to None. If None, the total number of samples is calculated. | |
| replacement (bool, optional): replacement. Defaults to True. | |
| Returns: | |
| List[(distributed) weighted random sampler] | |
| """ | |
| assert num_samplers > 0 | |
| samplers = [] | |
| for i in range(num_samplers): | |
| samplers.append( | |
| get_weighted_random_sampler(dataset_weights, dataset_index_ranges, num_samples_per_epoch, replacement)) | |
| return samplers | |