Spaces:
Runtime error
Runtime error
| import copy | |
| import logging | |
| import os | |
| import os.path as osp | |
| from os.path import join | |
| import torch | |
| from torch.utils.data import ConcatDataset, DataLoader | |
| from utils.optimizer import create_optimizer | |
| from utils.scheduler import create_scheduler | |
| logger = logging.getLogger(__name__) | |
| def get_media_types(datasources): | |
| """get the media types for for all the dataloaders. | |
| Args: | |
| datasources (List): List of dataloaders or datasets. | |
| Returns: List. The media_types. | |
| """ | |
| if isinstance(datasources[0], DataLoader): | |
| datasets = [dataloader.dataset for dataloader in datasources] | |
| else: | |
| datasets = datasources | |
| media_types = [ | |
| dataset.datasets[0].media_type | |
| if isinstance(dataset, ConcatDataset) | |
| else dataset.media_type | |
| for dataset in datasets | |
| ] | |
| return media_types | |