Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
| """ | |
| Miscellaneous utility functions | |
| """ | |
| import torch | |
| def cat(tensors, dim=0): | |
| """ | |
| Efficient version of torch.cat that avoids a copy if there is only a single element in a list | |
| """ | |
| assert isinstance(tensors, (list, tuple)) | |
| if len(tensors) == 1: | |
| return tensors[0] | |
| return torch.cat(tensors, dim) | |
| def permute_and_flatten(layer, N, A, C, H, W): | |
| layer = layer.view(N, -1, C, H, W) | |
| layer = layer.permute(0, 3, 4, 1, 2) | |
| layer = layer.reshape(N, -1, C) | |
| return layer | |
| def concat_box_prediction_layers(box_regression, box_cls=None, token_logits=None): | |
| box_regression_flattened = [] | |
| box_cls_flattened = [] | |
| token_logit_flattened = [] | |
| # for each feature level, permute the outputs to make them be in the | |
| # same format as the labels. Note that the labels are computed for | |
| # all feature levels concatenated, so we keep the same representation | |
| # for the objectness and the box_regression | |
| for box_cls_per_level, box_regression_per_level in zip( | |
| box_cls, box_regression | |
| ): | |
| N, AxC, H, W = box_cls_per_level.shape | |
| Ax4 = box_regression_per_level.shape[1] | |
| A = Ax4 // 4 | |
| C = AxC // A | |
| box_cls_per_level = permute_and_flatten( | |
| box_cls_per_level, N, A, C, H, W | |
| ) | |
| box_cls_flattened.append(box_cls_per_level) | |
| box_regression_per_level = permute_and_flatten( | |
| box_regression_per_level, N, A, 4, H, W | |
| ) | |
| box_regression_flattened.append(box_regression_per_level) | |
| if token_logits is not None: | |
| for token_logit_per_level in token_logits: | |
| N, AXT, H, W = token_logit_per_level.shape | |
| T = AXT // A | |
| token_logit_per_level = permute_and_flatten( | |
| token_logit_per_level, N, A, T, H, W | |
| ) | |
| token_logit_flattened.append(token_logit_per_level) | |
| # concatenate on the first dimension (representing the feature levels), to | |
| # take into account the way the labels were generated (with all feature maps | |
| # being concatenated as well) | |
| box_cls = cat(box_cls_flattened, dim=1).reshape(-1, C) | |
| box_regression = cat(box_regression_flattened, dim=1).reshape(-1, 4) | |
| token_logits_stacked = None | |
| if token_logits is not None: | |
| # stacked | |
| token_logits_stacked = cat(token_logit_flattened, dim=1) | |
| return box_regression, box_cls, token_logits_stacked | |
| def round_channels(channels, divisor=8): | |
| rounded_channels = max(int(channels + divisor / 2.0) // divisor * divisor, divisor) | |
| if float(rounded_channels) < 0.9 * channels: | |
| rounded_channels += divisor | |
| return rounded_channels | |