Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
| import torch | |
| from torch import nn | |
| import torch.distributed as dist | |
| import maskrcnn_benchmark.utils.comm as comm | |
| from torch.autograd.function import Function | |
| class FrozenBatchNorm2d(nn.Module): | |
| """ | |
| BatchNorm2d where the batch statistics and the affine parameters | |
| are fixed | |
| """ | |
| def __init__(self, n): | |
| super(FrozenBatchNorm2d, self).__init__() | |
| self.register_buffer("weight", torch.ones(n)) | |
| self.register_buffer("bias", torch.zeros(n)) | |
| self.register_buffer("running_mean", torch.zeros(n)) | |
| self.register_buffer("running_var", torch.ones(n)) | |
| def forward(self, x): | |
| scale = self.weight * self.running_var.rsqrt() | |
| bias = self.bias - self.running_mean * scale | |
| scale = scale.reshape(1, -1, 1, 1) | |
| bias = bias.reshape(1, -1, 1, 1) | |
| return x * scale + bias | |
| class AllReduce(Function): | |
| def forward(ctx, input): | |
| input_list = [torch.zeros_like(input) for k in range(dist.get_world_size())] | |
| # Use allgather instead of allreduce since I don't trust in-place operations .. | |
| dist.all_gather(input_list, input, async_op=False) | |
| inputs = torch.stack(input_list, dim=0) | |
| return torch.sum(inputs, dim=0) | |
| def backward(ctx, grad_output): | |
| dist.all_reduce(grad_output, async_op=False) | |
| return grad_output | |
| class NaiveSyncBatchNorm2d(nn.BatchNorm2d): | |
| """ | |
| In PyTorch<=1.5, ``nn.SyncBatchNorm`` has incorrect gradient | |
| when the batch size on each worker is different. | |
| (e.g., when scale augmentation is used, or when it is applied to mask head). | |
| This is a slower but correct alternative to `nn.SyncBatchNorm`. | |
| Note: | |
| There isn't a single definition of Sync BatchNorm. | |
| When ``stats_mode==""``, this module computes overall statistics by using | |
| statistics of each worker with equal weight. The result is true statistics | |
| of all samples (as if they are all on one worker) only when all workers | |
| have the same (N, H, W). This mode does not support inputs with zero batch size. | |
| When ``stats_mode=="N"``, this module computes overall statistics by weighting | |
| the statistics of each worker by their ``N``. The result is true statistics | |
| of all samples (as if they are all on one worker) only when all workers | |
| have the same (H, W). It is slower than ``stats_mode==""``. | |
| Even though the result of this module may not be the true statistics of all samples, | |
| it may still be reasonable because it might be preferrable to assign equal weights | |
| to all workers, regardless of their (H, W) dimension, instead of putting larger weight | |
| on larger images. From preliminary experiments, little difference is found between such | |
| a simplified implementation and an accurate computation of overall mean & variance. | |
| """ | |
| def __init__(self, *args, stats_mode="", **kwargs): | |
| super().__init__(*args, **kwargs) | |
| assert stats_mode in ["", "N"] | |
| self._stats_mode = stats_mode | |
| def forward(self, input): | |
| if comm.get_world_size() == 1 or not self.training: | |
| return super().forward(input) | |
| B, C = input.shape[0], input.shape[1] | |
| mean = torch.mean(input, dim=[0, 2, 3]) | |
| meansqr = torch.mean(input * input, dim=[0, 2, 3]) | |
| if self._stats_mode == "": | |
| assert B > 0, 'SyncBatchNorm(stats_mode="") does not support zero batch size.' | |
| vec = torch.cat([mean, meansqr], dim=0) | |
| vec = AllReduce.apply(vec) * (1.0 / dist.get_world_size()) | |
| mean, meansqr = torch.split(vec, C) | |
| momentum = self.momentum | |
| else: | |
| if B == 0: | |
| vec = torch.zeros([2 * C + 1], device=mean.device, dtype=mean.dtype) | |
| vec = vec + input.sum() # make sure there is gradient w.r.t input | |
| else: | |
| vec = torch.cat( | |
| [mean, meansqr, torch.ones([1], device=mean.device, dtype=mean.dtype)], dim=0 | |
| ) | |
| vec = AllReduce.apply(vec * B) | |
| total_batch = vec[-1].detach() | |
| momentum = total_batch.clamp(max=1) * self.momentum # no update if total_batch is 0 | |
| total_batch = torch.max(total_batch, torch.ones_like(total_batch)) # avoid div-by-zero | |
| mean, meansqr, _ = torch.split(vec / total_batch, C) | |
| var = meansqr - mean * mean | |
| invstd = torch.rsqrt(var + self.eps) | |
| scale = self.weight * invstd | |
| bias = self.bias - mean * scale | |
| scale = scale.reshape(1, -1, 1, 1) | |
| bias = bias.reshape(1, -1, 1, 1) | |
| self.running_mean += momentum * (mean.detach() - self.running_mean) | |
| self.running_var += momentum * (var.detach() - self.running_var) | |
| return input * scale + bias |