Spaces:
Running
Running
| import glob | |
| import json | |
| import os | |
| import pickle | |
| import random | |
| import shutil | |
| import tarfile | |
| from functools import partial | |
| import albumentations | |
| import cv2 | |
| import numpy as np | |
| import PIL | |
| import torchvision.transforms.functional as TF | |
| import yaml | |
| from decord import VideoReader | |
| from func_timeout import FunctionTimedOut, func_set_timeout | |
| from omegaconf import OmegaConf | |
| from PIL import Image | |
| from torch.utils.data import BatchSampler, Dataset, Sampler | |
| from tqdm import tqdm | |
| from ..modules.image_degradation import (degradation_fn_bsr, | |
| degradation_fn_bsr_light) | |
| class ImageVideoSampler(BatchSampler): | |
| """A sampler wrapper for grouping images with similar aspect ratio into a same batch. | |
| Args: | |
| sampler (Sampler): Base sampler. | |
| dataset (Dataset): Dataset providing data information. | |
| batch_size (int): Size of mini-batch. | |
| drop_last (bool): If ``True``, the sampler will drop the last batch if | |
| its size would be less than ``batch_size``. | |
| aspect_ratios (dict): The predefined aspect ratios. | |
| """ | |
| def __init__(self, | |
| sampler: Sampler, | |
| dataset: Dataset, | |
| batch_size: int, | |
| drop_last: bool = False | |
| ) -> None: | |
| if not isinstance(sampler, Sampler): | |
| raise TypeError('sampler should be an instance of ``Sampler``, ' | |
| f'but got {sampler}') | |
| if not isinstance(batch_size, int) or batch_size <= 0: | |
| raise ValueError('batch_size should be a positive integer value, ' | |
| f'but got batch_size={batch_size}') | |
| self.sampler = sampler | |
| self.dataset = dataset | |
| self.batch_size = batch_size | |
| self.drop_last = drop_last | |
| self.sampler_pos_start = 0 | |
| self.sampler_pos_reload = 0 | |
| self.num_samples_random = len(self.sampler) | |
| # buckets for each aspect ratio | |
| self.bucket = {'image':[], 'video':[]} | |
| def set_epoch(self, epoch): | |
| if hasattr(self.sampler, "set_epoch"): | |
| self.sampler.set_epoch(epoch) | |
| def __iter__(self): | |
| for index_sampler, idx in enumerate(self.sampler): | |
| if self.sampler_pos_reload != 0 and self.sampler_pos_reload < self.num_samples_random: | |
| if index_sampler < self.sampler_pos_reload: | |
| self.sampler_pos_start = (self.sampler_pos_start + 1) % self.num_samples_random | |
| continue | |
| elif index_sampler == self.sampler_pos_reload: | |
| self.sampler_pos_reload = 0 | |
| content_type = self.dataset.data.get_type(idx) | |
| bucket = self.bucket[content_type] | |
| bucket.append(idx) | |
| # yield a batch of indices in the same aspect ratio group | |
| if len(self.bucket['video']) == self.batch_size: | |
| yield self.bucket['video'] | |
| self.bucket['video'] = [] | |
| elif len(self.bucket['image']) == self.batch_size: | |
| yield self.bucket['image'] | |
| self.bucket['image'] = [] | |
| self.sampler_pos_start = (self.sampler_pos_start + 1) % self.num_samples_random | |
| class ImageVideoDataset(Dataset): | |
| # update __getitem__() from ImageNetSR. If timeout for Pandas70M, throw exception. | |
| # If caught exception(timeout or others), try another index until successful and return. | |
| def __init__(self, size=None, video_size=128, video_len=25, | |
| degradation=None, downscale_f=4, random_crop=True, min_crop_f=0.25, max_crop_f=1., | |
| s_t=None, slice_interval=None, data_root=None | |
| ): | |
| """ | |
| Imagenet Superresolution Dataloader | |
| Performs following ops in order: | |
| 1. crops a crop of size s from image either as random or center crop | |
| 2. resizes crop to size with cv2.area_interpolation | |
| 3. degrades resized crop with degradation_fn | |
| :param size: resizing to size after cropping | |
| :param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light | |
| :param downscale_f: Low Resolution Downsample factor | |
| :param min_crop_f: determines crop size s, | |
| where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f) | |
| :param max_crop_f: "" | |
| :param data_root: | |
| :param random_crop: | |
| """ | |
| self.base = self.get_base() | |
| assert size | |
| assert (size / downscale_f).is_integer() | |
| self.size = size | |
| self.LR_size = int(size / downscale_f) | |
| self.min_crop_f = min_crop_f | |
| self.max_crop_f = max_crop_f | |
| assert(max_crop_f <= 1.) | |
| self.center_crop = not random_crop | |
| self.s_t = s_t | |
| self.slice_interval = slice_interval | |
| self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA) | |
| self.video_rescaler = albumentations.SmallestMaxSize(max_size=video_size, interpolation=cv2.INTER_AREA) | |
| self.video_len = video_len | |
| self.video_size = video_size | |
| self.data_root = data_root | |
| self.pil_interpolation = False # gets reset later if incase interp_op is from pillow | |
| if degradation == "bsrgan": | |
| self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f) | |
| elif degradation == "bsrgan_light": | |
| self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f) | |
| else: | |
| interpolation_fn = { | |
| "cv_nearest": cv2.INTER_NEAREST, | |
| "cv_bilinear": cv2.INTER_LINEAR, | |
| "cv_bicubic": cv2.INTER_CUBIC, | |
| "cv_area": cv2.INTER_AREA, | |
| "cv_lanczos": cv2.INTER_LANCZOS4, | |
| "pil_nearest": PIL.Image.NEAREST, | |
| "pil_bilinear": PIL.Image.BILINEAR, | |
| "pil_bicubic": PIL.Image.BICUBIC, | |
| "pil_box": PIL.Image.BOX, | |
| "pil_hamming": PIL.Image.HAMMING, | |
| "pil_lanczos": PIL.Image.LANCZOS, | |
| }[degradation] | |
| self.pil_interpolation = degradation.startswith("pil_") | |
| if self.pil_interpolation: | |
| self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn) | |
| else: | |
| self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size, | |
| interpolation=interpolation_fn) | |
| def __len__(self): | |
| return len(self.base) | |
| def get_type(self, index): | |
| return self.base[index].get('type', 'image') | |
| def __getitem__(self, i): | |
| # time wait 3 seconds | |
| def get_video_item(example): | |
| if self.data_root is not None: | |
| video_reader = VideoReader(os.path.join(self.data_root, example['file_path'])) | |
| else: | |
| video_reader = VideoReader(example['file_path']) | |
| video_length = len(video_reader) | |
| if self.slice_interval == "rand": | |
| slice_interval = np.random.choice([1, 2, 3]) | |
| else: | |
| slice_interval = int(self.slice_interval) | |
| clip_length = min(video_length, (self.video_len - 1) * slice_interval + 1) | |
| start_idx = random.randint(0, video_length - clip_length) | |
| batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.video_len, dtype=int) | |
| pixel_values = video_reader.get_batch(batch_index).asnumpy() | |
| del video_reader | |
| out_images = [] | |
| LR_out_images = [] | |
| min_side_len = min(pixel_values[0].shape[:2]) | |
| crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None) | |
| crop_side_len = int(crop_side_len) | |
| if self.center_crop: | |
| self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len) | |
| else: | |
| self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len) | |
| imgs = np.transpose(pixel_values, (1, 2, 3, 0)) | |
| imgs = self.cropper(image=imgs)["image"] | |
| imgs = np.transpose(imgs, (3, 0, 1, 2)) | |
| for img in imgs: | |
| image = self.video_rescaler(image=img)["image"] | |
| out_images.append(image[None, :, :, :]) | |
| if self.pil_interpolation: | |
| image_pil = PIL.Image.fromarray(image) | |
| LR_image = self.degradation_process(image_pil) | |
| LR_image = np.array(LR_image).astype(np.uint8) | |
| else: | |
| LR_image = self.degradation_process(image=image)["image"] | |
| LR_out_images.append(LR_image[None, :, :, :]) | |
| example = {} | |
| example['image'] = (np.concatenate(out_images) / 127.5 - 1.0).astype(np.float32) | |
| example['LR_image'] = (np.concatenate(LR_out_images) / 127.5 - 1.0).astype(np.float32) | |
| return example | |
| example = self.base[i] | |
| if example.get('type', 'image') == 'video': | |
| while True: | |
| try: | |
| example = self.base[i] | |
| return get_video_item(example) | |
| except FunctionTimedOut: | |
| print("stt catch: Function 'extract failed' timed out.") | |
| i = random.randint(0, self.__len__() - 1) | |
| except Exception as e: | |
| print('stt catch', e) | |
| i = random.randint(0, self.__len__() - 1) | |
| elif example.get('type', 'image') == 'image': | |
| while True: | |
| try: | |
| example = self.base[i] | |
| if self.data_root is not None: | |
| image = Image.open(os.path.join(self.data_root, example['file_path'])) | |
| else: | |
| image = Image.open(example['file_path']) | |
| image = image.convert("RGB") | |
| image = np.array(image).astype(np.uint8) | |
| min_side_len = min(image.shape[:2]) | |
| crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None) | |
| crop_side_len = int(crop_side_len) | |
| if self.center_crop: | |
| self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len) | |
| else: | |
| self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len) | |
| image = self.cropper(image=image)["image"] | |
| image = self.image_rescaler(image=image)["image"] | |
| if self.pil_interpolation: | |
| image_pil = PIL.Image.fromarray(image) | |
| LR_image = self.degradation_process(image_pil) | |
| LR_image = np.array(LR_image).astype(np.uint8) | |
| else: | |
| LR_image = self.degradation_process(image=image)["image"] | |
| example = {} | |
| example["image"] = (image/127.5 - 1.0).astype(np.float32) | |
| example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32) | |
| return example | |
| except Exception as e: | |
| print("catch", e) | |
| i = random.randint(0, self.__len__() - 1) | |
| class CustomSRTrain(ImageVideoDataset): | |
| def __init__(self, data_json_path, **kwargs): | |
| self.data_json_path = data_json_path | |
| super().__init__(**kwargs) | |
| def get_base(self): | |
| return [ann for ann in json.load(open(self.data_json_path))] | |
| class CustomSRValidation(ImageVideoDataset): | |
| def __init__(self, data_json_path, **kwargs): | |
| self.data_json_path = data_json_path | |
| super().__init__(**kwargs) | |
| self.data_json_path = data_json_path | |
| def get_base(self): | |
| return [ann for ann in json.load(open(self.data_json_path))][:100] + \ | |
| [ann for ann in json.load(open(self.data_json_path))][-100:] | |