|
|
import pathlib |
|
|
from typing import Any, Callable, Dict, Iterable, List, Optional, Union |
|
|
|
|
|
import torch |
|
|
from tqdm.auto import tqdm |
|
|
|
|
|
from finetrainers.logging import get_logger |
|
|
from finetrainers.utils import delete_files |
|
|
|
|
|
|
|
|
logger = get_logger() |
|
|
|
|
|
PRECOMPUTED_DATA_DIR = "finetrainers-precomputed-data" |
|
|
|
|
|
|
|
|
def initialize_preprocessor( |
|
|
rank: int, |
|
|
world_size: int, |
|
|
num_items: int, |
|
|
processor_fn: Dict[str, Callable[[Dict[str, Any]], Dict[str, Any]]], |
|
|
save_dir: Optional[str] = None, |
|
|
enable_precomputation: bool = False, |
|
|
enable_reuse: bool = False, |
|
|
) -> Union["InMemoryDistributedDataPreprocessor", "PrecomputedDistributedDataPreprocessor"]: |
|
|
if enable_precomputation: |
|
|
return PrecomputedDistributedDataPreprocessor( |
|
|
rank, world_size, num_items, processor_fn, save_dir, enable_reuse |
|
|
) |
|
|
return InMemoryDistributedDataPreprocessor(rank, num_items, processor_fn) |
|
|
|
|
|
|
|
|
class DistributedDataProcessorMixin: |
|
|
def consume(self, *args, **kwargs): |
|
|
raise NotImplementedError("DistributedDataProcessorMixin::consume must be implemented by the subclass.") |
|
|
|
|
|
def consume_once(self, *args, **kwargs): |
|
|
raise NotImplementedError("DistributedDataProcessorMixin::consume_once must be implemented by the subclass.") |
|
|
|
|
|
@property |
|
|
def requires_data(self): |
|
|
raise NotImplementedError("DistributedDataProcessorMixin::requires_data must be implemented by the subclass.") |
|
|
|
|
|
|
|
|
class InMemoryDistributedDataPreprocessor(DistributedDataProcessorMixin): |
|
|
def __init__( |
|
|
self, rank: int, num_items: int, processor_fn: Dict[str, Callable[[Dict[str, Any]], Dict[str, Any]]] |
|
|
) -> None: |
|
|
super().__init__() |
|
|
|
|
|
self._rank = rank |
|
|
self._num_items = num_items |
|
|
self._processor_fn = processor_fn |
|
|
|
|
|
self._cached_samples = [] |
|
|
self._buffer = InMemoryDataBuffer(num_items) |
|
|
self._preprocessed_iterator: Union["InMemoryDataIterable", "InMemoryOnceDataIterable"] = None |
|
|
|
|
|
def consume( |
|
|
self, |
|
|
data_type: str, |
|
|
components: Dict[str, Any], |
|
|
data_iterator, |
|
|
generator: Optional[torch.Generator] = None, |
|
|
cache_samples: bool = False, |
|
|
use_cached_samples: bool = False, |
|
|
drop_samples: bool = False, |
|
|
) -> Iterable[Dict[str, Any]]: |
|
|
if data_type not in self._processor_fn.keys(): |
|
|
raise ValueError(f"Invalid data type: {data_type}. Supported types: {list(self._processor_fn.keys())}") |
|
|
if cache_samples: |
|
|
if use_cached_samples: |
|
|
raise ValueError("Cannot cache and use cached samples at the same time.") |
|
|
if drop_samples: |
|
|
raise ValueError("Cannot cache and drop samples at the same time.") |
|
|
|
|
|
for i in range(self._num_items): |
|
|
if use_cached_samples: |
|
|
item = self._cached_samples[i] |
|
|
else: |
|
|
item = next(data_iterator) |
|
|
if cache_samples: |
|
|
self._cached_samples.append(item) |
|
|
item = self._processor_fn[data_type](**item, **components, generator=generator) |
|
|
self._buffer.add(data_type, item) |
|
|
|
|
|
if drop_samples: |
|
|
del self._cached_samples |
|
|
self._cached_samples = [] |
|
|
|
|
|
self._preprocessed_iterator = InMemoryDataIterable(self._rank, data_type, self._buffer) |
|
|
return iter(self._preprocessed_iterator) |
|
|
|
|
|
def consume_once( |
|
|
self, |
|
|
data_type: str, |
|
|
components: Dict[str, Any], |
|
|
data_iterator, |
|
|
generator: Optional[torch.Generator] = None, |
|
|
cache_samples: bool = False, |
|
|
use_cached_samples: bool = False, |
|
|
drop_samples: bool = False, |
|
|
) -> Iterable[Dict[str, Any]]: |
|
|
if data_type not in self._processor_fn.keys(): |
|
|
raise ValueError(f"Invalid data type: {data_type}. Supported types: {list(self._processor_fn.keys())}") |
|
|
if cache_samples: |
|
|
if use_cached_samples: |
|
|
raise ValueError("Cannot cache and use cached samples at the same time.") |
|
|
if drop_samples: |
|
|
raise ValueError("Cannot cache and drop samples at the same time.") |
|
|
|
|
|
for i in range(self._num_items): |
|
|
if use_cached_samples: |
|
|
item = self._cached_samples[i] |
|
|
else: |
|
|
item = next(data_iterator) |
|
|
if cache_samples: |
|
|
self._cached_samples.append(item) |
|
|
item = self._processor_fn[data_type](**item, **components, generator=generator) |
|
|
self._buffer.add(data_type, item) |
|
|
|
|
|
if drop_samples: |
|
|
del self._cached_samples |
|
|
self._cached_samples = [] |
|
|
|
|
|
self._preprocessed_iterator = InMemoryOnceDataIterable(self._rank, data_type, self._buffer) |
|
|
return iter(self._preprocessed_iterator) |
|
|
|
|
|
@property |
|
|
def requires_data(self): |
|
|
if self._preprocessed_iterator is None: |
|
|
return True |
|
|
return self._preprocessed_iterator.requires_data |
|
|
|
|
|
|
|
|
class PrecomputedDistributedDataPreprocessor(DistributedDataProcessorMixin): |
|
|
def __init__( |
|
|
self, |
|
|
rank: int, |
|
|
world_size: int, |
|
|
num_items: int, |
|
|
processor_fn: Dict[str, Callable[[Dict[str, Any]], Dict[str, Any]]], |
|
|
save_dir: str, |
|
|
enable_reuse: bool = False, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
|
|
|
self._rank = rank |
|
|
self._world_size = world_size |
|
|
self._num_items = num_items |
|
|
self._processor_fn = processor_fn |
|
|
self._save_dir = pathlib.Path(save_dir) / PRECOMPUTED_DATA_DIR |
|
|
self._enable_reuse = enable_reuse |
|
|
|
|
|
self._cached_samples = [] |
|
|
self._preprocessed_iterator: Union["PrecomputedDataIterable", "PrecomputedOnceDataIterable"] = None |
|
|
|
|
|
if enable_reuse: |
|
|
if not self._save_dir.exists() or not self._save_dir.is_dir(): |
|
|
raise RuntimeError( |
|
|
f"The directory '{self._save_dir}' does not exist or is not a directory, but is required when enabling reuse of precomputed data." |
|
|
) |
|
|
logger.info(f"Reusing precomputed data from {self._save_dir}.") |
|
|
else: |
|
|
subdirectories = [] if not self._save_dir.exists() else [f for f in self._save_dir.iterdir() if f.is_dir()] |
|
|
if len(subdirectories) > 0: |
|
|
raise RuntimeError( |
|
|
"The current directory contains subdirectories other than the saved precomputed files. Please remove them or enable precomputation reuse." |
|
|
) |
|
|
|
|
|
|
|
|
delete_files([self._save_dir]) |
|
|
self._save_dir.mkdir(parents=True, exist_ok=True) |
|
|
logger.info(f"Cleaned up any existing precomputed data in {self._save_dir} and created a fresh directory.") |
|
|
|
|
|
def consume( |
|
|
self, |
|
|
data_type: str, |
|
|
components: Dict[str, Any], |
|
|
data_iterator, |
|
|
generator: Optional[torch.Generator] = None, |
|
|
cache_samples: bool = False, |
|
|
use_cached_samples: bool = False, |
|
|
drop_samples: bool = False, |
|
|
) -> Iterable[Dict[str, Any]]: |
|
|
if data_type not in self._processor_fn.keys(): |
|
|
raise ValueError(f"Invalid data type: {data_type}. Supported types: {list(self._processor_fn.keys())}") |
|
|
if cache_samples: |
|
|
if use_cached_samples: |
|
|
raise ValueError("Cannot cache and use cached samples at the same time.") |
|
|
if drop_samples: |
|
|
raise ValueError("Cannot cache and drop samples at the same time.") |
|
|
|
|
|
if not self._enable_reuse: |
|
|
for i in tqdm(range(self._num_items), desc=f"Rank {self._rank}", total=self._num_items): |
|
|
if use_cached_samples: |
|
|
item = self._cached_samples[i] |
|
|
else: |
|
|
item = next(data_iterator) |
|
|
if cache_samples: |
|
|
self._cached_samples.append(item) |
|
|
item = self._processor_fn[data_type](**item, **components, generator=generator) |
|
|
index = self._rank * self._num_items + i |
|
|
_save_item(item, index, self._save_dir, data_type) |
|
|
|
|
|
if drop_samples: |
|
|
del self._cached_samples |
|
|
self._cached_samples = [] |
|
|
|
|
|
if self._enable_reuse: |
|
|
data_iterator = PrecomputedOnceDataIterable(self._rank, self._world_size, self._save_dir, data_type) |
|
|
else: |
|
|
data_iterator = PrecomputedDataIterable(self._rank, self._world_size, self._save_dir, data_type) |
|
|
self._preprocessed_iterator = data_iterator |
|
|
return iter(data_iterator) |
|
|
|
|
|
def consume_once( |
|
|
self, |
|
|
data_type: str, |
|
|
components: Dict[str, Any], |
|
|
data_iterator, |
|
|
generator: Optional[torch.Generator] = None, |
|
|
cache_samples: bool = False, |
|
|
use_cached_samples: bool = False, |
|
|
drop_samples: bool = False, |
|
|
) -> Iterable[Dict[str, Any]]: |
|
|
if data_type not in self._processor_fn.keys(): |
|
|
raise ValueError(f"Invalid data type: {data_type}. Supported types: {list(self._processor_fn.keys())}") |
|
|
if cache_samples: |
|
|
if use_cached_samples: |
|
|
raise ValueError("Cannot cache and use cached samples at the same time.") |
|
|
if drop_samples: |
|
|
raise ValueError("Cannot cache and drop samples at the same time.") |
|
|
|
|
|
if not self._enable_reuse: |
|
|
for i in tqdm(range(self._num_items), desc=f"Processing data on rank {self._rank}", total=self._num_items): |
|
|
if use_cached_samples: |
|
|
item = self._cached_samples[i] |
|
|
else: |
|
|
item = next(data_iterator) |
|
|
if cache_samples: |
|
|
self._cached_samples.append(item) |
|
|
item = self._processor_fn[data_type](**item, **components, generator=generator) |
|
|
index = self._rank * self._num_items + i |
|
|
_save_item(item, index, self._save_dir, data_type) |
|
|
|
|
|
if drop_samples: |
|
|
del self._cached_samples |
|
|
self._cached_samples = [] |
|
|
|
|
|
self._preprocessed_iterator = PrecomputedOnceDataIterable( |
|
|
self._rank, self._world_size, self._save_dir, data_type |
|
|
) |
|
|
return iter(self._preprocessed_iterator) |
|
|
|
|
|
@property |
|
|
def requires_data(self): |
|
|
if self._preprocessed_iterator is None: |
|
|
return True |
|
|
return self._preprocessed_iterator.requires_data |
|
|
|
|
|
|
|
|
class InMemoryDataIterable: |
|
|
""" |
|
|
An iterator that loads data items from an in-memory buffer. Once all the data is consumed, |
|
|
`requires_data` is set to True, indicating that the more data is required and the preprocessor's |
|
|
consume method should be called again. |
|
|
""" |
|
|
|
|
|
def __init__(self, rank: int, data_type: str, buffer: "InMemoryDataBuffer") -> None: |
|
|
self._rank = rank |
|
|
self._data_type = data_type |
|
|
self._buffer = buffer |
|
|
|
|
|
self._requires_data = False |
|
|
|
|
|
def __iter__(self) -> Iterable[Dict[str, Any]]: |
|
|
while (length := self._buffer.get_length(self._data_type)) > 0: |
|
|
if length <= 1: |
|
|
self._requires_data = True |
|
|
yield self._buffer.get(self._data_type) |
|
|
|
|
|
def __len__(self) -> int: |
|
|
return self._buffer.get_length(self._data_type) |
|
|
|
|
|
@property |
|
|
def requires_data(self): |
|
|
return self._requires_data |
|
|
|
|
|
|
|
|
class InMemoryOnceDataIterable: |
|
|
""" |
|
|
An iterator that loads data items from an in-memory buffer. This iterator will never set |
|
|
`requires_data` to True, as it is assumed that all the data was configured to be preprocessed |
|
|
by the user. The data will indefinitely be cycled from the buffer. |
|
|
""" |
|
|
|
|
|
def __init__(self, rank: int, data_type: str, buffer: "InMemoryDataBuffer") -> None: |
|
|
self._rank = rank |
|
|
self._data_type = data_type |
|
|
self._buffer = buffer |
|
|
|
|
|
self._requires_data = False |
|
|
|
|
|
def __iter__(self) -> Iterable[Dict[str, Any]]: |
|
|
assert len(self) > 0, "No data available in the buffer." |
|
|
while True: |
|
|
item = self._buffer.get(self._data_type) |
|
|
yield item |
|
|
self._buffer.add(self._data_type, item) |
|
|
|
|
|
def __len__(self) -> int: |
|
|
return self._buffer.get_length(self._data_type) |
|
|
|
|
|
@property |
|
|
def requires_data(self): |
|
|
return self._requires_data |
|
|
|
|
|
|
|
|
class PrecomputedDataIterable: |
|
|
""" |
|
|
An iterator that loads preconfigured number of data items from disk. Once all the data is |
|
|
loaded, `requires_data` is set to True, indicating that the more data is required and |
|
|
the preprocessor's consume method should be called again. |
|
|
""" |
|
|
|
|
|
def __init__(self, rank: int, world_size: int, save_dir: str, data_type: str) -> None: |
|
|
self._rank = rank |
|
|
self._world_size = world_size |
|
|
self._save_dir = pathlib.Path(save_dir) |
|
|
self._data_type = data_type |
|
|
self._requires_data = False |
|
|
|
|
|
self._num_items = len(list(self._save_dir.glob(f"{data_type}-*.pt"))) |
|
|
|
|
|
def __iter__(self) -> Iterable[Dict[str, Any]]: |
|
|
map_location = torch.device(self._rank) |
|
|
for i in range(self._num_items): |
|
|
if i == self._num_items - 1: |
|
|
self._requires_data = True |
|
|
index = self._rank * self._num_items + i |
|
|
yield _load_item(index, self._save_dir, self._data_type, map_location) |
|
|
|
|
|
def __len__(self) -> int: |
|
|
return self._num_items |
|
|
|
|
|
@property |
|
|
def requires_data(self): |
|
|
return self._requires_data |
|
|
|
|
|
|
|
|
class PrecomputedOnceDataIterable: |
|
|
""" |
|
|
An infinite iterator that loads preprocessed data from disk. Once initialized, this iterator |
|
|
will never set `requires_data` to True, as it is assumed that all the data was configured to |
|
|
be preprocessed by the user. |
|
|
""" |
|
|
|
|
|
def __init__(self, rank: int, world_size: int, save_dir: str, data_type: str) -> None: |
|
|
self._rank = rank |
|
|
self._world_size = world_size |
|
|
self._save_dir = pathlib.Path(save_dir) |
|
|
self._data_type = data_type |
|
|
self._requires_data = False |
|
|
|
|
|
self._num_items = len(list(self._save_dir.glob(f"{data_type}-*.pt"))) |
|
|
if self._num_items <= self._rank: |
|
|
raise ValueError( |
|
|
f"Precomputed data directory is empty or does not contain enough items (required {self._rank + 1}, found {self._num_items})." |
|
|
) |
|
|
self._num_items_per_rank = max(1, self._num_items // world_size) |
|
|
|
|
|
def __iter__(self) -> Iterable[Dict[str, Any]]: |
|
|
map_location = torch.device(self._rank) |
|
|
i = 0 |
|
|
while True: |
|
|
index = self._rank * self._num_items_per_rank + i |
|
|
yield _load_item(index, self._save_dir, self._data_type, map_location) |
|
|
i = (i + 1) % self._num_items_per_rank |
|
|
|
|
|
def __len__(self) -> int: |
|
|
return self._num_items_per_rank |
|
|
|
|
|
@property |
|
|
def requires_data(self): |
|
|
return self._requires_data |
|
|
|
|
|
|
|
|
class InMemoryDataBuffer: |
|
|
def __init__(self, max_limit: int = -1) -> None: |
|
|
self.max_limit = max_limit |
|
|
self.buffer: Dict[str, List[str]] = {} |
|
|
|
|
|
def add(self, data_type: str, item: Dict[str, Any]) -> None: |
|
|
if data_type not in self.buffer: |
|
|
self.buffer[data_type] = [] |
|
|
if self.max_limit != -1 and len(self.buffer[data_type]) >= self.max_limit: |
|
|
logger.log_freq( |
|
|
"WARN", |
|
|
"IN_MEMORY_DATA_BUFFER_FULL", |
|
|
"Buffer is full. Dropping the oldest item. This message will be logged every 64th time this happens.", |
|
|
64, |
|
|
) |
|
|
self.buffer[data_type].pop(0) |
|
|
self.buffer[data_type].append(item) |
|
|
|
|
|
def get(self, data_type: str) -> Dict[str, Any]: |
|
|
return self.buffer[data_type].pop(0) |
|
|
|
|
|
def get_length(self, data_type: str) -> int: |
|
|
return len(self.buffer[data_type]) |
|
|
|
|
|
|
|
|
def _save_item(item: Dict[str, Any], index: int, directory: pathlib.Path, data_type: str) -> None: |
|
|
filename = directory / f"{data_type}-{index}.pt" |
|
|
torch.save(item, filename.as_posix()) |
|
|
|
|
|
|
|
|
def _load_item(index: int, directory: pathlib.Path, data_type: str, map_location=None) -> Dict[str, Any]: |
|
|
filename = directory / f"{data_type}-{index}.pt" |
|
|
return torch.load(filename.as_posix(), map_location=map_location, weights_only=True) |
|
|
|