Spaces:
Runtime error
Runtime error
| """ Split Attention Conv2d (for ResNeSt Models) | |
| Paper: `ResNeSt: Split-Attention Networks` - /https://arxiv.org/abs/2004.08955 | |
| Adapted from original PyTorch impl at https://github.com/zhanghang1989/ResNeSt | |
| Modified for torchscript compat, performance, and consistency with timm by Ross Wightman | |
| """ | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| from .helpers import make_divisible | |
| class RadixSoftmax(nn.Module): | |
| def __init__(self, radix, cardinality): | |
| super(RadixSoftmax, self).__init__() | |
| self.radix = radix | |
| self.cardinality = cardinality | |
| def forward(self, x): | |
| batch = x.size(0) | |
| if self.radix > 1: | |
| x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2) | |
| x = F.softmax(x, dim=1) | |
| x = x.reshape(batch, -1) | |
| else: | |
| x = torch.sigmoid(x) | |
| return x | |
| class SplitAttn(nn.Module): | |
| """Split-Attention (aka Splat) | |
| """ | |
| def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=None, | |
| dilation=1, groups=1, bias=False, radix=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8, | |
| act_layer=nn.ReLU, norm_layer=None, drop_block=None, **kwargs): | |
| super(SplitAttn, self).__init__() | |
| out_channels = out_channels or in_channels | |
| self.radix = radix | |
| self.drop_block = drop_block | |
| mid_chs = out_channels * radix | |
| if rd_channels is None: | |
| attn_chs = make_divisible(in_channels * radix * rd_ratio, min_value=32, divisor=rd_divisor) | |
| else: | |
| attn_chs = rd_channels * radix | |
| padding = kernel_size // 2 if padding is None else padding | |
| self.conv = nn.Conv2d( | |
| in_channels, mid_chs, kernel_size, stride, padding, dilation, | |
| groups=groups * radix, bias=bias, **kwargs) | |
| self.bn0 = norm_layer(mid_chs) if norm_layer else nn.Identity() | |
| self.act0 = act_layer(inplace=True) | |
| self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups) | |
| self.bn1 = norm_layer(attn_chs) if norm_layer else nn.Identity() | |
| self.act1 = act_layer(inplace=True) | |
| self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups) | |
| self.rsoftmax = RadixSoftmax(radix, groups) | |
| def forward(self, x): | |
| x = self.conv(x) | |
| x = self.bn0(x) | |
| if self.drop_block is not None: | |
| x = self.drop_block(x) | |
| x = self.act0(x) | |
| B, RC, H, W = x.shape | |
| if self.radix > 1: | |
| x = x.reshape((B, self.radix, RC // self.radix, H, W)) | |
| x_gap = x.sum(dim=1) | |
| else: | |
| x_gap = x | |
| x_gap = x_gap.mean((2, 3), keepdim=True) | |
| x_gap = self.fc1(x_gap) | |
| x_gap = self.bn1(x_gap) | |
| x_gap = self.act1(x_gap) | |
| x_attn = self.fc2(x_gap) | |
| x_attn = self.rsoftmax(x_attn).view(B, -1, 1, 1) | |
| if self.radix > 1: | |
| out = (x * x_attn.reshape((B, self.radix, RC // self.radix, 1, 1))).sum(dim=1) | |
| else: | |
| out = x * x_attn | |
| return out.contiguous() | |