|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | from __future__ import absolute_import, division, print_function | 
					
						
						|  |  | 
					
						
						|  | import warnings | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | import torch.nn.functional as F | 
					
						
						|  | from torch import nn | 
					
						
						|  | from torch.nn.init import constant_, xavier_uniform_ | 
					
						
						|  |  | 
					
						
						|  | from .dcnv3_func import DCNv3Function, dcnv3_core_pytorch, has_cuda_kernel | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class to_channels_first(nn.Module): | 
					
						
						|  |  | 
					
						
						|  | def __init__(self): | 
					
						
						|  | super().__init__() | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x): | 
					
						
						|  | return x.permute(0, 3, 1, 2) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class to_channels_last(nn.Module): | 
					
						
						|  |  | 
					
						
						|  | def __init__(self): | 
					
						
						|  | super().__init__() | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x): | 
					
						
						|  | return x.permute(0, 2, 3, 1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def build_norm_layer(dim, | 
					
						
						|  | norm_layer, | 
					
						
						|  | in_format='channels_last', | 
					
						
						|  | out_format='channels_last', | 
					
						
						|  | eps=1e-6): | 
					
						
						|  | layers = [] | 
					
						
						|  | if norm_layer == 'BN': | 
					
						
						|  | if in_format == 'channels_last': | 
					
						
						|  | layers.append(to_channels_first()) | 
					
						
						|  | layers.append(nn.BatchNorm2d(dim)) | 
					
						
						|  | if out_format == 'channels_last': | 
					
						
						|  | layers.append(to_channels_last()) | 
					
						
						|  | elif norm_layer == 'LN': | 
					
						
						|  | if in_format == 'channels_first': | 
					
						
						|  | layers.append(to_channels_last()) | 
					
						
						|  | layers.append(nn.LayerNorm(dim, eps=eps)) | 
					
						
						|  | if out_format == 'channels_first': | 
					
						
						|  | layers.append(to_channels_first()) | 
					
						
						|  | else: | 
					
						
						|  | raise NotImplementedError( | 
					
						
						|  | f'build_norm_layer does not support {norm_layer}') | 
					
						
						|  | return nn.Sequential(*layers) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def build_act_layer(act_layer): | 
					
						
						|  | if act_layer == 'ReLU': | 
					
						
						|  | return nn.ReLU(inplace=True) | 
					
						
						|  | elif act_layer == 'SiLU': | 
					
						
						|  | return nn.SiLU(inplace=True) | 
					
						
						|  | elif act_layer == 'GELU': | 
					
						
						|  | return nn.GELU() | 
					
						
						|  |  | 
					
						
						|  | raise NotImplementedError(f'build_act_layer does not support {act_layer}') | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def _is_power_of_2(n): | 
					
						
						|  | if (not isinstance(n, int)) or (n < 0): | 
					
						
						|  | raise ValueError( | 
					
						
						|  | 'invalid input for _is_power_of_2: {} (type: {})'.format(n, type(n))) | 
					
						
						|  |  | 
					
						
						|  | return (n & (n - 1) == 0) and n != 0 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class CenterFeatureScaleModule(nn.Module): | 
					
						
						|  | def forward(self, | 
					
						
						|  | query, | 
					
						
						|  | center_feature_scale_proj_weight, | 
					
						
						|  | center_feature_scale_proj_bias): | 
					
						
						|  | center_feature_scale = F.linear(query, | 
					
						
						|  | weight=center_feature_scale_proj_weight, | 
					
						
						|  | bias=center_feature_scale_proj_bias).sigmoid() | 
					
						
						|  | return center_feature_scale | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class DCNv3_pytorch(nn.Module): | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | channels=64, | 
					
						
						|  | kernel_size=3, | 
					
						
						|  | dw_kernel_size=None, | 
					
						
						|  | stride=1, | 
					
						
						|  | pad=1, | 
					
						
						|  | dilation=1, | 
					
						
						|  | group=4, | 
					
						
						|  | offset_scale=1.0, | 
					
						
						|  | act_layer='GELU', | 
					
						
						|  | norm_layer='LN', | 
					
						
						|  | center_feature_scale=False, | 
					
						
						|  | remove_center=False, | 
					
						
						|  | ): | 
					
						
						|  | """ | 
					
						
						|  | DCNv3 Module | 
					
						
						|  | :param channels | 
					
						
						|  | :param kernel_size | 
					
						
						|  | :param stride | 
					
						
						|  | :param pad | 
					
						
						|  | :param dilation | 
					
						
						|  | :param group | 
					
						
						|  | :param offset_scale | 
					
						
						|  | :param act_layer | 
					
						
						|  | :param norm_layer | 
					
						
						|  | """ | 
					
						
						|  | super().__init__() | 
					
						
						|  | if channels % group != 0: | 
					
						
						|  | raise ValueError( | 
					
						
						|  | f'channels must be divisible by group, but got {channels} and {group}') | 
					
						
						|  | _d_per_group = channels // group | 
					
						
						|  | dw_kernel_size = dw_kernel_size if dw_kernel_size is not None else kernel_size | 
					
						
						|  |  | 
					
						
						|  | if not _is_power_of_2(_d_per_group): | 
					
						
						|  | warnings.warn( | 
					
						
						|  | "You'd better set channels in DCNv3 to make the dimension of each attention head a power of 2 " | 
					
						
						|  | 'which is more efficient in our CUDA implementation.') | 
					
						
						|  |  | 
					
						
						|  | self.offset_scale = offset_scale | 
					
						
						|  | self.channels = channels | 
					
						
						|  | self.kernel_size = kernel_size | 
					
						
						|  | self.dw_kernel_size = dw_kernel_size | 
					
						
						|  | self.stride = stride | 
					
						
						|  | self.dilation = dilation | 
					
						
						|  | self.pad = pad | 
					
						
						|  | self.group = group | 
					
						
						|  | self.group_channels = channels // group | 
					
						
						|  | self.offset_scale = offset_scale | 
					
						
						|  | self.center_feature_scale = center_feature_scale | 
					
						
						|  | self.remove_center = int(remove_center) | 
					
						
						|  |  | 
					
						
						|  | self.dw_conv = nn.Sequential( | 
					
						
						|  | nn.Conv2d( | 
					
						
						|  | channels, | 
					
						
						|  | channels, | 
					
						
						|  | kernel_size=dw_kernel_size, | 
					
						
						|  | stride=1, | 
					
						
						|  | padding=(dw_kernel_size - 1) // 2, | 
					
						
						|  | groups=channels), | 
					
						
						|  | build_norm_layer( | 
					
						
						|  | channels, | 
					
						
						|  | norm_layer, | 
					
						
						|  | 'channels_first', | 
					
						
						|  | 'channels_last'), | 
					
						
						|  | build_act_layer(act_layer)) | 
					
						
						|  | self.offset = nn.Linear( | 
					
						
						|  | channels, | 
					
						
						|  | group * (kernel_size * kernel_size - remove_center) * 2) | 
					
						
						|  | self.mask = nn.Linear( | 
					
						
						|  | channels, | 
					
						
						|  | group * (kernel_size * kernel_size - remove_center)) | 
					
						
						|  | self.input_proj = nn.Linear(channels, channels) | 
					
						
						|  | self.output_proj = nn.Linear(channels, channels) | 
					
						
						|  | self._reset_parameters() | 
					
						
						|  |  | 
					
						
						|  | if center_feature_scale: | 
					
						
						|  | self.center_feature_scale_proj_weight = nn.Parameter( | 
					
						
						|  | torch.zeros((group, channels), dtype=torch.float)) | 
					
						
						|  | self.center_feature_scale_proj_bias = nn.Parameter( | 
					
						
						|  | torch.tensor(0.0, dtype=torch.float).view((1,)).repeat(group, )) | 
					
						
						|  | self.center_feature_scale_module = CenterFeatureScaleModule() | 
					
						
						|  |  | 
					
						
						|  | def _reset_parameters(self): | 
					
						
						|  | constant_(self.offset.weight.data, 0.) | 
					
						
						|  | constant_(self.offset.bias.data, 0.) | 
					
						
						|  | constant_(self.mask.weight.data, 0.) | 
					
						
						|  | constant_(self.mask.bias.data, 0.) | 
					
						
						|  | xavier_uniform_(self.input_proj.weight.data) | 
					
						
						|  | constant_(self.input_proj.bias.data, 0.) | 
					
						
						|  | xavier_uniform_(self.output_proj.weight.data) | 
					
						
						|  | constant_(self.output_proj.bias.data, 0.) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, input): | 
					
						
						|  | """ | 
					
						
						|  | :param query                       (N, H, W, C) | 
					
						
						|  | :return output                     (N, H, W, C) | 
					
						
						|  | """ | 
					
						
						|  | N, H, W, _ = input.shape | 
					
						
						|  |  | 
					
						
						|  | x = self.input_proj(input) | 
					
						
						|  | x_proj = x | 
					
						
						|  |  | 
					
						
						|  | x1 = input.permute(0, 3, 1, 2) | 
					
						
						|  | x1 = self.dw_conv(x1) | 
					
						
						|  | offset = self.offset(x1) | 
					
						
						|  | mask = self.mask(x1).reshape(N, H, W, self.group, -1) | 
					
						
						|  | mask = F.softmax(mask, -1).reshape(N, H, W, -1) | 
					
						
						|  |  | 
					
						
						|  | x = dcnv3_core_pytorch( | 
					
						
						|  | x, offset, mask, | 
					
						
						|  | self.kernel_size, self.kernel_size, | 
					
						
						|  | self.stride, self.stride, | 
					
						
						|  | self.pad, self.pad, | 
					
						
						|  | self.dilation, self.dilation, | 
					
						
						|  | self.group, self.group_channels, | 
					
						
						|  | self.offset_scale, self.remove_center) | 
					
						
						|  | if self.center_feature_scale: | 
					
						
						|  | center_feature_scale = self.center_feature_scale_module( | 
					
						
						|  | x1, self.center_feature_scale_proj_weight, self.center_feature_scale_proj_bias) | 
					
						
						|  |  | 
					
						
						|  | center_feature_scale = center_feature_scale[..., None].repeat( | 
					
						
						|  | 1, 1, 1, 1, self.channels // self.group).flatten(-2) | 
					
						
						|  | x = x * (1 - center_feature_scale) + x_proj * center_feature_scale | 
					
						
						|  | x = self.output_proj(x) | 
					
						
						|  |  | 
					
						
						|  | return x | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class DCNv3(nn.Module): | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | channels=64, | 
					
						
						|  | kernel_size=3, | 
					
						
						|  | dw_kernel_size=None, | 
					
						
						|  | stride=1, | 
					
						
						|  | pad=1, | 
					
						
						|  | dilation=1, | 
					
						
						|  | group=4, | 
					
						
						|  | offset_scale=1.0, | 
					
						
						|  | act_layer='GELU', | 
					
						
						|  | norm_layer='LN', | 
					
						
						|  | center_feature_scale=False, | 
					
						
						|  | remove_center=False, | 
					
						
						|  | ): | 
					
						
						|  | """ | 
					
						
						|  | DCNv3 Module | 
					
						
						|  | :param channels | 
					
						
						|  | :param kernel_size | 
					
						
						|  | :param stride | 
					
						
						|  | :param pad | 
					
						
						|  | :param dilation | 
					
						
						|  | :param group | 
					
						
						|  | :param offset_scale | 
					
						
						|  | :param act_layer | 
					
						
						|  | :param norm_layer | 
					
						
						|  | """ | 
					
						
						|  | super().__init__() | 
					
						
						|  | if channels % group != 0: | 
					
						
						|  | raise ValueError( | 
					
						
						|  | f'channels must be divisible by group, but got {channels} and {group}') | 
					
						
						|  | _d_per_group = channels // group | 
					
						
						|  | dw_kernel_size = dw_kernel_size if dw_kernel_size is not None else kernel_size | 
					
						
						|  |  | 
					
						
						|  | if not _is_power_of_2(_d_per_group): | 
					
						
						|  | warnings.warn( | 
					
						
						|  | "You'd better set channels in DCNv3 to make the dimension of each attention head a power of 2 " | 
					
						
						|  | 'which is more efficient in our CUDA implementation.') | 
					
						
						|  |  | 
					
						
						|  | self.offset_scale = offset_scale | 
					
						
						|  | self.channels = channels | 
					
						
						|  | self.kernel_size = kernel_size | 
					
						
						|  | self.dw_kernel_size = dw_kernel_size | 
					
						
						|  | self.stride = stride | 
					
						
						|  | self.dilation = dilation | 
					
						
						|  | self.pad = pad | 
					
						
						|  | self.group = group | 
					
						
						|  | self.group_channels = channels // group | 
					
						
						|  | self.offset_scale = offset_scale | 
					
						
						|  | self.center_feature_scale = center_feature_scale | 
					
						
						|  | self.remove_center = int(remove_center) | 
					
						
						|  |  | 
					
						
						|  | if self.remove_center and self.kernel_size % 2 == 0: | 
					
						
						|  | raise ValueError('remove_center is only compatible with odd kernel size.') | 
					
						
						|  |  | 
					
						
						|  | self.dw_conv = nn.Sequential( | 
					
						
						|  | nn.Conv2d( | 
					
						
						|  | channels, | 
					
						
						|  | channels, | 
					
						
						|  | kernel_size=dw_kernel_size, | 
					
						
						|  | stride=1, | 
					
						
						|  | padding=(dw_kernel_size - 1) // 2, | 
					
						
						|  | groups=channels), | 
					
						
						|  | build_norm_layer( | 
					
						
						|  | channels, | 
					
						
						|  | norm_layer, | 
					
						
						|  | 'channels_first', | 
					
						
						|  | 'channels_last'), | 
					
						
						|  | build_act_layer(act_layer)) | 
					
						
						|  | self.offset = nn.Linear( | 
					
						
						|  | channels, | 
					
						
						|  | group * (kernel_size * kernel_size - remove_center) * 2) | 
					
						
						|  | self.mask = nn.Linear( | 
					
						
						|  | channels, | 
					
						
						|  | group * (kernel_size * kernel_size - remove_center)) | 
					
						
						|  | self.input_proj = nn.Linear(channels, channels) | 
					
						
						|  | self.output_proj = nn.Linear(channels, channels) | 
					
						
						|  | self._reset_parameters() | 
					
						
						|  |  | 
					
						
						|  | if center_feature_scale: | 
					
						
						|  | self.center_feature_scale_proj_weight = nn.Parameter( | 
					
						
						|  | torch.zeros((group, channels), dtype=torch.float)) | 
					
						
						|  | self.center_feature_scale_proj_bias = nn.Parameter( | 
					
						
						|  | torch.tensor(0.0, dtype=torch.float).view((1,)).repeat(group, )) | 
					
						
						|  | self.center_feature_scale_module = CenterFeatureScaleModule() | 
					
						
						|  |  | 
					
						
						|  | def _reset_parameters(self): | 
					
						
						|  | constant_(self.offset.weight.data, 0.) | 
					
						
						|  | constant_(self.offset.bias.data, 0.) | 
					
						
						|  | constant_(self.mask.weight.data, 0.) | 
					
						
						|  | constant_(self.mask.bias.data, 0.) | 
					
						
						|  | xavier_uniform_(self.input_proj.weight.data) | 
					
						
						|  | constant_(self.input_proj.bias.data, 0.) | 
					
						
						|  | xavier_uniform_(self.output_proj.weight.data) | 
					
						
						|  | constant_(self.output_proj.bias.data, 0.) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, input): | 
					
						
						|  | """ | 
					
						
						|  | :param query                       (N, H, W, C) | 
					
						
						|  | :return output                     (N, H, W, C) | 
					
						
						|  | """ | 
					
						
						|  | N, H, W, _ = input.shape | 
					
						
						|  |  | 
					
						
						|  | x = self.input_proj(input) | 
					
						
						|  | x_proj = x | 
					
						
						|  | dtype = x.dtype | 
					
						
						|  |  | 
					
						
						|  | x1 = input.permute(0, 3, 1, 2) | 
					
						
						|  | x1 = self.dw_conv(x1) | 
					
						
						|  | offset = self.offset(x1) | 
					
						
						|  | mask = self.mask(x1).reshape(N, H, W, self.group, -1) | 
					
						
						|  | mask = F.softmax(mask, -1) | 
					
						
						|  | mask = mask.reshape(N, H, W, -1).type(dtype) | 
					
						
						|  |  | 
					
						
						|  | x = DCNv3Function.apply( | 
					
						
						|  | x, offset, mask, | 
					
						
						|  | self.kernel_size, self.kernel_size, | 
					
						
						|  | self.stride, self.stride, | 
					
						
						|  | self.pad, self.pad, | 
					
						
						|  | self.dilation, self.dilation, | 
					
						
						|  | self.group, self.group_channels, | 
					
						
						|  | self.offset_scale, | 
					
						
						|  | 256, | 
					
						
						|  | self.remove_center) | 
					
						
						|  |  | 
					
						
						|  | if self.center_feature_scale: | 
					
						
						|  | center_feature_scale = self.center_feature_scale_module( | 
					
						
						|  | x1, self.center_feature_scale_proj_weight, self.center_feature_scale_proj_bias) | 
					
						
						|  |  | 
					
						
						|  | center_feature_scale = center_feature_scale[..., None].repeat( | 
					
						
						|  | 1, 1, 1, 1, self.channels // self.group).flatten(-2) | 
					
						
						|  | x = x * (1 - center_feature_scale) + x_proj * center_feature_scale | 
					
						
						|  | x = self.output_proj(x) | 
					
						
						|  |  | 
					
						
						|  | return x | 
					
						
						|  |  |