Spaces:
Running
Running
| # Copyright (c) Tencent Inc. All rights reserved. | |
| import os.path as osp | |
| from typing import List, Union | |
| from mmengine.fileio import get_local_path, join_path | |
| from mmengine.utils import is_abs | |
| from mmdet.datasets.coco import CocoDataset | |
| from mmyolo.registry import DATASETS | |
| from .utils import RobustBatchShapePolicyDataset | |
| class YOLOv5MixedGroundingDataset(RobustBatchShapePolicyDataset, CocoDataset): | |
| """Mixed grounding dataset.""" | |
| METAINFO = { | |
| 'classes': ('object',), | |
| 'palette': [(220, 20, 60)]} | |
| def load_data_list(self) -> List[dict]: | |
| """Load annotations from an annotation file named as ``self.ann_file`` | |
| Returns: | |
| List[dict]: A list of annotation. | |
| """ # noqa: E501 | |
| with get_local_path( | |
| self.ann_file, backend_args=self.backend_args) as local_path: | |
| self.coco = self.COCOAPI(local_path) | |
| img_ids = self.coco.get_img_ids() | |
| data_list = [] | |
| total_ann_ids = [] | |
| for img_id in img_ids: | |
| raw_img_info = self.coco.load_imgs([img_id])[0] | |
| raw_img_info['img_id'] = img_id | |
| ann_ids = self.coco.get_ann_ids(img_ids=[img_id]) | |
| raw_ann_info = self.coco.load_anns(ann_ids) | |
| total_ann_ids.extend(ann_ids) | |
| parsed_data_info = self.parse_data_info({ | |
| 'raw_ann_info': | |
| raw_ann_info, | |
| 'raw_img_info': | |
| raw_img_info | |
| }) | |
| data_list.append(parsed_data_info) | |
| if self.ANN_ID_UNIQUE: | |
| assert len(set(total_ann_ids)) == len( | |
| total_ann_ids | |
| ), f"Annotation ids in '{self.ann_file}' are not unique!" | |
| del self.coco | |
| # print(len(data_list)) | |
| return data_list | |
| def parse_data_info(self, raw_data_info: dict) -> Union[dict, List[dict]]: | |
| """Parse raw annotation to target format. | |
| Args: | |
| raw_data_info (dict): Raw data information load from ``ann_file`` | |
| Returns: | |
| Union[dict, List[dict]]: Parsed annotation. | |
| """ | |
| img_info = raw_data_info['raw_img_info'] | |
| ann_info = raw_data_info['raw_ann_info'] | |
| data_info = {} | |
| img_path = None | |
| img_prefix = self.data_prefix.get('img', None) | |
| if isinstance(img_prefix, str): | |
| img_path = osp.join(img_prefix, img_info['file_name']) | |
| elif isinstance(img_prefix, (list, tuple)): | |
| for prefix in img_prefix: | |
| candidate_img_path = osp.join(prefix, img_info['file_name']) | |
| if osp.exists(candidate_img_path): | |
| img_path = candidate_img_path | |
| break | |
| assert img_path is not None, ( | |
| f'Image path {img_info["file_name"]} not found in' | |
| f'{img_prefix}') | |
| if self.data_prefix.get('seg', None): | |
| seg_map_path = osp.join( | |
| self.data_prefix['seg'], | |
| img_info['file_name'].rsplit('.', 1)[0] + self.seg_map_suffix) | |
| else: | |
| seg_map_path = None | |
| data_info['img_path'] = img_path | |
| data_info['img_id'] = img_info['img_id'] | |
| data_info['seg_map_path'] = seg_map_path | |
| data_info['height'] = float(img_info['height']) | |
| data_info['width'] = float(img_info['width']) | |
| cat2id = {} | |
| texts = [] | |
| for ann in ann_info: | |
| cat_name = ' '.join([img_info['caption'][t[0]:t[1]] | |
| for t in ann['tokens_positive']]) | |
| if cat_name not in cat2id: | |
| cat2id[cat_name] = len(cat2id) | |
| texts.append([cat_name]) | |
| data_info['texts'] = texts | |
| instances = [] | |
| for i, ann in enumerate(ann_info): | |
| instance = {} | |
| if ann.get('ignore', False): | |
| continue | |
| x1, y1, w, h = ann['bbox'] | |
| inter_w = max(0, | |
| min(x1 + w, float(img_info['width'])) - max(x1, 0)) | |
| inter_h = max(0, | |
| min(y1 + h, float(img_info['height'])) - max(y1, 0)) | |
| if inter_w * inter_h == 0: | |
| continue | |
| if ann['area'] <= 0 or w < 1 or h < 1: | |
| continue | |
| bbox = [x1, y1, x1 + w, y1 + h] | |
| if ann.get('iscrowd', False): | |
| instance['ignore_flag'] = 1 | |
| else: | |
| instance['ignore_flag'] = 0 | |
| instance['bbox'] = bbox | |
| cat_name = ' '.join([img_info['caption'][t[0]:t[1]] | |
| for t in ann['tokens_positive']]) | |
| instance['bbox_label'] = cat2id[cat_name] | |
| if ann.get('segmentation', None): | |
| instance['mask'] = ann['segmentation'] | |
| instances.append(instance) | |
| # NOTE: for detection task, we set `is_detection` to 1 | |
| data_info['is_detection'] = 1 | |
| data_info['instances'] = instances | |
| # print(data_info['texts']) | |
| return data_info | |
| def filter_data(self) -> List[dict]: | |
| """Filter annotations according to filter_cfg. | |
| Returns: | |
| List[dict]: Filtered results. | |
| """ | |
| if self.test_mode: | |
| return self.data_list | |
| if self.filter_cfg is None: | |
| return self.data_list | |
| filter_empty_gt = self.filter_cfg.get('filter_empty_gt', False) | |
| min_size = self.filter_cfg.get('min_size', 0) | |
| # obtain images that contain annotation | |
| ids_with_ann = set(data_info['img_id'] for data_info in self.data_list) | |
| valid_data_infos = [] | |
| for i, data_info in enumerate(self.data_list): | |
| img_id = data_info['img_id'] | |
| width = int(data_info['width']) | |
| height = int(data_info['height']) | |
| if filter_empty_gt and img_id not in ids_with_ann: | |
| continue | |
| if min(width, height) >= min_size: | |
| valid_data_infos.append(data_info) | |
| return valid_data_infos | |
| def _join_prefix(self): | |
| """Join ``self.data_root`` with ``self.data_prefix`` and | |
| ``self.ann_file``. | |
| """ | |
| # Automatically join annotation file path with `self.root` if | |
| # `self.ann_file` is not an absolute path. | |
| if self.ann_file and not is_abs(self.ann_file) and self.data_root: | |
| self.ann_file = join_path(self.data_root, self.ann_file) | |
| # Automatically join data directory with `self.root` if path value in | |
| # `self.data_prefix` is not an absolute path. | |
| for data_key, prefix in self.data_prefix.items(): | |
| if isinstance(prefix, (list, tuple)): | |
| abs_prefix = [] | |
| for p in prefix: | |
| if not is_abs(p) and self.data_root: | |
| abs_prefix.append(join_path(self.data_root, p)) | |
| else: | |
| abs_prefix.append(p) | |
| self.data_prefix[data_key] = abs_prefix | |
| elif isinstance(prefix, str): | |
| if not is_abs(prefix) and self.data_root: | |
| self.data_prefix[data_key] = join_path( | |
| self.data_root, prefix) | |
| else: | |
| self.data_prefix[data_key] = prefix | |
| else: | |
| raise TypeError('prefix should be a string, tuple or list,' | |
| f'but got {type(prefix)}') | |