Spaces:
Runtime error
Runtime error
| import logging | |
| import os | |
| import os.path | |
| import math | |
| from PIL import Image, ImageDraw | |
| import random | |
| import numpy as np | |
| import torch | |
| import torchvision | |
| import torch.utils.data as data | |
| from pycocotools import mask as coco_mask | |
| from maskrcnn_benchmark.structures.bounding_box import BoxList | |
| from maskrcnn_benchmark.structures.segmentation_mask import SegmentationMask | |
| from maskrcnn_benchmark.data.datasets.coco import has_valid_annotation | |
| from .od_to_grounding import convert_od_to_grounding_simple, check_for_positive_overflow, sanity_check_target_after_processing, convert_object_detection_to_grounding_optimized_for_od | |
| import pdb | |
| import json | |
| class CocoGrounding(torchvision.datasets.CocoDetection): | |
| def __init__(self, | |
| img_folder, | |
| ann_file, | |
| transforms, | |
| return_masks, | |
| return_tokens, | |
| is_train=False, | |
| tokenizer=None, | |
| disable_shuffle=False, | |
| add_detection_prompt=False, | |
| one_hot=False, | |
| disable_clip_to_image=False, | |
| no_minus_one_for_one_hot=False, | |
| separation_tokens=" ", | |
| few_shot=0, | |
| no_mask_for_od=False, | |
| override_category=None, | |
| use_caption_prompt=False, | |
| caption_prompt=None, | |
| max_query_len=256, | |
| special_safeguard_for_coco_grounding=False, | |
| random_sample_negative=-1, | |
| **kwargs | |
| ): | |
| super(CocoGrounding, self).__init__(img_folder, ann_file) | |
| self.ids = sorted(self.ids) | |
| ids = [] | |
| for img_id in self.ids: | |
| if isinstance(img_id, str): | |
| ann_ids = self.coco.getAnnIds(imgIds=[img_id], iscrowd=None) | |
| else: | |
| ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None) | |
| anno = self.coco.loadAnns(ann_ids) | |
| if has_valid_annotation(anno): | |
| ids.append(img_id) | |
| self.ids = ids | |
| if few_shot: | |
| ids = [] | |
| # cats_freq = [few_shot]*len(self.coco.cats.keys()) | |
| cats_freq = [few_shot]*max(list(self.coco.cats.keys())) | |
| for img_id in self.ids: | |
| if isinstance(img_id, str): | |
| ann_ids = self.coco.getAnnIds(imgIds=[img_id], iscrowd=None) | |
| else: | |
| ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None) | |
| anno = self.coco.loadAnns(ann_ids) | |
| cat = set([ann['category_id'] for ann in anno]) #set/tuple corresponde to instance/image level | |
| is_needed = sum([cats_freq[c-1]>0 for c in cat]) | |
| if is_needed: | |
| ids.append(img_id) | |
| for c in cat: | |
| cats_freq[c-1] -= 1 | |
| # print(cat, cats_freq) | |
| self.ids = ids | |
| self.json_category_id_to_contiguous_id = { | |
| v: i + 1 for i, v in enumerate(self.coco.getCatIds()) | |
| } | |
| self.contiguous_category_id_to_json_id = { | |
| v: k for k, v in self.json_category_id_to_contiguous_id.items() | |
| } | |
| if override_category is not None: | |
| self.coco.dataset["categories"] = override_category | |
| self.use_caption_prompt = use_caption_prompt | |
| self.caption_prompt = caption_prompt | |
| self.special_safeguard_for_coco_grounding = special_safeguard_for_coco_grounding | |
| self.random_sample_negative = random_sample_negative | |
| self.ind_to_class = self.categories(no_background=False) | |
| self.id_to_img_map = {k: v for k, v in enumerate(self.ids)} | |
| self._transforms = transforms | |
| self.max_query_len = max_query_len | |
| self.prepare = ConvertCocoPolysToMask(False, return_tokens, tokenizer=tokenizer, max_query_len=max_query_len) | |
| self.tokenizer = tokenizer | |
| self.is_train = is_train | |
| self.ind_to_class = self.categories(no_background=False) | |
| self.disable_shuffle = disable_shuffle | |
| self.add_detection_prompt = add_detection_prompt | |
| self.one_hot = one_hot | |
| self.no_minus_one_for_one_hot = no_minus_one_for_one_hot | |
| self.disable_clip_to_image = disable_clip_to_image | |
| self.separation_tokens = separation_tokens | |
| self.no_mask_for_od = no_mask_for_od | |
| self.return_masks = return_masks | |
| def categories(self, no_background=True): | |
| categories = self.coco.dataset["categories"] | |
| label_list = {} | |
| for index, i in enumerate(categories): | |
| # assert(index + 1 == i["id"]) | |
| if not no_background or (i["name"] != "__background__" and i['id'] != 0): | |
| label_list[self.json_category_id_to_contiguous_id[i["id"]]] = i["name"] | |
| return label_list | |
| def get_box_mask(self, rect, img_size, mode="poly"): | |
| assert mode=="poly", "Only support poly mask right now!" | |
| x1, y1, x2, y2 = rect[0], rect[1], rect[2], rect[3] | |
| return [[x1, y1, x1, y2, x2, y2, x2, y1]] | |
| def __getitem__(self, idx): | |
| img, tgt = super(CocoGrounding, self).__getitem__(idx) | |
| image_id = self.ids[idx] | |
| tgt = [obj for obj in tgt if obj["iscrowd"] == 0] | |
| boxes = [obj["bbox"] for obj in tgt] | |
| boxes = torch.as_tensor(boxes).reshape(-1, 4) # guard against no boxes | |
| target = BoxList(boxes, img.size, mode="xywh").convert("xyxy") | |
| classes = [obj["category_id"] for obj in tgt] | |
| classes = [self.json_category_id_to_contiguous_id[c] for c in classes] | |
| classes = torch.tensor(classes) | |
| target.add_field("labels", classes) | |
| if self.return_masks: | |
| masks = [] | |
| is_box_mask = [] | |
| for obj, bbox in zip(tgt, target.bbox): | |
| if "segmentation" in obj: | |
| masks.append(obj["segmentation"]) | |
| is_box_mask.append(0) | |
| else: | |
| masks.append(self.get_box_mask(bbox, img.size, mode="poly")) | |
| is_box_mask.append(1) | |
| masks = SegmentationMask(masks, img.size, mode="poly") | |
| is_box_mask = torch.tensor(is_box_mask) | |
| target.add_field("masks", masks) | |
| target.add_field("is_box_mask", is_box_mask) | |
| if not self.disable_clip_to_image: | |
| target = target.clip_to_image(remove_empty=True) | |
| if self.special_safeguard_for_coco_grounding: | |
| # Intended for LVIS | |
| assert(not self.use_caption_prompt) | |
| original_box_num = len(target) | |
| target, positive_caption_length = check_for_positive_overflow(target, self.ind_to_class, self.tokenizer, self.max_query_len-2) # leave some space for the special tokens | |
| if len(target) < original_box_num: | |
| print("WARNING: removed {} boxes due to positive caption overflow".format(original_box_num - len(target))) | |
| annotations, caption, greenlight_span_for_masked_lm_objective, label_to_positions = convert_object_detection_to_grounding_optimized_for_od( | |
| target=target, | |
| image_id=image_id, | |
| ind_to_class=self.ind_to_class, | |
| disable_shuffle=self.disable_shuffle, | |
| add_detection_prompt=False, | |
| add_detection_prompt_advanced=False, | |
| random_sample_negative=self.random_sample_negative, | |
| control_probabilities=(0.0, 0.0, 1.0, 0.0), # always try to add a lot of negatives | |
| restricted_negative_list=None, | |
| separation_tokens=self.separation_tokens, | |
| max_num_labels=-1, | |
| positive_caption_length=positive_caption_length, | |
| tokenizer=self.tokenizer, | |
| max_seq_length=self.max_query_len-2 | |
| ) | |
| else: | |
| # Intended for COCO / ODinW | |
| annotations, caption, greenlight_span_for_masked_lm_objective = convert_od_to_grounding_simple( | |
| target=target, | |
| image_id=image_id, | |
| ind_to_class=self.ind_to_class, | |
| disable_shuffle=self.disable_shuffle, | |
| add_detection_prompt=self.add_detection_prompt, | |
| separation_tokens=self.separation_tokens, | |
| caption_prompt=self.caption_prompt if self.use_caption_prompt else None, | |
| ) | |
| anno = {"image_id": image_id, "annotations": annotations, "caption": caption} | |
| anno["greenlight_span_for_masked_lm_objective"] = greenlight_span_for_masked_lm_objective | |
| if self.no_mask_for_od: | |
| anno["greenlight_span_for_masked_lm_objective"].append((-1, -1, -1)) | |
| img, anno = self.prepare(img, anno, box_format="xyxy") | |
| # for equivalence check | |
| if self.one_hot: | |
| logging.info("using one hot for equivalence check.") | |
| one_hot_map = torch.zeros_like(anno["positive_map"], dtype=torch.float) | |
| text_mask = torch.zeros(anno["positive_map"].shape[1], dtype=torch.int64) | |
| # create one hot mapping | |
| for ii, cls in enumerate(classes): | |
| if self.no_minus_one_for_one_hot: | |
| one_hot_map[ii, cls] = 1.0 | |
| else: | |
| one_hot_map[ii, cls - 1] = 1.0 | |
| if self.no_minus_one_for_one_hot: | |
| text_mask[:] = 1 | |
| else: | |
| text_mask[:len(self.ind_to_class)] = 1 | |
| anno["positive_map"] = one_hot_map | |
| anno["text_mask"] = text_mask | |
| if self._transforms is not None: | |
| img, target = self._transforms(img, target) | |
| # add additional property | |
| for ann in anno: | |
| target.add_field(ann, anno[ann]) | |
| sanity_check_target_after_processing(target) | |
| return img, target, idx | |
| def get_img_info(self, index): | |
| img_id = self.id_to_img_map[index] | |
| img_data = self.coco.imgs[img_id] | |
| return img_data | |
| class ModulatedDataset(torchvision.datasets.CocoDetection): | |
| def __init__(self, | |
| img_folder, | |
| ann_file, | |
| transforms, | |
| return_masks, | |
| return_tokens, | |
| is_train=False, | |
| tokenizer=None, | |
| disable_clip_to_image=False, | |
| no_mask_for_gold=False, | |
| max_query_len=256, | |
| **kwargs): | |
| super(ModulatedDataset, self).__init__(img_folder, ann_file) | |
| self.ids = sorted(self.ids) | |
| ids = [] | |
| for img_id in self.ids: | |
| if isinstance(img_id, str): | |
| ann_ids = self.coco.getAnnIds(imgIds=[img_id], iscrowd=None) | |
| else: | |
| ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None) | |
| anno = self.coco.loadAnns(ann_ids) | |
| if has_valid_annotation(anno): | |
| ids.append(img_id) | |
| self.ids = ids | |
| self.id_to_img_map = {k: v for k, v in enumerate(self.ids)} | |
| self._transforms = transforms | |
| self.max_query_len = max_query_len | |
| self.prepare = ConvertCocoPolysToMask(return_masks, return_tokens, tokenizer=tokenizer, max_query_len=max_query_len) | |
| self.is_train = is_train | |
| self.disable_clip_to_image = disable_clip_to_image | |
| self.no_mask_for_gold = no_mask_for_gold | |
| def __getitem__(self, idx): | |
| img, target = super(ModulatedDataset, self).__getitem__(idx) | |
| image_id = self.ids[idx] | |
| coco_img = self.coco.loadImgs(image_id)[0] | |
| caption = coco_img["caption"] | |
| dataset_name = coco_img["dataset_name"] if "dataset_name" in coco_img else None | |
| anno = {"image_id": image_id, "annotations": target, "caption": caption} | |
| # This dataset is used for Flickr & Mixed, so the sequence is maskable | |
| anno["greenlight_span_for_masked_lm_objective"] = [(0, len(caption))] | |
| if self.no_mask_for_gold: | |
| anno["greenlight_span_for_masked_lm_objective"].append((-1, -1, -1)) | |
| img, anno = self.prepare(img, anno) | |
| # convert to BoxList (bboxes, labels) | |
| boxes = torch.as_tensor(anno["boxes"]).reshape(-1, 4) # guard against no boxes | |
| target = BoxList(boxes, img.size, mode="xyxy") | |
| classes = anno["labels"] | |
| target.add_field("labels", classes) | |
| if self.prepare.return_masks: | |
| target.add_field("masks", anno.pop("masks")) | |
| target.add_field("is_box_mask", anno.pop("is_box_mask")) | |
| if not self.disable_clip_to_image: | |
| num_boxes = len(target.bbox) | |
| target = target.clip_to_image(remove_empty=True) | |
| assert num_boxes == len(target.bbox), "Box got removed in MixedDataset!!!" | |
| # Check if bboxes are correct | |
| # draw = ImageDraw.Draw(img) | |
| # boxes = target.bbox | |
| # for box in boxes: | |
| # draw.rectangle([box[0], box[1], box[2], box[3]]) | |
| # img.save('OUTPUT/images/{}.jpg'.format(idx)) | |
| if self._transforms is not None: | |
| img, target = self._transforms(img, target) | |
| # add additional property | |
| for ann in anno: | |
| target.add_field(ann, anno[ann]) | |
| target.add_field("dataset_name", dataset_name) | |
| for extra_key in ["sentence_id", "original_img_id", "original_id", "task_id"]: | |
| if extra_key in coco_img: | |
| target.add_field(extra_key, coco_img[extra_key]) | |
| if "tokens_positive_eval" in coco_img and not self.is_train: | |
| tokenized = self.prepare.tokenizer(caption, return_tensors="pt") | |
| target.add_field("positive_map_eval", create_positive_map(tokenized, coco_img["tokens_positive_eval"])) | |
| target.add_field("nb_eval", len(target.get_field("positive_map_eval"))) | |
| sanity_check_target_after_processing(target) | |
| return img, target, idx | |
| def get_img_info(self, index): | |
| img_id = self.id_to_img_map[index] | |
| img_data = self.coco.imgs[img_id] | |
| return img_data | |
| class CocoDetection(data.Dataset): | |
| """`MS Coco Detection <http://mscoco.org/dataset/#detections-challenge2016>`_ Dataset. | |
| Args: | |
| root (string): Root directory where images are downloaded to. | |
| annFile (string): Path to json annotation file. | |
| transform (callable, optional): A function/transform that takes in an PIL image | |
| and returns a transformed version. E.g, ``transforms.ToTensor`` | |
| target_transform (callable, optional): A function/transform that takes in the | |
| target and transforms it. | |
| """ | |
| def __init__(self, root, annFile, transform=None, target_transform=None): | |
| from pycocotools.coco import COCO | |
| self.root = root | |
| self.coco = COCO(annFile) | |
| self.ids = list(self.coco.imgs.keys()) | |
| self.transform = transform | |
| self.target_transform = target_transform | |
| def __getitem__(self, index, return_meta=False): | |
| """ | |
| Args: | |
| index (int): Index | |
| Returns: | |
| tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``. | |
| """ | |
| coco = self.coco | |
| img_id = self.ids[index] | |
| if isinstance(img_id, str): | |
| img_id = [img_id] | |
| ann_ids = coco.getAnnIds(imgIds=img_id) | |
| target = coco.loadAnns(ann_ids) | |
| meta = coco.loadImgs(img_id)[0] | |
| path = meta['file_name'] | |
| img = pil_loader(os.path.join(self.root, path)) | |
| if self.transform is not None: | |
| img = self.transform(img) | |
| if self.target_transform is not None: | |
| target = self.target_transform(target) | |
| if return_meta: | |
| return img, target, meta | |
| else: | |
| return img, target | |
| def __len__(self): | |
| return len(self.ids) | |
| def __repr__(self): | |
| fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' | |
| fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) | |
| fmt_str += ' Root Location: {}\n'.format(self.root) | |
| tmp = ' Transforms (if any): ' | |
| fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) | |
| tmp = ' Target Transforms (if any): ' | |
| fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) | |
| return fmt_str | |
| class ConvertCocoPolysToMask(object): | |
| def __init__(self, return_masks=False, return_tokens=False, tokenizer=None, max_query_len=256): | |
| self.return_masks = return_masks | |
| self.return_tokens = return_tokens | |
| self.tokenizer = tokenizer | |
| self.max_query_len = max_query_len | |
| def get_box_mask(self, rect, img_size, mode="poly"): | |
| assert mode=="poly", "Only support poly mask right now!" | |
| x1, y1, x2, y2 = rect[0], rect[1], rect[2], rect[3] | |
| return [[x1, y1, x1, y2, x2, y2, x2, y1]] | |
| def __call__(self, image, target, ignore_box_screen=False, box_format="xywh"): | |
| w, h = image.size | |
| image_id = target["image_id"] | |
| image_id = torch.tensor([image_id]) | |
| anno = target["annotations"] | |
| caption = target["caption"] if "caption" in target else None | |
| label_to_positions = target.get("label_to_positions", {}) | |
| greenlight_span_for_masked_lm_objective = target.get("greenlight_span_for_masked_lm_objective", None) | |
| anno = [obj for obj in anno if "iscrowd" not in obj or obj["iscrowd"] == 0] | |
| boxes = [obj["bbox"] for obj in anno] | |
| # guard against no boxes via resizing | |
| boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) | |
| if box_format == "xywh": | |
| boxes[:, 2:] += boxes[:, :2] - 1 # TO_REMOVE = 1 | |
| boxes[:, 0::2].clamp_(min=0, max=w-1) # TO_REMOVE = 1 | |
| boxes[:, 1::2].clamp_(min=0, max=h-1) # TO_REMOVE = 1 | |
| classes = [obj["category_id"] for obj in anno] | |
| classes = torch.tensor(classes, dtype=torch.int64) | |
| if self.return_masks: | |
| masks = [] | |
| is_box_mask = [] | |
| for obj, bbox in zip(anno, boxes): | |
| if "segmentation" in obj: | |
| masks.append(obj["segmentation"]) | |
| is_box_mask.append(0) | |
| else: | |
| masks.append(self.get_box_mask(bbox, image.size, mode='poly')) | |
| is_box_mask.append(1) | |
| masks = SegmentationMask(masks, image.size, mode='poly') | |
| is_box_mask = torch.tensor(is_box_mask) | |
| keypoints = None | |
| if anno and "keypoints" in anno[0]: | |
| keypoints = [obj["keypoints"] for obj in anno] | |
| keypoints = torch.as_tensor(keypoints, dtype=torch.float32) | |
| num_keypoints = keypoints.shape[0] | |
| if num_keypoints: | |
| keypoints = keypoints.view(num_keypoints, -1, 3) | |
| isfinal = None | |
| if anno and "isfinal" in anno[0]: | |
| isfinal = torch.as_tensor([obj["isfinal"] for obj in anno], dtype=torch.float) | |
| tokens_positive = [] if self.return_tokens else None | |
| if self.return_tokens and anno and "tokens" in anno[0]: | |
| tokens_positive = [obj["tokens"] for obj in anno] | |
| elif self.return_tokens and anno and "tokens_positive" in anno[0]: | |
| tokens_positive = [obj["tokens_positive"] for obj in anno] | |
| keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) | |
| boxes = boxes[keep] | |
| classes = classes[keep] | |
| if self.return_masks: | |
| masks = masks[keep] | |
| is_box_mask = is_box_mask[keep] | |
| if keypoints is not None: | |
| keypoints = keypoints[keep] | |
| target = {} | |
| target["boxes"] = boxes | |
| target["labels"] = classes | |
| if caption is not None: | |
| target["caption"] = caption | |
| if self.return_masks: | |
| target["masks"] = masks | |
| target["is_box_mask"] = is_box_mask | |
| target["image_id"] = image_id | |
| if keypoints is not None: | |
| target["keypoints"] = keypoints | |
| if tokens_positive is not None: | |
| target["tokens_positive"] = [] | |
| for i, k in enumerate(keep): | |
| if k or ignore_box_screen: | |
| target["tokens_positive"].append(tokens_positive[i]) | |
| if isfinal is not None: | |
| target["isfinal"] = isfinal | |
| # for conversion to coco api | |
| area = torch.tensor([obj["area"] for obj in anno]) | |
| iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno]) | |
| target["area"] = area[keep] | |
| target["iscrowd"] = iscrowd[keep] | |
| target["orig_size"] = torch.as_tensor([int(h), int(w)]) | |
| target["size"] = torch.as_tensor([int(h), int(w)]) | |
| if self.return_tokens and self.tokenizer is not None: | |
| if not ignore_box_screen: | |
| assert len(target["boxes"]) == len(target["tokens_positive"]) | |
| tokenized = self.tokenizer(caption, return_tensors="pt", | |
| max_length=self.max_query_len, | |
| truncation=True) | |
| target["positive_map"] = create_positive_map(tokenized, target["tokens_positive"]) | |
| target['greenlight_map'] = create_greenlight_map(greenlight_span_for_masked_lm_objective,tokenized) | |
| target["positive_map_for_od_labels"] = create_positive_map_for_od_labels(tokenized, label_to_positions) | |
| original_od_label = [] | |
| for obj in anno: | |
| original_od_label.append( | |
| obj.get("original_od_label", -10)) # NOTE: The padding value has to be not the same as -1 or -100 | |
| target["original_od_label"] = torch.as_tensor(original_od_label) | |
| return image, target | |
| def create_greenlight_map(tok_list, tokenized): | |
| # An example tok_list: | |
| # [(0, 5), (10, 13), (-1, -1, -1)] | |
| # The last one is a special indicator.. | |
| greenlight_map = torch.zeros(256, dtype=torch.float) | |
| for item in tok_list: | |
| if len(item) != 2: | |
| assert(len(item) == 3) | |
| # Make everything unmakable | |
| greenlight_map[:] = -1 | |
| break | |
| beg, end = item | |
| beg_pos = tokenized.char_to_token(beg) | |
| end_pos = tokenized.char_to_token(end - 1) | |
| if beg_pos is None: | |
| try: | |
| beg_pos = tokenized.char_to_token(beg + 1) | |
| if beg_pos is None: | |
| beg_pos = tokenized.char_to_token(beg + 2) | |
| except: | |
| beg_pos = None | |
| if end_pos is None: | |
| try: | |
| end_pos = tokenized.char_to_token(end - 2) | |
| if end_pos is None: | |
| end_pos = tokenized.char_to_token(end - 3) | |
| except: | |
| end_pos = None | |
| if beg_pos is None or end_pos is None: | |
| continue | |
| assert beg_pos is not None and end_pos is not None | |
| greenlight_map[beg_pos: end_pos + 1].fill_(1) | |
| return greenlight_map | |
| def create_positive_map_for_od_labels(tokenized, label_to_positions): | |
| """construct a map such that positive_map[i] = j, where j is the object detection label of the token i""" | |
| """ | |
| {3: [1: 5)} | |
| 256 : -1 3 3 3 3 -1 .. 8 8 .. | |
| the woman in the garden | |
| -1 -1 -1 -1 -1 | |
| """ | |
| positive_map = torch.ones(256, dtype=torch.float) * -1 # -1 means no match | |
| keys = list(label_to_positions.keys()) | |
| for j, key in enumerate(keys): | |
| tok_list = label_to_positions[key] | |
| # one label only mapps to one location | |
| beg, end = tok_list | |
| beg_pos = tokenized.char_to_token(beg) | |
| end_pos = tokenized.char_to_token(end - 1) | |
| if beg_pos is None: | |
| try: | |
| beg_pos = tokenized.char_to_token(beg + 1) | |
| if beg_pos is None: | |
| beg_pos = tokenized.char_to_token(beg + 2) | |
| except: | |
| beg_pos = None | |
| if end_pos is None: | |
| try: | |
| end_pos = tokenized.char_to_token(end - 2) | |
| if end_pos is None: | |
| end_pos = tokenized.char_to_token(end - 3) | |
| except: | |
| end_pos = None | |
| if beg_pos is None or end_pos is None: | |
| continue | |
| assert beg_pos is not None and end_pos is not None | |
| positive_map[beg_pos: end_pos + 1].fill_(key) | |
| return positive_map | |
| def convert_coco_poly_to_mask(segmentations, height, width): | |
| masks = [] | |
| for polygons in segmentations: | |
| rles = coco_mask.frPyObjects(polygons, height, width) | |
| mask = coco_mask.decode(rles) | |
| if len(mask.shape) < 3: | |
| mask = mask[..., None] | |
| mask = torch.as_tensor(mask, dtype=torch.uint8) | |
| mask = mask.any(dim=2) | |
| masks.append(mask) | |
| if masks: | |
| masks = torch.stack(masks, dim=0) | |
| else: | |
| masks = torch.zeros((0, height, width), dtype=torch.uint8) | |
| return masks | |
| def create_positive_map(tokenized, tokens_positive): | |
| """construct a map such that positive_map[i,j] = True iff box i is associated to token j""" | |
| positive_map = torch.zeros((len(tokens_positive), 256), dtype=torch.float) | |
| for j, tok_list in enumerate(tokens_positive): | |
| for (beg, end) in tok_list: | |
| beg_pos = tokenized.char_to_token(beg) | |
| end_pos = tokenized.char_to_token(end - 1) | |
| if beg_pos is None: | |
| try: | |
| beg_pos = tokenized.char_to_token(beg + 1) | |
| if beg_pos is None: | |
| beg_pos = tokenized.char_to_token(beg + 2) | |
| except: | |
| beg_pos = None | |
| if end_pos is None: | |
| try: | |
| end_pos = tokenized.char_to_token(end - 2) | |
| if end_pos is None: | |
| end_pos = tokenized.char_to_token(end - 3) | |
| except: | |
| end_pos = None | |
| if beg_pos is None or end_pos is None: | |
| continue | |
| assert beg_pos is not None and end_pos is not None | |
| positive_map[j, beg_pos: end_pos + 1].fill_(1) | |
| return positive_map / (positive_map.sum(-1)[:, None] + 1e-6) | |
| def pil_loader(path, retry=5): | |
| # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) | |
| ri = 0 | |
| while ri < retry: | |
| try: | |
| with open(path, 'rb') as f: | |
| img = Image.open(f) | |
| return img.convert('RGB') | |
| except: | |
| ri += 1 | |