Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from functools import partial | |
| import torch | |
| from six.moves import map, zip | |
| def multi_apply(func, *args, **kwargs): | |
| """Apply function to a list of arguments. | |
| Note: | |
| This function applies the ``func`` to multiple inputs and | |
| map the multiple outputs of the ``func`` into different | |
| list. Each list contains the same type of outputs corresponding | |
| to different inputs. | |
| Args: | |
| func (Function): A function that will be applied to a list of | |
| arguments | |
| Returns: | |
| tuple(list): A tuple containing multiple list, each list contains | |
| a kind of returned results by the function | |
| """ | |
| pfunc = partial(func, **kwargs) if kwargs else func | |
| map_results = map(pfunc, *args) | |
| return tuple(map(list, zip(*map_results))) | |
| def filter_scores_and_topk(scores, score_thr, topk, results=None): | |
| """Filter results using score threshold and topk candidates. | |
| Args: | |
| scores (Tensor): The scores, shape (num_bboxes, K). | |
| score_thr (float): The score filter threshold. | |
| topk (int): The number of topk candidates. | |
| results (dict or list or Tensor, Optional): The results to | |
| which the filtering rule is to be applied. The shape | |
| of each item is (num_bboxes, N). | |
| Returns: | |
| tuple: Filtered results | |
| - scores (Tensor): The scores after being filtered, \ | |
| shape (num_bboxes_filtered, ). | |
| - labels (Tensor): The class labels, shape \ | |
| (num_bboxes_filtered, ). | |
| - anchor_idxs (Tensor): The anchor indexes, shape \ | |
| (num_bboxes_filtered, ). | |
| - filtered_results (dict or list or Tensor, Optional): \ | |
| The filtered results. The shape of each item is \ | |
| (num_bboxes_filtered, N). | |
| """ | |
| valid_mask = scores > score_thr | |
| scores = scores[valid_mask] | |
| valid_idxs = torch.nonzero(valid_mask) | |
| num_topk = min(topk, valid_idxs.size(0)) | |
| # torch.sort is actually faster than .topk (at least on GPUs) | |
| scores, idxs = scores.sort(descending=True) | |
| scores = scores[:num_topk] | |
| topk_idxs = valid_idxs[idxs[:num_topk]] | |
| keep_idxs, labels = topk_idxs.unbind(dim=1) | |
| filtered_results = None | |
| if results is not None: | |
| if isinstance(results, dict): | |
| filtered_results = {k: v[keep_idxs] for k, v in results.items()} | |
| elif isinstance(results, list): | |
| filtered_results = [result[keep_idxs] for result in results] | |
| elif isinstance(results, torch.Tensor): | |
| filtered_results = results[keep_idxs] | |
| else: | |
| raise NotImplementedError(f'Only supports dict or list or Tensor, ' | |
| f'but get {type(results)}.') | |
| return scores, labels, keep_idxs, filtered_results | |