Spaces:
Runtime error
Runtime error
| import os | |
| import os.path as op | |
| import json | |
| # import logging | |
| import base64 | |
| import yaml | |
| import errno | |
| import io | |
| import math | |
| from PIL import Image, ImageDraw | |
| from maskrcnn_benchmark.structures.bounding_box import BoxList | |
| from .box_label_loader import LabelLoader | |
| def load_linelist_file(linelist_file): | |
| if linelist_file is not None: | |
| line_list = [] | |
| with open(linelist_file, 'r') as fp: | |
| for i in fp: | |
| line_list.append(int(i.strip())) | |
| return line_list | |
| def img_from_base64(imagestring): | |
| try: | |
| img = Image.open(io.BytesIO(base64.b64decode(imagestring))) | |
| return img.convert('RGB') | |
| except ValueError: | |
| return None | |
| def load_from_yaml_file(yaml_file): | |
| with open(yaml_file, 'r') as fp: | |
| return yaml.load(fp, Loader=yaml.CLoader) | |
| def find_file_path_in_yaml(fname, root): | |
| if fname is not None: | |
| if op.isfile(fname): | |
| return fname | |
| elif op.isfile(op.join(root, fname)): | |
| return op.join(root, fname) | |
| else: | |
| raise FileNotFoundError( | |
| errno.ENOENT, os.strerror(errno.ENOENT), op.join(root, fname) | |
| ) | |
| def create_lineidx(filein, idxout): | |
| idxout_tmp = idxout + '.tmp' | |
| with open(filein, 'r') as tsvin, open(idxout_tmp, 'w') as tsvout: | |
| fsize = os.fstat(tsvin.fileno()).st_size | |
| fpos = 0 | |
| while fpos != fsize: | |
| tsvout.write(str(fpos) + "\n") | |
| tsvin.readline() | |
| fpos = tsvin.tell() | |
| os.rename(idxout_tmp, idxout) | |
| def read_to_character(fp, c): | |
| result = [] | |
| while True: | |
| s = fp.read(32) | |
| assert s != '' | |
| if c in s: | |
| result.append(s[: s.index(c)]) | |
| break | |
| else: | |
| result.append(s) | |
| return ''.join(result) | |
| class TSVFile(object): | |
| def __init__(self, tsv_file, generate_lineidx=False): | |
| self.tsv_file = tsv_file | |
| self.lineidx = op.splitext(tsv_file)[0] + '.lineidx' | |
| self._fp = None | |
| self._lineidx = None | |
| # the process always keeps the process which opens the file. | |
| # If the pid is not equal to the currrent pid, we will re-open the file. | |
| self.pid = None | |
| # generate lineidx if not exist | |
| if not op.isfile(self.lineidx) and generate_lineidx: | |
| create_lineidx(self.tsv_file, self.lineidx) | |
| def __del__(self): | |
| if self._fp: | |
| self._fp.close() | |
| def __str__(self): | |
| return "TSVFile(tsv_file='{}')".format(self.tsv_file) | |
| def __repr__(self): | |
| return str(self) | |
| def num_rows(self): | |
| self._ensure_lineidx_loaded() | |
| return len(self._lineidx) | |
| def seek(self, idx): | |
| self._ensure_tsv_opened() | |
| self._ensure_lineidx_loaded() | |
| try: | |
| pos = self._lineidx[idx] | |
| except: | |
| # logging.info('{}-{}'.format(self.tsv_file, idx)) | |
| raise | |
| self._fp.seek(pos) | |
| return [s.strip() for s in self._fp.readline().split('\t')] | |
| def seek_first_column(self, idx): | |
| self._ensure_tsv_opened() | |
| self._ensure_lineidx_loaded() | |
| pos = self._lineidx[idx] | |
| self._fp.seek(pos) | |
| return read_to_character(self._fp, '\t') | |
| def get_key(self, idx): | |
| return self.seek_first_column(idx) | |
| def __getitem__(self, index): | |
| return self.seek(index) | |
| def __len__(self): | |
| return self.num_rows() | |
| def _ensure_lineidx_loaded(self): | |
| if self._lineidx is None: | |
| # logging.info('loading lineidx: {}'.format(self.lineidx)) | |
| with open(self.lineidx, 'r') as fp: | |
| self._lineidx = [int(i.strip()) for i in fp.readlines()] | |
| def _ensure_tsv_opened(self): | |
| if self._fp is None: | |
| self._fp = open(self.tsv_file, 'r') | |
| self.pid = os.getpid() | |
| if self.pid != os.getpid(): | |
| # logging.info('re-open {} because the process id changed'.format(self.tsv_file)) | |
| self._fp = open(self.tsv_file, 'r') | |
| self.pid = os.getpid() | |
| class CompositeTSVFile(): | |
| def __init__(self, file_list, seq_file, root='.'): | |
| if isinstance(file_list, str): | |
| self.file_list = load_list_file(file_list) | |
| else: | |
| assert isinstance(file_list, list) | |
| self.file_list = file_list | |
| self.seq_file = seq_file | |
| self.root = root | |
| self.initialized = False | |
| self.initialize() | |
| def get_key(self, index): | |
| idx_source, idx_row = self.seq[index] | |
| k = self.tsvs[idx_source].get_key(idx_row) | |
| return '_'.join([self.file_list[idx_source], k]) | |
| def num_rows(self): | |
| return len(self.seq) | |
| def __getitem__(self, index): | |
| idx_source, idx_row = self.seq[index] | |
| return self.tsvs[idx_source].seek(idx_row) | |
| def __len__(self): | |
| return len(self.seq) | |
| def initialize(self): | |
| ''' | |
| this function has to be called in init function if cache_policy is | |
| enabled. Thus, let's always call it in init funciton to make it simple. | |
| ''' | |
| if self.initialized: | |
| return | |
| self.seq = [] | |
| with open(self.seq_file, 'r') as fp: | |
| for line in fp: | |
| parts = line.strip().split('\t') | |
| self.seq.append([int(parts[0]), int(parts[1])]) | |
| self.tsvs = [TSVFile(op.join(self.root, f)) for f in self.file_list] | |
| self.initialized = True | |
| def load_list_file(fname): | |
| with open(fname, 'r') as fp: | |
| lines = fp.readlines() | |
| result = [line.strip() for line in lines] | |
| if len(result) > 0 and result[-1] == '': | |
| result = result[:-1] | |
| return result | |
| class TSVDataset(object): | |
| def __init__(self, img_file, label_file=None, hw_file=None, | |
| linelist_file=None, imageid2idx_file=None): | |
| """Constructor. | |
| Args: | |
| img_file: Image file with image key and base64 encoded image str. | |
| label_file: An optional label file with image key and label information. | |
| A label_file is required for training and optional for testing. | |
| hw_file: An optional file with image key and image height/width info. | |
| linelist_file: An optional file with a list of line indexes to load samples. | |
| It is useful to select a subset of samples or duplicate samples. | |
| """ | |
| self.img_file = img_file | |
| self.label_file = label_file | |
| self.hw_file = hw_file | |
| self.linelist_file = linelist_file | |
| self.img_tsv = TSVFile(img_file) | |
| self.label_tsv = None if label_file is None else TSVFile(label_file, generate_lineidx=True) | |
| self.hw_tsv = None if hw_file is None else TSVFile(hw_file) | |
| self.line_list = load_linelist_file(linelist_file) | |
| self.imageid2idx = None | |
| if imageid2idx_file is not None: | |
| self.imageid2idx = json.load(open(imageid2idx_file, 'r')) | |
| self.transforms = None | |
| def __len__(self): | |
| if self.line_list is None: | |
| if self.imageid2idx is not None: | |
| assert self.label_tsv is not None, "label_tsv is None!!!" | |
| return self.label_tsv.num_rows() | |
| return self.img_tsv.num_rows() | |
| else: | |
| return len(self.line_list) | |
| def __getitem__(self, idx): | |
| img = self.get_image(idx) | |
| img_size = img.size # w, h | |
| annotations = self.get_annotations(idx) | |
| # print(idx, annotations) | |
| target = self.get_target_from_annotations(annotations, img_size, idx) | |
| img, target = self.apply_transforms(img, target) | |
| if self.transforms is None: | |
| return img, target, idx, 1.0 | |
| else: | |
| new_img_size = img.shape[1:] | |
| scale = math.sqrt(float(new_img_size[0] * new_img_size[1]) / float(img_size[0] * img_size[1])) | |
| return img, target, idx, scale | |
| def get_line_no(self, idx): | |
| return idx if self.line_list is None else self.line_list[idx] | |
| def get_image(self, idx): | |
| line_no = self.get_line_no(idx) | |
| if self.imageid2idx is not None: | |
| assert self.label_tsv is not None, "label_tsv is None!!!" | |
| row = self.label_tsv.seek(line_no) | |
| annotations = json.loads(row[1]) | |
| imageid = annotations["img_id"] | |
| line_no = self.imageid2idx[imageid] | |
| row = self.img_tsv.seek(line_no) | |
| # use -1 to support old format with multiple columns. | |
| img = img_from_base64(row[-1]) | |
| return img | |
| def get_annotations(self, idx): | |
| line_no = self.get_line_no(idx) | |
| if self.label_tsv is not None: | |
| row = self.label_tsv.seek(line_no) | |
| annotations = json.loads(row[1]) | |
| return annotations | |
| else: | |
| return [] | |
| def get_target_from_annotations(self, annotations, img_size, idx): | |
| # This function will be overwritten by each dataset to | |
| # decode the labels to specific formats for each task. | |
| return annotations | |
| def apply_transforms(self, image, target=None): | |
| # This function will be overwritten by each dataset to | |
| # apply transforms to image and targets. | |
| return image, target | |
| def get_img_info(self, idx): | |
| if self.imageid2idx is not None: | |
| assert self.label_tsv is not None, "label_tsv is None!!!" | |
| line_no = self.get_line_no(idx) | |
| row = self.label_tsv.seek(line_no) | |
| annotations = json.loads(row[1]) | |
| return {"height": int(annotations["img_w"]), "width": int(annotations["img_w"])} | |
| if self.hw_tsv is not None: | |
| line_no = self.get_line_no(idx) | |
| row = self.hw_tsv.seek(line_no) | |
| try: | |
| # json string format with "height" and "width" being the keys | |
| data = json.loads(row[1]) | |
| if type(data) == list: | |
| return data[0] | |
| elif type(data) == dict: | |
| return data | |
| except ValueError: | |
| # list of strings representing height and width in order | |
| hw_str = row[1].split(' ') | |
| hw_dict = {"height": int(hw_str[0]), "width": int(hw_str[1])} | |
| return hw_dict | |
| def get_img_key(self, idx): | |
| line_no = self.get_line_no(idx) | |
| # based on the overhead of reading each row. | |
| if self.imageid2idx is not None: | |
| assert self.label_tsv is not None, "label_tsv is None!!!" | |
| row = self.label_tsv.seek(line_no) | |
| annotations = json.loads(row[1]) | |
| return annotations["img_id"] | |
| if self.hw_tsv: | |
| return self.hw_tsv.seek(line_no)[0] | |
| elif self.label_tsv: | |
| return self.label_tsv.seek(line_no)[0] | |
| else: | |
| return self.img_tsv.seek(line_no)[0] | |
| class TSVYamlDataset(TSVDataset): | |
| """ TSVDataset taking a Yaml file for easy function call | |
| """ | |
| def __init__(self, yaml_file, root=None, replace_clean_label=False): | |
| print("Reading {}".format(yaml_file)) | |
| self.cfg = load_from_yaml_file(yaml_file) | |
| if root: | |
| self.root = root | |
| else: | |
| self.root = op.dirname(yaml_file) | |
| img_file = find_file_path_in_yaml(self.cfg['img'], self.root) | |
| label_file = find_file_path_in_yaml(self.cfg.get('label', None), | |
| self.root) | |
| hw_file = find_file_path_in_yaml(self.cfg.get('hw', None), self.root) | |
| linelist_file = find_file_path_in_yaml(self.cfg.get('linelist', None), | |
| self.root) | |
| imageid2idx_file = find_file_path_in_yaml(self.cfg.get('imageid2idx', None), | |
| self.root) | |
| if replace_clean_label: | |
| assert ("raw_label" in label_file) | |
| label_file = label_file.replace("raw_label", "clean_label") | |
| super(TSVYamlDataset, self).__init__( | |
| img_file, label_file, hw_file, linelist_file, imageid2idx_file) | |
| class ODTSVDataset(TSVYamlDataset): | |
| """ | |
| Generic TSV dataset format for Object Detection. | |
| """ | |
| def __init__(self, yaml_file, extra_fields=(), transforms=None, | |
| is_load_label=True, **kwargs): | |
| if yaml_file is None: | |
| return | |
| super(ODTSVDataset, self).__init__(yaml_file) | |
| self.transforms = transforms | |
| self.is_load_label = is_load_label | |
| self.attribute_on = False | |
| # self.attribute_on = kwargs['args'].MODEL.ATTRIBUTE_ON if "args" in kwargs else False | |
| if self.is_load_label: | |
| # construct maps | |
| jsondict_file = find_file_path_in_yaml( | |
| self.cfg.get("labelmap", None), self.root | |
| ) | |
| if jsondict_file is None: | |
| jsondict_file = find_file_path_in_yaml( | |
| self.cfg.get("jsondict", None), self.root | |
| ) | |
| if "json" in jsondict_file: | |
| jsondict = json.load(open(jsondict_file, 'r')) | |
| if "label_to_idx" not in jsondict: | |
| jsondict = {'label_to_idx': jsondict} | |
| elif "tsv" in jsondict_file: | |
| label_to_idx = {} | |
| counter = 1 | |
| with open(jsondict_file) as f: | |
| for line in f: | |
| label_to_idx[line.strip()] = counter | |
| counter += 1 | |
| jsondict = {'label_to_idx': label_to_idx} | |
| else: | |
| assert (0) | |
| self.labelmap = {} | |
| self.class_to_ind = jsondict['label_to_idx'] | |
| self.class_to_ind['__background__'] = 0 | |
| self.ind_to_class = {v: k for k, v in self.class_to_ind.items()} | |
| self.labelmap['class_to_ind'] = self.class_to_ind | |
| if self.attribute_on: | |
| self.attribute_to_ind = jsondict['attribute_to_idx'] | |
| self.attribute_to_ind['__no_attribute__'] = 0 | |
| self.ind_to_attribute = {v: k for k, v in self.attribute_to_ind.items()} | |
| self.labelmap['attribute_to_ind'] = self.attribute_to_ind | |
| self.label_loader = LabelLoader( | |
| labelmap=self.labelmap, | |
| extra_fields=extra_fields, | |
| ) | |
| def get_target_from_annotations(self, annotations, img_size, idx): | |
| if isinstance(annotations, list): | |
| annotations = {"objects": annotations} | |
| if self.is_load_label: | |
| return self.label_loader(annotations['objects'], img_size) | |
| def apply_transforms(self, img, target=None): | |
| if self.transforms is not None: | |
| img, target = self.transforms(img, target) | |
| return img, target | |