File size: 4,945 Bytes
d572f56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import inspect
import os
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union

import numpy as np
import torch
import torchaudio as ta
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
from torch.utils import data

from pytorch_lightning import LightningDataModule
from torch.utils.data import Dataset, DataLoader, IterableDataset


def from_datasets(
    train_dataset: Optional[Union[Dataset, Sequence[Dataset], Mapping[str, Dataset]]] = None,
    val_dataset: Optional[Union[Dataset, Sequence[Dataset]]] = None,
    test_dataset: Optional[Union[Dataset, Sequence[Dataset]]] = None,
    predict_dataset: Optional[Union[Dataset, Sequence[Dataset]]] = None,
    batch_size: int = 1,
    num_workers: int = 0,
    **datamodule_kwargs: Any,
) -> "LightningDataModule":

    def dataloader(ds: Dataset, shuffle: bool = False) -> DataLoader:
        shuffle &= not isinstance(ds, IterableDataset)
        return DataLoader(
            ds,
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=num_workers,
            pin_memory=True,
            prefetch_factor=4,
            persistent_workers=True,
        )

    def train_dataloader() -> TRAIN_DATALOADERS:
        assert train_dataset

        if isinstance(train_dataset, Mapping):
            return {key: dataloader(ds, shuffle=True) for key, ds in train_dataset.items()}
        if isinstance(train_dataset, Sequence):
            return [dataloader(ds, shuffle=True) for ds in train_dataset]
        return dataloader(train_dataset, shuffle=True)

    def val_dataloader() -> EVAL_DATALOADERS:
        assert val_dataset

        if isinstance(val_dataset, Sequence):
            return [dataloader(ds) for ds in val_dataset]
        return dataloader(val_dataset)

    def test_dataloader() -> EVAL_DATALOADERS:
        assert test_dataset

        if isinstance(test_dataset, Sequence):
            return [dataloader(ds) for ds in test_dataset]
        return dataloader(test_dataset)

    def predict_dataloader() -> EVAL_DATALOADERS:
        assert predict_dataset

        if isinstance(predict_dataset, Sequence):
            return [dataloader(ds) for ds in predict_dataset]
        return dataloader(predict_dataset)

    candidate_kwargs = {"batch_size": batch_size, "num_workers": num_workers}
    accepted_params = inspect.signature(LightningDataModule.__init__).parameters
    accepts_kwargs = any(param.kind == param.VAR_KEYWORD for param in accepted_params.values())
    if accepts_kwargs:
        special_kwargs = candidate_kwargs
    else:
        accepted_param_names = set(accepted_params)
        accepted_param_names.discard("self")
        special_kwargs = {k: v for k, v in candidate_kwargs.items() if k in accepted_param_names}

    datamodule = LightningDataModule(**datamodule_kwargs, **special_kwargs)
    if train_dataset is not None:
        datamodule.train_dataloader = train_dataloader  # type: ignore[method-assign]
    if val_dataset is not None:
        datamodule.val_dataloader = val_dataloader  # type: ignore[method-assign]
    if test_dataset is not None:
        datamodule.test_dataloader = test_dataloader  # type: ignore[method-assign]
    if predict_dataset is not None:
        datamodule.predict_dataloader = predict_dataloader  # type: ignore[method-assign]

    return datamodule


class BaseSourceSeparationDataset(data.Dataset, ABC):
    def __init__(
        self,
        split: str,
        stems: List[str],
        files: List[str],
        data_path: str,
        fs: int,
        npy_memmap: bool,
        recompute_mixture: bool,
    ):
        if "mixture" not in stems:
            stems = ["mixture"] + stems

        self.split = split
        self.stems = stems
        self.stems_no_mixture = [s for s in stems if s != "mixture"]
        self.files = files
        self.data_path = data_path
        self.fs = fs
        self.npy_memmap = npy_memmap
        self.recompute_mixture = recompute_mixture

    @abstractmethod
    def get_stem(self, *, stem: str, identifier: Dict[str, Any]) -> torch.Tensor:
        raise NotImplementedError

    def _get_audio(self, stems, identifier: Dict[str, Any]):
        audio = {}
        for stem in stems:
            audio[stem] = self.get_stem(stem=stem, identifier=identifier)

        return audio

    def get_audio(self, identifier: Dict[str, Any]):
        if self.recompute_mixture:
            audio = self._get_audio(self.stems_no_mixture, identifier=identifier)
            audio["mixture"] = self.compute_mixture(audio)
            return audio
        else:
            return self._get_audio(self.stems, identifier=identifier)

    @abstractmethod
    def get_identifier(self, index: int) -> Dict[str, Any]:
        pass

    def compute_mixture(self, audio) -> torch.Tensor:
        return sum(audio[stem] for stem in audio if stem != "mixture")