Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from typing import Callable | |
| from src.diffusion.base.training import * | |
| from src.diffusion.base.scheduling import BaseScheduler | |
| import concurrent.futures | |
| from concurrent.futures import ProcessPoolExecutor | |
| from typing import List | |
| from PIL import Image | |
| import torch | |
| import random | |
| import numpy as np | |
| import copy | |
| import torchvision.transforms.functional as tvtf | |
| from src.models.vae import uint82fp | |
| def center_crop_arr(pil_image, width, height): | |
| """ | |
| Center cropping implementation from ADM. | |
| https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 | |
| """ | |
| while pil_image.size[0] >= 2 * width and pil_image.size[1] >= 2 * height: | |
| pil_image = pil_image.resize( | |
| tuple(x // 2 for x in pil_image.size), resample=Image.BOX | |
| ) | |
| scale = max(width / pil_image.size[0], height / pil_image.size[1]) | |
| pil_image = pil_image.resize( | |
| tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC | |
| ) | |
| arr = np.array(pil_image) | |
| crop_y = random.randint(0, (arr.shape[0] - height)) | |
| crop_x = random.randint(0, (arr.shape[1] - width)) | |
| return Image.fromarray(arr[crop_y: crop_y + height, crop_x: crop_x + width]) | |
| def process_fn(width, height, data, hflip=0.5): | |
| image, label = data | |
| if random.uniform(0, 1) > hflip: # hflip | |
| image = tvtf.hflip(image) | |
| image = center_crop_arr(image, width, height) # crop | |
| image = np.array(image).transpose(2, 0, 1) | |
| return image, label | |
| class VARCandidate: | |
| def __init__(self, aspect_ratio, width, height, buffer, max_buffer_size=1024): | |
| self.aspect_ratio = aspect_ratio | |
| self.width = int(width) | |
| self.height = int(height) | |
| self.buffer = buffer | |
| self.max_buffer_size = max_buffer_size | |
| def add_sample(self, data): | |
| self.buffer.append(data) | |
| self.buffer = self.buffer[-self.max_buffer_size:] | |
| def ready(self, batch_size): | |
| return len(self.buffer) >= batch_size | |
| def get_batch(self, batch_size): | |
| batch = self.buffer[:batch_size] | |
| self.buffer = self.buffer[batch_size:] | |
| batch = [copy.deepcopy(b.result()) for b in batch] | |
| x, y = zip(*batch) | |
| x = torch.stack([torch.from_numpy(im).cuda() for im in x], dim=0) | |
| x = list(map(uint82fp, x)) | |
| return x, y | |
| class VARTransformEngine: | |
| def __init__(self, | |
| base_image_size, | |
| num_aspect_ratios, | |
| min_aspect_ratio, | |
| max_aspect_ratio, | |
| num_workers = 8, | |
| ): | |
| self.base_image_size = base_image_size | |
| self.num_aspect_ratios = num_aspect_ratios | |
| self.min_aspect_ratio = min_aspect_ratio | |
| self.max_aspect_ratio = max_aspect_ratio | |
| self.aspect_ratios = np.linspace(self.min_aspect_ratio, self.max_aspect_ratio, self.num_aspect_ratios) | |
| self.aspect_ratios = self.aspect_ratios.tolist() | |
| self.candidates_pool = [] | |
| for i in range(self.num_aspect_ratios): | |
| candidate = VARCandidate( | |
| aspect_ratio=self.aspect_ratios[i], | |
| width=int(self.base_image_size * self.aspect_ratios[i] ** 0.5 // 16 * 16), | |
| height=int(self.base_image_size * self.aspect_ratios[i] ** -0.5 // 16 * 16), | |
| buffer=[], | |
| max_buffer_size=1024 | |
| ) | |
| self.candidates_pool.append(candidate) | |
| self.default_candidate = VARCandidate( | |
| aspect_ratio=1.0, | |
| width=self.base_image_size, | |
| height=self.base_image_size, | |
| buffer=[], | |
| max_buffer_size=1024, | |
| ) | |
| self.executor_pool = ProcessPoolExecutor(max_workers=num_workers) | |
| self._prefill_count = 100 | |
| def find_candidate(self, data): | |
| image = data[0] | |
| aspect_ratio = image.size[0] / image.size[1] | |
| min_distance = 1000000 | |
| min_candidate = None | |
| for candidate in self.candidates_pool: | |
| dis = abs(aspect_ratio - candidate.aspect_ratio) | |
| if dis < min_distance: | |
| min_distance = dis | |
| min_candidate = candidate | |
| return min_candidate | |
| def __call__(self, batch_data): | |
| self._prefill_count -= 1 | |
| if isinstance(batch_data[0], torch.Tensor): | |
| batch_data[0] = batch_data[0].unbind(0) | |
| batch_data = list(zip(*batch_data)) | |
| for data in batch_data: | |
| candidate = self.find_candidate(data) | |
| future = self.executor_pool.submit(process_fn, candidate.width, candidate.height, data) | |
| candidate.add_sample(future) | |
| if self._prefill_count >= 0: | |
| future = self.executor_pool.submit(process_fn, | |
| self.default_candidate.width, | |
| self.default_candidate.height, | |
| data) | |
| self.default_candidate.add_sample(future) | |
| batch_size = len(batch_data) | |
| random.shuffle(self.candidates_pool) | |
| for candidate in self.candidates_pool: | |
| if candidate.ready(batch_size=batch_size): | |
| return candidate.get_batch(batch_size=batch_size) | |
| # fallback to default 256 | |
| for data in batch_data: | |
| future = self.executor_pool.submit(process_fn, | |
| self.default_candidate.width, | |
| self.default_candidate.height, | |
| data) | |
| self.default_candidate.add_sample(future) | |
| return self.default_candidate.get_batch(batch_size=batch_size) |