Spaces:
Runtime error
Runtime error
| import inspect | |
| import os | |
| from abc import ABC, abstractmethod | |
| from typing import Any, Dict, List, Mapping, Optional, Sequence, Union | |
| import numpy as np | |
| import torch | |
| import torchaudio as ta | |
| from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS | |
| from torch.utils import data | |
| from pytorch_lightning import LightningDataModule | |
| from torch.utils.data import Dataset, DataLoader, IterableDataset | |
| def from_datasets( | |
| train_dataset: Optional[Union[Dataset, Sequence[Dataset], Mapping[str, Dataset]]] = None, | |
| val_dataset: Optional[Union[Dataset, Sequence[Dataset]]] = None, | |
| test_dataset: Optional[Union[Dataset, Sequence[Dataset]]] = None, | |
| predict_dataset: Optional[Union[Dataset, Sequence[Dataset]]] = None, | |
| batch_size: int = 1, | |
| num_workers: int = 0, | |
| **datamodule_kwargs: Any, | |
| ) -> "LightningDataModule": | |
| def dataloader(ds: Dataset, shuffle: bool = False) -> DataLoader: | |
| shuffle &= not isinstance(ds, IterableDataset) | |
| return DataLoader( | |
| ds, | |
| batch_size=batch_size, | |
| shuffle=shuffle, | |
| num_workers=num_workers, | |
| pin_memory=True, | |
| prefetch_factor=4, | |
| persistent_workers=True, | |
| ) | |
| def train_dataloader() -> TRAIN_DATALOADERS: | |
| assert train_dataset | |
| if isinstance(train_dataset, Mapping): | |
| return {key: dataloader(ds, shuffle=True) for key, ds in train_dataset.items()} | |
| if isinstance(train_dataset, Sequence): | |
| return [dataloader(ds, shuffle=True) for ds in train_dataset] | |
| return dataloader(train_dataset, shuffle=True) | |
| def val_dataloader() -> EVAL_DATALOADERS: | |
| assert val_dataset | |
| if isinstance(val_dataset, Sequence): | |
| return [dataloader(ds) for ds in val_dataset] | |
| return dataloader(val_dataset) | |
| def test_dataloader() -> EVAL_DATALOADERS: | |
| assert test_dataset | |
| if isinstance(test_dataset, Sequence): | |
| return [dataloader(ds) for ds in test_dataset] | |
| return dataloader(test_dataset) | |
| def predict_dataloader() -> EVAL_DATALOADERS: | |
| assert predict_dataset | |
| if isinstance(predict_dataset, Sequence): | |
| return [dataloader(ds) for ds in predict_dataset] | |
| return dataloader(predict_dataset) | |
| candidate_kwargs = {"batch_size": batch_size, "num_workers": num_workers} | |
| accepted_params = inspect.signature(LightningDataModule.__init__).parameters | |
| accepts_kwargs = any(param.kind == param.VAR_KEYWORD for param in accepted_params.values()) | |
| if accepts_kwargs: | |
| special_kwargs = candidate_kwargs | |
| else: | |
| accepted_param_names = set(accepted_params) | |
| accepted_param_names.discard("self") | |
| special_kwargs = {k: v for k, v in candidate_kwargs.items() if k in accepted_param_names} | |
| datamodule = LightningDataModule(**datamodule_kwargs, **special_kwargs) | |
| if train_dataset is not None: | |
| datamodule.train_dataloader = train_dataloader # type: ignore[method-assign] | |
| if val_dataset is not None: | |
| datamodule.val_dataloader = val_dataloader # type: ignore[method-assign] | |
| if test_dataset is not None: | |
| datamodule.test_dataloader = test_dataloader # type: ignore[method-assign] | |
| if predict_dataset is not None: | |
| datamodule.predict_dataloader = predict_dataloader # type: ignore[method-assign] | |
| return datamodule | |
| class BaseSourceSeparationDataset(data.Dataset, ABC): | |
| def __init__( | |
| self, | |
| split: str, | |
| stems: List[str], | |
| files: List[str], | |
| data_path: str, | |
| fs: int, | |
| npy_memmap: bool, | |
| recompute_mixture: bool, | |
| ): | |
| if "mixture" not in stems: | |
| stems = ["mixture"] + stems | |
| self.split = split | |
| self.stems = stems | |
| self.stems_no_mixture = [s for s in stems if s != "mixture"] | |
| self.files = files | |
| self.data_path = data_path | |
| self.fs = fs | |
| self.npy_memmap = npy_memmap | |
| self.recompute_mixture = recompute_mixture | |
| def get_stem(self, *, stem: str, identifier: Dict[str, Any]) -> torch.Tensor: | |
| raise NotImplementedError | |
| def _get_audio(self, stems, identifier: Dict[str, Any]): | |
| audio = {} | |
| for stem in stems: | |
| audio[stem] = self.get_stem(stem=stem, identifier=identifier) | |
| return audio | |
| def get_audio(self, identifier: Dict[str, Any]): | |
| if self.recompute_mixture: | |
| audio = self._get_audio(self.stems_no_mixture, identifier=identifier) | |
| audio["mixture"] = self.compute_mixture(audio) | |
| return audio | |
| else: | |
| return self._get_audio(self.stems, identifier=identifier) | |
| def get_identifier(self, index: int) -> Dict[str, Any]: | |
| pass | |
| def compute_mixture(self, audio) -> torch.Tensor: | |
| return sum(audio[stem] for stem in audio if stem != "mixture") | |