Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import torch.utils.data | |
| from PIL import Image | |
| import sys | |
| if sys.version_info[0] == 2: | |
| import xml.etree.cElementTree as ET | |
| else: | |
| import xml.etree.ElementTree as ET | |
| from maskrcnn_benchmark.structures.bounding_box import BoxList | |
| class PascalVOCDataset(torch.utils.data.Dataset): | |
| CLASSES = ( | |
| "__background__ ", | |
| "aeroplane", | |
| "bicycle", | |
| "bird", | |
| "boat", | |
| "bottle", | |
| "bus", | |
| "car", | |
| "cat", | |
| "chair", | |
| "cow", | |
| "diningtable", | |
| "dog", | |
| "horse", | |
| "motorbike", | |
| "person", | |
| "pottedplant", | |
| "sheep", | |
| "sofa", | |
| "train", | |
| "tvmonitor", | |
| ) | |
| def __init__(self, data_dir, split, use_difficult=False, transforms=None): | |
| self.root = data_dir | |
| self.image_set = split | |
| self.keep_difficult = use_difficult | |
| self.transforms = transforms | |
| self._annopath = os.path.join(self.root, "Annotations", "%s.xml") | |
| self._imgpath = os.path.join(self.root, "JPEGImages", "%s.jpg") | |
| self._imgsetpath = os.path.join(self.root, "ImageSets", "Main", "%s.txt") | |
| with open(self._imgsetpath % self.image_set) as f: | |
| self.ids = f.readlines() | |
| self.ids = [x.strip("\n") for x in self.ids] | |
| self.id_to_img_map = {k: v for k, v in enumerate(self.ids)} | |
| cls = PascalVOCDataset.CLASSES | |
| self.class_to_ind = dict(zip(cls, range(len(cls)))) | |
| def __getitem__(self, index): | |
| img_id = self.ids[index] | |
| img = Image.open(self._imgpath % img_id).convert("RGB") | |
| target = self.get_groundtruth(index) | |
| target = target.clip_to_image(remove_empty=True) | |
| if self.transforms is not None: | |
| img, target = self.transforms(img, target) | |
| return img, target, index | |
| def __len__(self): | |
| return len(self.ids) | |
| def get_groundtruth(self, index): | |
| img_id = self.ids[index] | |
| anno = ET.parse(self._annopath % img_id).getroot() | |
| anno = self._preprocess_annotation(anno) | |
| height, width = anno["im_info"] | |
| target = BoxList(anno["boxes"], (width, height), mode="xyxy") | |
| target.add_field("labels", anno["labels"]) | |
| target.add_field("difficult", anno["difficult"]) | |
| return target | |
| def _preprocess_annotation(self, target): | |
| boxes = [] | |
| gt_classes = [] | |
| difficult_boxes = [] | |
| TO_REMOVE = 1 | |
| for obj in target.iter("object"): | |
| difficult = int(obj.find("difficult").text) == 1 | |
| if not self.keep_difficult and difficult: | |
| continue | |
| name = obj.find("name").text.lower().strip() | |
| bb = obj.find("bndbox") | |
| # Make pixel indexes 0-based | |
| # Refer to "https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/datasets/pascal_voc.py#L208-L211" | |
| box = [ | |
| bb.find("xmin").text, | |
| bb.find("ymin").text, | |
| bb.find("xmax").text, | |
| bb.find("ymax").text, | |
| ] | |
| bndbox = tuple( | |
| map(lambda x: x - TO_REMOVE, list(map(int, box))) | |
| ) | |
| boxes.append(bndbox) | |
| gt_classes.append(self.class_to_ind[name]) | |
| difficult_boxes.append(difficult) | |
| size = target.find("size") | |
| im_info = tuple(map(int, (size.find("height").text, size.find("width").text))) | |
| res = { | |
| "boxes": torch.tensor(boxes, dtype=torch.float32), | |
| "labels": torch.tensor(gt_classes), | |
| "difficult": torch.tensor(difficult_boxes), | |
| "im_info": im_info, | |
| } | |
| return res | |
| def get_img_info(self, index): | |
| img_id = self.ids[index] | |
| anno = ET.parse(self._annopath % img_id).getroot() | |
| size = anno.find("size") | |
| im_info = tuple(map(int, (size.find("height").text, size.find("width").text))) | |
| return {"height": im_info[0], "width": im_info[1]} | |
| def map_class_id_to_class_name(self, class_id): | |
| return PascalVOCDataset.CLASSES[class_id] | |