Spaces:
Runtime error
Runtime error
| # Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved | |
| # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
| import json | |
| import os | |
| import time | |
| from collections import defaultdict | |
| import pycocotools.mask as mask_utils | |
| import torchvision | |
| from PIL import Image | |
| # from .coco import ConvertCocoPolysToMask, make_coco_transforms | |
| from .modulated_coco import ConvertCocoPolysToMask | |
| def _isArrayLike(obj): | |
| return hasattr(obj, "__iter__") and hasattr(obj, "__len__") | |
| class LVIS: | |
| def __init__(self, annotation_path=None): | |
| """Class for reading and visualizing annotations. | |
| Args: | |
| annotation_path (str): location of annotation file | |
| """ | |
| self.anns = {} | |
| self.cats = {} | |
| self.imgs = {} | |
| self.img_ann_map = defaultdict(list) | |
| self.cat_img_map = defaultdict(list) | |
| self.dataset = {} | |
| if annotation_path is not None: | |
| print("Loading annotations.") | |
| tic = time.time() | |
| self.dataset = self._load_json(annotation_path) | |
| print("Done (t={:0.2f}s)".format(time.time() - tic)) | |
| assert type(self.dataset) == dict, "Annotation file format {} not supported.".format(type(self.dataset)) | |
| self._create_index() | |
| def _load_json(self, path): | |
| with open(path, "r") as f: | |
| return json.load(f) | |
| def _create_index(self): | |
| print("Creating index.") | |
| self.img_ann_map = defaultdict(list) | |
| self.cat_img_map = defaultdict(list) | |
| self.anns = {} | |
| self.cats = {} | |
| self.imgs = {} | |
| for ann in self.dataset["annotations"]: | |
| self.img_ann_map[ann["image_id"]].append(ann) | |
| self.anns[ann["id"]] = ann | |
| for img in self.dataset["images"]: | |
| self.imgs[img["id"]] = img | |
| for cat in self.dataset["categories"]: | |
| self.cats[cat["id"]] = cat | |
| for ann in self.dataset["annotations"]: | |
| self.cat_img_map[ann["category_id"]].append(ann["image_id"]) | |
| print("Index created.") | |
| def get_ann_ids(self, img_ids=None, cat_ids=None, area_rng=None): | |
| """Get ann ids that satisfy given filter conditions. | |
| Args: | |
| img_ids (int array): get anns for given imgs | |
| cat_ids (int array): get anns for given cats | |
| area_rng (float array): get anns for a given area range. e.g [0, inf] | |
| Returns: | |
| ids (int array): integer array of ann ids | |
| """ | |
| if img_ids is not None: | |
| img_ids = img_ids if _isArrayLike(img_ids) else [img_ids] | |
| if cat_ids is not None: | |
| cat_ids = cat_ids if _isArrayLike(cat_ids) else [cat_ids] | |
| anns = [] | |
| if img_ids is not None: | |
| for img_id in img_ids: | |
| anns.extend(self.img_ann_map[img_id]) | |
| else: | |
| anns = self.dataset["annotations"] | |
| # return early if no more filtering required | |
| if cat_ids is None and area_rng is None: | |
| return [_ann["id"] for _ann in anns] | |
| cat_ids = set(cat_ids) | |
| if area_rng is None: | |
| area_rng = [0, float("inf")] | |
| ann_ids = [ | |
| _ann["id"] | |
| for _ann in anns | |
| if _ann["category_id"] in cat_ids and _ann["area"] > area_rng[0] and _ann["area"] < area_rng[1] | |
| ] | |
| return ann_ids | |
| def get_cat_ids(self): | |
| """Get all category ids. | |
| Returns: | |
| ids (int array): integer array of category ids | |
| """ | |
| return list(self.cats.keys()) | |
| def get_img_ids(self): | |
| """Get all img ids. | |
| Returns: | |
| ids (int array): integer array of image ids | |
| """ | |
| return list(self.imgs.keys()) | |
| def _load_helper(self, _dict, ids): | |
| if ids is None: | |
| return list(_dict.values()) | |
| elif _isArrayLike(ids): | |
| return [_dict[id] for id in ids] | |
| else: | |
| return [_dict[ids]] | |
| def load_anns(self, ids=None): | |
| """Load anns with the specified ids. If ids=None load all anns. | |
| Args: | |
| ids (int array): integer array of annotation ids | |
| Returns: | |
| anns (dict array) : loaded annotation objects | |
| """ | |
| return self._load_helper(self.anns, ids) | |
| def load_cats(self, ids): | |
| """Load categories with the specified ids. If ids=None load all | |
| categories. | |
| Args: | |
| ids (int array): integer array of category ids | |
| Returns: | |
| cats (dict array) : loaded category dicts | |
| """ | |
| return self._load_helper(self.cats, ids) | |
| def load_imgs(self, ids): | |
| """Load categories with the specified ids. If ids=None load all images. | |
| Args: | |
| ids (int array): integer array of image ids | |
| Returns: | |
| imgs (dict array) : loaded image dicts | |
| """ | |
| return self._load_helper(self.imgs, ids) | |
| def download(self, save_dir, img_ids=None): | |
| """Download images from mscoco.org server. | |
| Args: | |
| save_dir (str): dir to save downloaded images | |
| img_ids (int array): img ids of images to download | |
| """ | |
| imgs = self.load_imgs(img_ids) | |
| if not os.path.exists(save_dir): | |
| os.makedirs(save_dir) | |
| for img in imgs: | |
| file_name = os.path.join(save_dir, img["file_name"]) | |
| if not os.path.exists(file_name): | |
| from urllib.request import urlretrieve | |
| urlretrieve(img["coco_url"], file_name) | |
| def ann_to_rle(self, ann): | |
| """Convert annotation which can be polygons, uncompressed RLE to RLE. | |
| Args: | |
| ann (dict) : annotation object | |
| Returns: | |
| ann (rle) | |
| """ | |
| img_data = self.imgs[ann["image_id"]] | |
| h, w = img_data["height"], img_data["width"] | |
| segm = ann["segmentation"] | |
| if isinstance(segm, list): | |
| # polygon -- a single object might consist of multiple parts | |
| # we merge all parts into one mask rle code | |
| rles = mask_utils.frPyObjects(segm, h, w) | |
| rle = mask_utils.merge(rles) | |
| elif isinstance(segm["counts"], list): | |
| # uncompressed RLE | |
| rle = mask_utils.frPyObjects(segm, h, w) | |
| else: | |
| # rle | |
| rle = ann["segmentation"] | |
| return rle | |
| def ann_to_mask(self, ann): | |
| """Convert annotation which can be polygons, uncompressed RLE, or RLE | |
| to binary mask. | |
| Args: | |
| ann (dict) : annotation object | |
| Returns: | |
| binary mask (numpy 2D array) | |
| """ | |
| rle = self.ann_to_rle(ann) | |
| return mask_utils.decode(rle) | |
| class LvisDetectionBase(torchvision.datasets.VisionDataset): | |
| def __init__(self, root, annFile, transform=None, target_transform=None, transforms=None): | |
| super(LvisDetectionBase, self).__init__(root, transforms, transform, target_transform) | |
| self.lvis = LVIS(annFile) | |
| self.ids = list(sorted(self.lvis.imgs.keys())) | |
| def __getitem__(self, index): | |
| """ | |
| Args: | |
| index (int): Index | |
| Returns: | |
| tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``. | |
| """ | |
| lvis = self.lvis | |
| img_id = self.ids[index] | |
| ann_ids = lvis.get_ann_ids(img_ids=img_id) | |
| target = lvis.load_anns(ann_ids) | |
| path = "/".join(self.lvis.load_imgs(img_id)[0]["coco_url"].split("/")[-2:]) | |
| img = Image.open(os.path.join(self.root, path)).convert("RGB") | |
| if self.transforms is not None: | |
| img, target = self.transforms(img, target) | |
| return img, target | |
| def __len__(self): | |
| return len(self.ids) | |
| class LvisDetection(LvisDetectionBase): | |
| def __init__(self, img_folder, ann_file, transforms, return_masks=False, **kwargs): | |
| super(LvisDetection, self).__init__(img_folder, ann_file) | |
| self.ann_file = ann_file | |
| self._transforms = transforms | |
| self.prepare = ConvertCocoPolysToMask(return_masks) | |
| def __getitem__(self, idx): | |
| img, target = super(LvisDetection, self).__getitem__(idx) | |
| image_id = self.ids[idx] | |
| target = {"image_id": image_id, "annotations": target} | |
| img, target = self.prepare(img, target) | |
| if self._transforms is not None: | |
| img = self._transforms(img) | |
| return img, target, idx | |
| def get_raw_image(self, idx): | |
| img, target = super(LvisDetection, self).__getitem__(idx) | |
| return img | |
| def categories(self): | |
| id2cat = {c["id"]: c for c in self.lvis.dataset["categories"]} | |
| all_cats = sorted(list(id2cat.keys())) | |
| categories = {} | |
| for l in list(all_cats): | |
| categories[l] = id2cat[l]['name'] | |
| return categories |