Spaces:
Runtime error
Runtime error
File size: 4,945 Bytes
d572f56 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import inspect
import os
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union
import numpy as np
import torch
import torchaudio as ta
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
from torch.utils import data
from pytorch_lightning import LightningDataModule
from torch.utils.data import Dataset, DataLoader, IterableDataset
def from_datasets(
train_dataset: Optional[Union[Dataset, Sequence[Dataset], Mapping[str, Dataset]]] = None,
val_dataset: Optional[Union[Dataset, Sequence[Dataset]]] = None,
test_dataset: Optional[Union[Dataset, Sequence[Dataset]]] = None,
predict_dataset: Optional[Union[Dataset, Sequence[Dataset]]] = None,
batch_size: int = 1,
num_workers: int = 0,
**datamodule_kwargs: Any,
) -> "LightningDataModule":
def dataloader(ds: Dataset, shuffle: bool = False) -> DataLoader:
shuffle &= not isinstance(ds, IterableDataset)
return DataLoader(
ds,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
pin_memory=True,
prefetch_factor=4,
persistent_workers=True,
)
def train_dataloader() -> TRAIN_DATALOADERS:
assert train_dataset
if isinstance(train_dataset, Mapping):
return {key: dataloader(ds, shuffle=True) for key, ds in train_dataset.items()}
if isinstance(train_dataset, Sequence):
return [dataloader(ds, shuffle=True) for ds in train_dataset]
return dataloader(train_dataset, shuffle=True)
def val_dataloader() -> EVAL_DATALOADERS:
assert val_dataset
if isinstance(val_dataset, Sequence):
return [dataloader(ds) for ds in val_dataset]
return dataloader(val_dataset)
def test_dataloader() -> EVAL_DATALOADERS:
assert test_dataset
if isinstance(test_dataset, Sequence):
return [dataloader(ds) for ds in test_dataset]
return dataloader(test_dataset)
def predict_dataloader() -> EVAL_DATALOADERS:
assert predict_dataset
if isinstance(predict_dataset, Sequence):
return [dataloader(ds) for ds in predict_dataset]
return dataloader(predict_dataset)
candidate_kwargs = {"batch_size": batch_size, "num_workers": num_workers}
accepted_params = inspect.signature(LightningDataModule.__init__).parameters
accepts_kwargs = any(param.kind == param.VAR_KEYWORD for param in accepted_params.values())
if accepts_kwargs:
special_kwargs = candidate_kwargs
else:
accepted_param_names = set(accepted_params)
accepted_param_names.discard("self")
special_kwargs = {k: v for k, v in candidate_kwargs.items() if k in accepted_param_names}
datamodule = LightningDataModule(**datamodule_kwargs, **special_kwargs)
if train_dataset is not None:
datamodule.train_dataloader = train_dataloader # type: ignore[method-assign]
if val_dataset is not None:
datamodule.val_dataloader = val_dataloader # type: ignore[method-assign]
if test_dataset is not None:
datamodule.test_dataloader = test_dataloader # type: ignore[method-assign]
if predict_dataset is not None:
datamodule.predict_dataloader = predict_dataloader # type: ignore[method-assign]
return datamodule
class BaseSourceSeparationDataset(data.Dataset, ABC):
def __init__(
self,
split: str,
stems: List[str],
files: List[str],
data_path: str,
fs: int,
npy_memmap: bool,
recompute_mixture: bool,
):
if "mixture" not in stems:
stems = ["mixture"] + stems
self.split = split
self.stems = stems
self.stems_no_mixture = [s for s in stems if s != "mixture"]
self.files = files
self.data_path = data_path
self.fs = fs
self.npy_memmap = npy_memmap
self.recompute_mixture = recompute_mixture
@abstractmethod
def get_stem(self, *, stem: str, identifier: Dict[str, Any]) -> torch.Tensor:
raise NotImplementedError
def _get_audio(self, stems, identifier: Dict[str, Any]):
audio = {}
for stem in stems:
audio[stem] = self.get_stem(stem=stem, identifier=identifier)
return audio
def get_audio(self, identifier: Dict[str, Any]):
if self.recompute_mixture:
audio = self._get_audio(self.stems_no_mixture, identifier=identifier)
audio["mixture"] = self.compute_mixture(audio)
return audio
else:
return self._get_audio(self.stems, identifier=identifier)
@abstractmethod
def get_identifier(self, index: int) -> Dict[str, Any]:
pass
def compute_mixture(self, audio) -> torch.Tensor:
return sum(audio[stem] for stem in audio if stem != "mixture")
|