Spaces:
Runtime error
Runtime error
| import torch | |
| from torchvision.ops.boxes import box_iou | |
| def nms(bboxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float) -> torch.Tensor: | |
| order = torch.argsort(-scores) | |
| keep = [] | |
| while order.numel() > 0: | |
| i = order[0] | |
| keep.append(i.item()) | |
| if order.numel() == 1: | |
| break | |
| ious = box_iou(bboxes[i].unsqueeze(0), bboxes[order[1:]])[0] | |
| mask = ious <= iou_threshold | |
| order = order[1:][mask] | |
| return torch.tensor(keep, device=bboxes.device) | |