Spaces:
Runtime error
Runtime error
| import random | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Tuple, Union | |
| from PIL import Image | |
| from torch import zeros_like | |
| from torch.utils.data import Dataset | |
| from torchvision import transforms | |
| import glob | |
| from .preprocess_files import face_mask_google_mediapipe | |
| OBJECT_TEMPLATE = [ | |
| "a photo of a {}", | |
| "a rendering of a {}", | |
| "a cropped photo of the {}", | |
| "the photo of a {}", | |
| "a photo of a clean {}", | |
| "a photo of a dirty {}", | |
| "a dark photo of the {}", | |
| "a photo of my {}", | |
| "a photo of the cool {}", | |
| "a close-up photo of a {}", | |
| "a bright photo of the {}", | |
| "a cropped photo of a {}", | |
| "a photo of the {}", | |
| "a good photo of the {}", | |
| "a photo of one {}", | |
| "a close-up photo of the {}", | |
| "a rendition of the {}", | |
| "a photo of the clean {}", | |
| "a rendition of a {}", | |
| "a photo of a nice {}", | |
| "a good photo of a {}", | |
| "a photo of the nice {}", | |
| "a photo of the small {}", | |
| "a photo of the weird {}", | |
| "a photo of the large {}", | |
| "a photo of a cool {}", | |
| "a photo of a small {}", | |
| ] | |
| STYLE_TEMPLATE = [ | |
| "a painting in the style of {}", | |
| "a rendering in the style of {}", | |
| "a cropped painting in the style of {}", | |
| "the painting in the style of {}", | |
| "a clean painting in the style of {}", | |
| "a dirty painting in the style of {}", | |
| "a dark painting in the style of {}", | |
| "a picture in the style of {}", | |
| "a cool painting in the style of {}", | |
| "a close-up painting in the style of {}", | |
| "a bright painting in the style of {}", | |
| "a cropped painting in the style of {}", | |
| "a good painting in the style of {}", | |
| "a close-up painting in the style of {}", | |
| "a rendition in the style of {}", | |
| "a nice painting in the style of {}", | |
| "a small painting in the style of {}", | |
| "a weird painting in the style of {}", | |
| "a large painting in the style of {}", | |
| ] | |
| NULL_TEMPLATE = ["{}"] | |
| TEMPLATE_MAP = { | |
| "object": OBJECT_TEMPLATE, | |
| "style": STYLE_TEMPLATE, | |
| "null": NULL_TEMPLATE, | |
| } | |
| def _randomset(lis): | |
| ret = [] | |
| for i in range(len(lis)): | |
| if random.random() < 0.5: | |
| ret.append(lis[i]) | |
| return ret | |
| def _shuffle(lis): | |
| return random.sample(lis, len(lis)) | |
| def _get_cutout_holes( | |
| height, | |
| width, | |
| min_holes=8, | |
| max_holes=32, | |
| min_height=16, | |
| max_height=128, | |
| min_width=16, | |
| max_width=128, | |
| ): | |
| holes = [] | |
| for _n in range(random.randint(min_holes, max_holes)): | |
| hole_height = random.randint(min_height, max_height) | |
| hole_width = random.randint(min_width, max_width) | |
| y1 = random.randint(0, height - hole_height) | |
| x1 = random.randint(0, width - hole_width) | |
| y2 = y1 + hole_height | |
| x2 = x1 + hole_width | |
| holes.append((x1, y1, x2, y2)) | |
| return holes | |
| def _generate_random_mask(image): | |
| mask = zeros_like(image[:1]) | |
| holes = _get_cutout_holes(mask.shape[1], mask.shape[2]) | |
| for (x1, y1, x2, y2) in holes: | |
| mask[:, y1:y2, x1:x2] = 1.0 | |
| if random.uniform(0, 1) < 0.25: | |
| mask.fill_(1.0) | |
| masked_image = image * (mask < 0.5) | |
| return mask, masked_image | |
| class PivotalTuningDatasetCapation(Dataset): | |
| """ | |
| A dataset to prepare the instance and class images with the prompts for fine-tuning the model. | |
| It pre-processes the images and the tokenizes prompts. | |
| """ | |
| def __init__( | |
| self, | |
| instance_data_root, | |
| tokenizer, | |
| token_map: Optional[dict] = None, | |
| use_template: Optional[str] = None, | |
| size=512, | |
| h_flip=True, | |
| color_jitter=False, | |
| resize=True, | |
| use_mask_captioned_data=False, | |
| use_face_segmentation_condition=False, | |
| train_inpainting=False, | |
| blur_amount: int = 70, | |
| ): | |
| self.size = size | |
| self.tokenizer = tokenizer | |
| self.resize = resize | |
| self.train_inpainting = train_inpainting | |
| instance_data_root = Path(instance_data_root) | |
| if not instance_data_root.exists(): | |
| raise ValueError("Instance images root doesn't exists.") | |
| self.instance_images_path = [] | |
| self.mask_path = [] | |
| assert not ( | |
| use_mask_captioned_data and use_template | |
| ), "Can't use both mask caption data and template." | |
| # Prepare the instance images | |
| if use_mask_captioned_data: | |
| src_imgs = glob.glob(str(instance_data_root) + "/*src.jpg") | |
| for f in src_imgs: | |
| idx = int(str(Path(f).stem).split(".")[0]) | |
| mask_path = f"{instance_data_root}/{idx}.mask.png" | |
| if Path(mask_path).exists(): | |
| self.instance_images_path.append(f) | |
| self.mask_path.append(mask_path) | |
| else: | |
| print(f"Mask not found for {f}") | |
| self.captions = open(f"{instance_data_root}/caption.txt").readlines() | |
| else: | |
| possibily_src_images = ( | |
| glob.glob(str(instance_data_root) + "/*.jpg") | |
| + glob.glob(str(instance_data_root) + "/*.png") | |
| + glob.glob(str(instance_data_root) + "/*.jpeg") | |
| ) | |
| possibily_src_images = ( | |
| set(possibily_src_images) | |
| - set(glob.glob(str(instance_data_root) + "/*mask.png")) | |
| - set([str(instance_data_root) + "/caption.txt"]) | |
| ) | |
| self.instance_images_path = list(set(possibily_src_images)) | |
| self.captions = [ | |
| x.split("/")[-1].split(".")[0] for x in self.instance_images_path | |
| ] | |
| assert ( | |
| len(self.instance_images_path) > 0 | |
| ), "No images found in the instance data root." | |
| self.instance_images_path = sorted(self.instance_images_path) | |
| self.use_mask = use_face_segmentation_condition or use_mask_captioned_data | |
| self.use_mask_captioned_data = use_mask_captioned_data | |
| if use_face_segmentation_condition: | |
| for idx in range(len(self.instance_images_path)): | |
| targ = f"{instance_data_root}/{idx}.mask.png" | |
| # see if the mask exists | |
| if not Path(targ).exists(): | |
| print(f"Mask not found for {targ}") | |
| print( | |
| "Warning : this will pre-process all the images in the instance data root." | |
| ) | |
| if len(self.mask_path) > 0: | |
| print( | |
| "Warning : masks already exists, but will be overwritten." | |
| ) | |
| masks = face_mask_google_mediapipe( | |
| [ | |
| Image.open(f).convert("RGB") | |
| for f in self.instance_images_path | |
| ] | |
| ) | |
| for idx, mask in enumerate(masks): | |
| mask.save(f"{instance_data_root}/{idx}.mask.png") | |
| break | |
| for idx in range(len(self.instance_images_path)): | |
| self.mask_path.append(f"{instance_data_root}/{idx}.mask.png") | |
| self.num_instance_images = len(self.instance_images_path) | |
| self.token_map = token_map | |
| self.use_template = use_template | |
| if use_template is not None: | |
| self.templates = TEMPLATE_MAP[use_template] | |
| self._length = self.num_instance_images | |
| self.h_flip = h_flip | |
| self.image_transforms = transforms.Compose( | |
| [ | |
| transforms.Resize( | |
| size, interpolation=transforms.InterpolationMode.BILINEAR | |
| ) | |
| if resize | |
| else transforms.Lambda(lambda x: x), | |
| transforms.ColorJitter(0.1, 0.1) | |
| if color_jitter | |
| else transforms.Lambda(lambda x: x), | |
| transforms.CenterCrop(size), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5], [0.5]), | |
| ] | |
| ) | |
| self.blur_amount = blur_amount | |
| def __len__(self): | |
| return self._length | |
| def __getitem__(self, index): | |
| example = {} | |
| instance_image = Image.open( | |
| self.instance_images_path[index % self.num_instance_images] | |
| ) | |
| if not instance_image.mode == "RGB": | |
| instance_image = instance_image.convert("RGB") | |
| example["instance_images"] = self.image_transforms(instance_image) | |
| if self.train_inpainting: | |
| ( | |
| example["instance_masks"], | |
| example["instance_masked_images"], | |
| ) = _generate_random_mask(example["instance_images"]) | |
| if self.use_template: | |
| assert self.token_map is not None | |
| input_tok = list(self.token_map.values())[0] | |
| text = random.choice(self.templates).format(input_tok) | |
| else: | |
| text = self.captions[index % self.num_instance_images].strip() | |
| if self.token_map is not None: | |
| for token, value in self.token_map.items(): | |
| text = text.replace(token, value) | |
| print(text) | |
| if self.use_mask: | |
| example["mask"] = ( | |
| self.image_transforms( | |
| Image.open(self.mask_path[index % self.num_instance_images]) | |
| ) | |
| * 0.5 | |
| + 1.0 | |
| ) | |
| if self.h_flip and random.random() > 0.5: | |
| hflip = transforms.RandomHorizontalFlip(p=1) | |
| example["instance_images"] = hflip(example["instance_images"]) | |
| if self.use_mask: | |
| example["mask"] = hflip(example["mask"]) | |
| example["instance_prompt_ids"] = self.tokenizer( | |
| text, | |
| padding="do_not_pad", | |
| truncation=True, | |
| max_length=self.tokenizer.model_max_length, | |
| ).input_ids | |
| return example | |