Spaces:
Running
Running
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import argparse | |
| import os | |
| from collections import defaultdict, namedtuple | |
| from pathlib import Path | |
| import musdb | |
| import numpy as np | |
| import torch as th | |
| import tqdm | |
| from torch.utils.data import DataLoader | |
| from .audio import AudioFile | |
| ChunkInfo = namedtuple("ChunkInfo", ["file_index", "offset", "local_index"]) | |
| class Rawset: | |
| """ | |
| Dataset of raw, normalized, float32 audio files | |
| """ | |
| def __init__(self, path, samples=None, stride=None, channels=2, streams=None): | |
| self.path = Path(path) | |
| self.channels = channels | |
| self.samples = samples | |
| if stride is None: | |
| stride = samples if samples is not None else 0 | |
| self.stride = stride | |
| entries = defaultdict(list) | |
| for root, folders, files in os.walk(self.path, followlinks=True): | |
| folders.sort() | |
| files.sort() | |
| for file in files: | |
| if file.endswith(".raw"): | |
| path = Path(root) / file | |
| name, stream = path.stem.rsplit('.', 1) | |
| entries[(path.parent.relative_to(self.path), name)].append(int(stream)) | |
| self._entries = list(entries.keys()) | |
| sizes = [] | |
| self._lengths = [] | |
| ref_streams = sorted(entries[self._entries[0]]) | |
| assert ref_streams == list(range(len(ref_streams))) | |
| if streams is None: | |
| self.streams = ref_streams | |
| else: | |
| self.streams = streams | |
| for entry in sorted(entries.keys()): | |
| streams = entries[entry] | |
| assert sorted(streams) == ref_streams | |
| file = self._path(*entry) | |
| length = file.stat().st_size // (4 * channels) | |
| if samples is None: | |
| sizes.append(1) | |
| else: | |
| if length < samples: | |
| self._entries.remove(entry) | |
| continue | |
| sizes.append((length - samples) // stride + 1) | |
| self._lengths.append(length) | |
| if not sizes: | |
| raise ValueError(f"Empty dataset {self.path}") | |
| self._cumulative_sizes = np.cumsum(sizes) | |
| self._sizes = sizes | |
| def __len__(self): | |
| return self._cumulative_sizes[-1] | |
| def total_length(self): | |
| return sum(self._lengths) | |
| def chunk_info(self, index): | |
| file_index = np.searchsorted(self._cumulative_sizes, index, side='right') | |
| if file_index == 0: | |
| local_index = index | |
| else: | |
| local_index = index - self._cumulative_sizes[file_index - 1] | |
| return ChunkInfo(offset=local_index * self.stride, | |
| file_index=file_index, | |
| local_index=local_index) | |
| def _path(self, folder, name, stream=0): | |
| return self.path / folder / (name + f'.{stream}.raw') | |
| def __getitem__(self, index): | |
| chunk = self.chunk_info(index) | |
| entry = self._entries[chunk.file_index] | |
| length = self.samples or self._lengths[chunk.file_index] | |
| streams = [] | |
| to_read = length * self.channels * 4 | |
| for stream_index, stream in enumerate(self.streams): | |
| offset = chunk.offset * 4 * self.channels | |
| file = open(self._path(*entry, stream=stream), 'rb') | |
| file.seek(offset) | |
| content = file.read(to_read) | |
| assert len(content) == to_read | |
| content = np.frombuffer(content, dtype=np.float32) | |
| content = content.copy() # make writable | |
| streams.append(th.from_numpy(content).view(length, self.channels).t()) | |
| return th.stack(streams, dim=0) | |
| def name(self, index): | |
| chunk = self.chunk_info(index) | |
| folder, name = self._entries[chunk.file_index] | |
| return folder / name | |
| class MusDBSet: | |
| def __init__(self, mus, streams=slice(None), samplerate=44100, channels=2): | |
| self.mus = mus | |
| self.streams = streams | |
| self.samplerate = samplerate | |
| self.channels = channels | |
| def __len__(self): | |
| return len(self.mus.tracks) | |
| def __getitem__(self, index): | |
| track = self.mus.tracks[index] | |
| return (track.name, AudioFile(track.path).read(channels=self.channels, | |
| seek_time=0, | |
| streams=self.streams, | |
| samplerate=self.samplerate)) | |
| def build_raw(mus, destination, normalize, workers, samplerate, channels): | |
| destination.mkdir(parents=True, exist_ok=True) | |
| loader = DataLoader(MusDBSet(mus, channels=channels, samplerate=samplerate), | |
| batch_size=1, | |
| num_workers=workers, | |
| collate_fn=lambda x: x[0]) | |
| for name, streams in tqdm.tqdm(loader): | |
| if normalize: | |
| ref = streams[0].mean(dim=0) # use mono mixture as reference | |
| streams = (streams - ref.mean()) / ref.std() | |
| for index, stream in enumerate(streams): | |
| open(destination / (name + f'.{index}.raw'), "wb").write(stream.t().numpy().tobytes()) | |
| def main(): | |
| parser = argparse.ArgumentParser('rawset') | |
| parser.add_argument('--workers', type=int, default=10) | |
| parser.add_argument('--samplerate', type=int, default=44100) | |
| parser.add_argument('--channels', type=int, default=2) | |
| parser.add_argument('musdb', type=Path) | |
| parser.add_argument('destination', type=Path) | |
| args = parser.parse_args() | |
| build_raw(musdb.DB(root=args.musdb, subsets=["train"], split="train"), | |
| args.destination / "train", | |
| normalize=True, | |
| channels=args.channels, | |
| samplerate=args.samplerate, | |
| workers=args.workers) | |
| build_raw(musdb.DB(root=args.musdb, subsets=["train"], split="valid"), | |
| args.destination / "valid", | |
| normalize=True, | |
| samplerate=args.samplerate, | |
| channels=args.channels, | |
| workers=args.workers) | |
| if __name__ == "__main__": | |
| main() | |