Spaces:
Runtime error
Runtime error
| from typing import Optional, Union | |
| import torch | |
| from mmdet.models.task_modules.assigners.match_cost import BaseMatchCost | |
| from mmengine.structures import InstanceData | |
| from torch import Tensor | |
| from mmdet.registry import TASK_UTILS | |
| class FlexibleClassificationCost(BaseMatchCost): | |
| def __init__(self, weight: Union[float, int] = 1) -> None: | |
| super().__init__(weight=weight) | |
| def __call__(self, | |
| pred_instances: InstanceData, | |
| gt_instances: InstanceData, | |
| img_meta: Optional[dict] = None, | |
| **kwargs) -> Tensor: | |
| """Compute match cost. | |
| Args: | |
| pred_instances (:obj:`InstanceData`): ``scores`` inside is | |
| predicted classification logits, of shape | |
| (num_queries, num_class). | |
| gt_instances (:obj:`InstanceData`): ``labels`` inside should have | |
| shape (num_gt, ). | |
| img_meta (Optional[dict]): _description_. Defaults to None. | |
| Returns: | |
| Tensor: Match Cost matrix of shape (num_preds, num_gts). | |
| """ | |
| _pred_scores = pred_instances.scores | |
| gt_labels = gt_instances.labels | |
| pred_scores = _pred_scores[..., :-1] | |
| iou_score = _pred_scores[..., -1:] | |
| pred_scores = pred_scores.softmax(-1) | |
| iou_score = iou_score.sigmoid() | |
| pred_scores = torch.cat([pred_scores, iou_score], dim=-1) | |
| cls_cost = -pred_scores[:, gt_labels] | |
| return cls_cost * self.weight | |