Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| from .deform_conv import ModulatedDeformConv | |
| from .dyrelu import h_sigmoid, DYReLU | |
| class Conv3x3Norm(torch.nn.Module): | |
| def __init__(self, | |
| in_channels, | |
| out_channels, | |
| stride, | |
| deformable=False, | |
| use_gn=False): | |
| super(Conv3x3Norm, self).__init__() | |
| if deformable: | |
| self.conv = ModulatedDeformConv(in_channels, out_channels, kernel_size=3, stride=stride, padding=1) | |
| else: | |
| self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1) | |
| if use_gn: | |
| self.bn = nn.GroupNorm(num_groups=16, num_channels=out_channels) | |
| else: | |
| self.bn = None | |
| def forward(self, input, **kwargs): | |
| x = self.conv(input, **kwargs) | |
| if self.bn: | |
| x = self.bn(x) | |
| return x | |
| class DyConv(nn.Module): | |
| def __init__(self, | |
| in_channels=256, | |
| out_channels=256, | |
| conv_func=Conv3x3Norm, | |
| use_dyfuse=True, | |
| use_dyrelu=False, | |
| use_deform=False | |
| ): | |
| super(DyConv, self).__init__() | |
| self.DyConv = nn.ModuleList() | |
| self.DyConv.append(conv_func(in_channels, out_channels, 1)) | |
| self.DyConv.append(conv_func(in_channels, out_channels, 1)) | |
| self.DyConv.append(conv_func(in_channels, out_channels, 2)) | |
| if use_dyfuse: | |
| self.AttnConv = nn.Sequential( | |
| nn.AdaptiveAvgPool2d(1), | |
| nn.Conv2d(in_channels, 1, kernel_size=1), | |
| nn.ReLU(inplace=True)) | |
| self.h_sigmoid = h_sigmoid() | |
| else: | |
| self.AttnConv = None | |
| if use_dyrelu: | |
| self.relu = DYReLU(in_channels, out_channels) | |
| else: | |
| self.relu = nn.ReLU() | |
| if use_deform: | |
| self.offset = nn.Conv2d(in_channels, 27, kernel_size=3, stride=1, padding=1) | |
| else: | |
| self.offset = None | |
| self.init_weights() | |
| def init_weights(self): | |
| for m in self.DyConv.modules(): | |
| if isinstance(m, nn.Conv2d): | |
| nn.init.normal_(m.weight.data, 0, 0.01) | |
| if m.bias is not None: | |
| m.bias.data.zero_() | |
| if self.AttnConv is not None: | |
| for m in self.AttnConv.modules(): | |
| if isinstance(m, nn.Conv2d): | |
| nn.init.normal_(m.weight.data, 0, 0.01) | |
| if m.bias is not None: | |
| m.bias.data.zero_() | |
| def forward(self, x): | |
| next_x = [] | |
| for level, feature in enumerate(x): | |
| conv_args = dict() | |
| if self.offset is not None: | |
| offset_mask = self.offset(feature) | |
| offset = offset_mask[:, :18, :, :] | |
| mask = offset_mask[:, 18:, :, :].sigmoid() | |
| conv_args = dict(offset=offset, mask=mask) | |
| temp_fea = [self.DyConv[1](feature, **conv_args)] | |
| if level > 0: | |
| temp_fea.append(self.DyConv[2](x[level - 1], **conv_args)) | |
| if level < len(x) - 1: | |
| temp_fea.append(F.upsample_bilinear(self.DyConv[0](x[level + 1], **conv_args), | |
| size=[feature.size(2), feature.size(3)])) | |
| mean_fea = torch.mean(torch.stack(temp_fea), dim=0, keepdim=False) | |
| if self.AttnConv is not None: | |
| attn_fea = [] | |
| res_fea = [] | |
| for fea in temp_fea: | |
| res_fea.append(fea) | |
| attn_fea.append(self.AttnConv(fea)) | |
| res_fea = torch.stack(res_fea) | |
| spa_pyr_attn = self.h_sigmoid(torch.stack(attn_fea)) | |
| mean_fea = torch.mean(res_fea * spa_pyr_attn, dim=0, keepdim=False) | |
| next_x.append(self.relu(mean_fea)) | |
| return next_x | |
| class DyHead(nn.Module): | |
| def __init__(self, cfg, in_channels): | |
| super(DyHead, self).__init__() | |
| self.cfg = cfg | |
| channels = cfg.MODEL.DYHEAD.CHANNELS | |
| use_gn = cfg.MODEL.DYHEAD.USE_GN | |
| use_dyrelu = cfg.MODEL.DYHEAD.USE_DYRELU | |
| use_dyfuse = cfg.MODEL.DYHEAD.USE_DYFUSE | |
| use_deform = cfg.MODEL.DYHEAD.USE_DFCONV | |
| conv_func = lambda i,o,s : Conv3x3Norm(i,o,s,deformable=use_deform,use_gn=use_gn) | |
| dyhead_tower = [] | |
| for i in range(cfg.MODEL.DYHEAD.NUM_CONVS): | |
| dyhead_tower.append( | |
| DyConv( | |
| in_channels if i == 0 else channels, | |
| channels, | |
| conv_func=conv_func, | |
| use_dyrelu=use_dyrelu, | |
| use_dyfuse=use_dyfuse, | |
| use_deform=use_deform | |
| ) | |
| ) | |
| self.add_module('dyhead_tower', nn.Sequential(*dyhead_tower)) | |
| def forward(self, x): | |
| dyhead_tower = self.dyhead_tower(x) | |
| return dyhead_tower |