Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import random | |
| from typing import List, Tuple | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from mmengine import MessageHub | |
| from mmengine.dist import barrier, broadcast, get_dist_info | |
| from mmengine.structures import PixelData | |
| from torch import Tensor | |
| from mmpose.registry import MODELS | |
| from mmpose.structures import PoseDataSample | |
| class BatchSyncRandomResize(nn.Module): | |
| """Batch random resize which synchronizes the random size across ranks. | |
| Args: | |
| random_size_range (tuple): The multi-scale random range during | |
| multi-scale training. | |
| interval (int): The iter interval of change | |
| image size. Defaults to 10. | |
| size_divisor (int): Image size divisible factor. | |
| Defaults to 32. | |
| """ | |
| def __init__(self, | |
| random_size_range: Tuple[int, int], | |
| interval: int = 10, | |
| size_divisor: int = 32) -> None: | |
| super().__init__() | |
| self.rank, self.world_size = get_dist_info() | |
| self._input_size = None | |
| self._random_size_range = (round(random_size_range[0] / size_divisor), | |
| round(random_size_range[1] / size_divisor)) | |
| self._interval = interval | |
| self._size_divisor = size_divisor | |
| def forward(self, inputs: Tensor, data_samples: List[PoseDataSample] | |
| ) -> Tuple[Tensor, List[PoseDataSample]]: | |
| """resize a batch of images and bboxes to shape ``self._input_size``""" | |
| h, w = inputs.shape[-2:] | |
| if self._input_size is None: | |
| self._input_size = (h, w) | |
| scale_y = self._input_size[0] / h | |
| scale_x = self._input_size[1] / w | |
| if scale_x != 1 or scale_y != 1: | |
| inputs = F.interpolate( | |
| inputs, | |
| size=self._input_size, | |
| mode='bilinear', | |
| align_corners=False) | |
| for data_sample in data_samples: | |
| img_shape = (int(data_sample.img_shape[0] * scale_y), | |
| int(data_sample.img_shape[1] * scale_x)) | |
| pad_shape = (int(data_sample.pad_shape[0] * scale_y), | |
| int(data_sample.pad_shape[1] * scale_x)) | |
| data_sample.set_metainfo({ | |
| 'img_shape': img_shape, | |
| 'pad_shape': pad_shape, | |
| 'batch_input_shape': self._input_size | |
| }) | |
| if 'gt_instance_labels' not in data_sample: | |
| continue | |
| if 'bboxes' in data_sample.gt_instance_labels: | |
| data_sample.gt_instance_labels.bboxes[..., 0::2] *= scale_x | |
| data_sample.gt_instance_labels.bboxes[..., 1::2] *= scale_y | |
| if 'keypoints' in data_sample.gt_instance_labels: | |
| data_sample.gt_instance_labels.keypoints[..., 0] *= scale_x | |
| data_sample.gt_instance_labels.keypoints[..., 1] *= scale_y | |
| if 'areas' in data_sample.gt_instance_labels: | |
| data_sample.gt_instance_labels.areas *= scale_x * scale_y | |
| if 'gt_fields' in data_sample \ | |
| and 'heatmap_mask' in data_sample.gt_fields: | |
| mask = data_sample.gt_fields.heatmap_mask.unsqueeze(0) | |
| gt_fields = PixelData() | |
| gt_fields.set_field( | |
| F.interpolate( | |
| mask.float(), | |
| size=self._input_size, | |
| mode='bilinear', | |
| align_corners=False).squeeze(0), 'heatmap_mask') | |
| data_sample.gt_fields = gt_fields | |
| message_hub = MessageHub.get_current_instance() | |
| if (message_hub.get_info('iter') + 1) % self._interval == 0: | |
| self._input_size = self._get_random_size( | |
| aspect_ratio=float(w / h), device=inputs.device) | |
| return inputs, data_samples | |
| def _get_random_size(self, aspect_ratio: float, | |
| device: torch.device) -> Tuple[int, int]: | |
| """Randomly generate a shape in ``_random_size_range`` and broadcast to | |
| all ranks.""" | |
| tensor = torch.LongTensor(2).to(device) | |
| if self.rank == 0: | |
| size = random.randint(*self._random_size_range) | |
| size = (self._size_divisor * size, | |
| self._size_divisor * int(aspect_ratio * size)) | |
| tensor[0] = size[0] | |
| tensor[1] = size[1] | |
| barrier() | |
| broadcast(tensor, 0) | |
| input_size = (tensor[0].item(), tensor[1].item()) | |
| return input_size | |