Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import types | |
| from typing import Dict, Optional | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer | |
| from mmengine.model import BaseModule | |
| from torch import Tensor | |
| from mmpose.utils.typing import OptConfigType | |
| class RepVGGBlock(BaseModule): | |
| """A block in RepVGG architecture, supporting optional normalization in the | |
| identity branch. | |
| This block consists of 3x3 and 1x1 convolutions, with an optional identity | |
| shortcut branch that includes normalization. | |
| Args: | |
| in_channels (int): The input channels of the block. | |
| out_channels (int): The output channels of the block. | |
| stride (int): The stride of the block. Defaults to 1. | |
| padding (int): The padding of the block. Defaults to 1. | |
| dilation (int): The dilation of the block. Defaults to 1. | |
| groups (int): The groups of the block. Defaults to 1. | |
| padding_mode (str): The padding mode of the block. Defaults to 'zeros'. | |
| norm_cfg (dict): The config dict for normalization layers. | |
| Defaults to dict(type='BN'). | |
| act_cfg (dict): The config dict for activation layers. | |
| Defaults to dict(type='ReLU'). | |
| without_branch_norm (bool): Whether to skip branch_norm. | |
| Defaults to True. | |
| init_cfg (dict): The config dict for initialization. Defaults to None. | |
| """ | |
| def __init__(self, | |
| in_channels: int, | |
| out_channels: int, | |
| stride: int = 1, | |
| padding: int = 1, | |
| dilation: int = 1, | |
| groups: int = 1, | |
| padding_mode: str = 'zeros', | |
| norm_cfg: OptConfigType = dict(type='BN'), | |
| act_cfg: OptConfigType = dict(type='ReLU'), | |
| without_branch_norm: bool = True, | |
| init_cfg: OptConfigType = None): | |
| super(RepVGGBlock, self).__init__(init_cfg) | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.stride = stride | |
| self.padding = padding | |
| self.dilation = dilation | |
| self.groups = groups | |
| self.norm_cfg = norm_cfg | |
| self.act_cfg = act_cfg | |
| # judge if input shape and output shape are the same. | |
| # If true, add a normalized identity shortcut. | |
| self.branch_norm = None | |
| if out_channels == in_channels and stride == 1 and \ | |
| padding == dilation and not without_branch_norm: | |
| self.branch_norm = build_norm_layer(norm_cfg, in_channels)[1] | |
| self.branch_3x3 = ConvModule( | |
| self.in_channels, | |
| self.out_channels, | |
| 3, | |
| stride=self.stride, | |
| padding=self.padding, | |
| groups=self.groups, | |
| dilation=self.dilation, | |
| norm_cfg=self.norm_cfg, | |
| act_cfg=None) | |
| self.branch_1x1 = ConvModule( | |
| self.in_channels, | |
| self.out_channels, | |
| 1, | |
| groups=self.groups, | |
| norm_cfg=self.norm_cfg, | |
| act_cfg=None) | |
| self.act = build_activation_layer(act_cfg) | |
| def forward(self, x: Tensor) -> Tensor: | |
| """Forward pass through the RepVGG block. | |
| The output is the sum of 3x3 and 1x1 convolution outputs, | |
| along with the normalized identity branch output, followed by | |
| activation. | |
| Args: | |
| x (Tensor): The input tensor. | |
| Returns: | |
| Tensor: The output tensor. | |
| """ | |
| if self.branch_norm is None: | |
| branch_norm_out = 0 | |
| else: | |
| branch_norm_out = self.branch_norm(x) | |
| out = self.branch_3x3(x) + self.branch_1x1(x) + branch_norm_out | |
| out = self.act(out) | |
| return out | |
| def _pad_1x1_to_3x3_tensor(self, kernel1x1): | |
| """Pad 1x1 tensor to 3x3. | |
| Args: | |
| kernel1x1 (Tensor): The input 1x1 kernel need to be padded. | |
| Returns: | |
| Tensor: 3x3 kernel after padded. | |
| """ | |
| if kernel1x1 is None: | |
| return 0 | |
| else: | |
| return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1]) | |
| def _fuse_bn_tensor(self, branch: nn.Module) -> Tensor: | |
| """Derives the equivalent kernel and bias of a specific branch layer. | |
| Args: | |
| branch (nn.Module): The layer that needs to be equivalently | |
| transformed, which can be nn.Sequential or nn.Batchnorm2d | |
| Returns: | |
| tuple: Equivalent kernel and bias | |
| """ | |
| if branch is None: | |
| return 0, 0 | |
| if isinstance(branch, ConvModule): | |
| kernel = branch.conv.weight | |
| running_mean = branch.bn.running_mean | |
| running_var = branch.bn.running_var | |
| gamma = branch.bn.weight | |
| beta = branch.bn.bias | |
| eps = branch.bn.eps | |
| else: | |
| assert isinstance(branch, (nn.SyncBatchNorm, nn.BatchNorm2d)) | |
| if not hasattr(self, 'id_tensor'): | |
| input_dim = self.in_channels // self.groups | |
| kernel_value = np.zeros((self.in_channels, input_dim, 3, 3), | |
| dtype=np.float32) | |
| for i in range(self.in_channels): | |
| kernel_value[i, i % input_dim, 1, 1] = 1 | |
| self.id_tensor = torch.from_numpy(kernel_value).to( | |
| branch.weight.device) | |
| kernel = self.id_tensor | |
| running_mean = branch.running_mean | |
| running_var = branch.running_var | |
| gamma = branch.weight | |
| beta = branch.bias | |
| eps = branch.eps | |
| std = (running_var + eps).sqrt() | |
| t = (gamma / std).reshape(-1, 1, 1, 1) | |
| return kernel * t, beta - running_mean * gamma / std | |
| def get_equivalent_kernel_bias(self): | |
| """Derives the equivalent kernel and bias in a differentiable way. | |
| Returns: | |
| tuple: Equivalent kernel and bias | |
| """ | |
| kernel3x3, bias3x3 = self._fuse_bn_tensor(self.branch_3x3) | |
| kernel1x1, bias1x1 = self._fuse_bn_tensor(self.branch_1x1) | |
| kernelid, biasid = (0, 0) if self.branch_norm is None else \ | |
| self._fuse_bn_tensor(self.branch_norm) | |
| return (kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, | |
| bias3x3 + bias1x1 + biasid) | |
| def switch_to_deploy(self, test_cfg: Optional[Dict] = None): | |
| """Switches the block to deployment mode. | |
| In deployment mode, the block uses a single convolution operation | |
| derived from the equivalent kernel and bias, replacing the original | |
| branches. This reduces computational complexity during inference. | |
| """ | |
| if getattr(self, 'deploy', False): | |
| return | |
| kernel, bias = self.get_equivalent_kernel_bias() | |
| self.conv_reparam = nn.Conv2d( | |
| in_channels=self.branch_3x3.conv.in_channels, | |
| out_channels=self.branch_3x3.conv.out_channels, | |
| kernel_size=self.branch_3x3.conv.kernel_size, | |
| stride=self.branch_3x3.conv.stride, | |
| padding=self.branch_3x3.conv.padding, | |
| dilation=self.branch_3x3.conv.dilation, | |
| groups=self.branch_3x3.conv.groups, | |
| bias=True) | |
| self.conv_reparam.weight.data = kernel | |
| self.conv_reparam.bias.data = bias | |
| for para in self.parameters(): | |
| para.detach_() | |
| self.__delattr__('branch_3x3') | |
| self.__delattr__('branch_1x1') | |
| if hasattr(self, 'branch_norm'): | |
| self.__delattr__('branch_norm') | |
| def _forward(self, x): | |
| return self.act(self.conv_reparam(x)) | |
| self.forward = types.MethodType(_forward, self) | |
| self.deploy = True | |