Spaces:
Runtime error
Runtime error
| import torch.nn as nn | |
| import torch | |
| from maskrcnn_benchmark.layers import swish | |
| class BiFPN(nn.Module): | |
| def __init__(self, in_channels_list, out_channels, first_time=False, epsilon=1e-4, attention=True): | |
| super(BiFPN, self).__init__() | |
| self.epsilon = epsilon | |
| # Conv layers | |
| self.conv6_up = nn.Sequential( | |
| nn.Conv2d(out_channels, out_channels, 3, groups=out_channels, bias=False), | |
| nn.Conv2d(out_channels, out_channels, 1), | |
| nn.BatchNorm2d(out_channels, momentum=0.01, eps=1e-3), | |
| ) | |
| self.conv5_up = nn.Sequential( | |
| nn.Conv2d(out_channels, out_channels, 3, groups=out_channels, bias=False), | |
| nn.Conv2d(out_channels, out_channels, 1), | |
| nn.BatchNorm2d(out_channels, momentum=0.01, eps=1e-3), | |
| ) | |
| self.conv4_up = nn.Sequential( | |
| nn.Conv2d(out_channels, out_channels, 3, groups=out_channels, bias=False), | |
| nn.Conv2d(out_channels, out_channels, 1), | |
| nn.BatchNorm2d(out_channels, momentum=0.01, eps=1e-3), | |
| ) | |
| self.conv3_up = nn.Sequential( | |
| nn.Conv2d(out_channels, out_channels, 3, groups=out_channels, bias=False), | |
| nn.Conv2d(out_channels, out_channels, 1), | |
| nn.BatchNorm2d(out_channels, momentum=0.01, eps=1e-3), | |
| ) | |
| self.conv4_down = nn.Sequential( | |
| nn.Conv2d(out_channels, out_channels, 3, groups=out_channels, bias=False), | |
| nn.Conv2d(out_channels, out_channels, 1), | |
| nn.BatchNorm2d(out_channels, momentum=0.01, eps=1e-3), | |
| ) | |
| self.conv5_down = nn.Sequential( | |
| nn.Conv2d(out_channels, out_channels, 3, groups=out_channels, bias=False), | |
| nn.Conv2d(out_channels, out_channels, 1), | |
| nn.BatchNorm2d(out_channels, momentum=0.01, eps=1e-3), | |
| ) | |
| self.conv6_down = nn.Sequential( | |
| nn.Conv2d(out_channels, out_channels, 3, groups=out_channels, bias=False), | |
| nn.Conv2d(out_channels, out_channels, 1), | |
| nn.BatchNorm2d(out_channels, momentum=0.01, eps=1e-3), | |
| ) | |
| self.conv7_down = nn.Sequential( | |
| nn.Conv2d(out_channels, out_channels, 3, groups=out_channels, bias=False), | |
| nn.Conv2d(out_channels, out_channels, 1), | |
| nn.BatchNorm2d(out_channels, momentum=0.01, eps=1e-3), | |
| ) | |
| # Feature scaling layers | |
| self.p6_upsample = nn.Upsample(scale_factor=2, mode='nearest') | |
| self.p5_upsample = nn.Upsample(scale_factor=2, mode='nearest') | |
| self.p4_upsample = nn.Upsample(scale_factor=2, mode='nearest') | |
| self.p3_upsample = nn.Upsample(scale_factor=2, mode='nearest') | |
| self.p4_downsample = nn.MaxPool2d(3, 2) | |
| self.p5_downsample = nn.MaxPool2d(3, 2) | |
| self.p6_downsample = nn.MaxPool2d(3, 2) | |
| self.p7_downsample = nn.MaxPool2d(3, 2) | |
| self.swish = swish() | |
| self.first_time = first_time | |
| if self.first_time: | |
| self.p5_down_channel = nn.Sequential( | |
| nn.Conv2d(in_channels_list[2], out_channels, 1), | |
| nn.BatchNorm2d(out_channels, momentum=0.01, eps=1e-3), | |
| ) | |
| self.p4_down_channel = nn.Sequential( | |
| nn.Conv2d(in_channels_list[1], out_channels, 1), | |
| nn.BatchNorm2d(out_channels, momentum=0.01, eps=1e-3), | |
| ) | |
| self.p3_down_channel = nn.Sequential( | |
| nn.Conv2d(in_channels_list[0], out_channels, 1), | |
| nn.BatchNorm2d(out_channels, momentum=0.01, eps=1e-3), | |
| ) | |
| self.p5_to_p6 = nn.Sequential( | |
| nn.Conv2d(in_channels_list[2], out_channels, 1), | |
| nn.BatchNorm2d(out_channels, momentum=0.01, eps=1e-3), | |
| nn.MaxPool2d(3, 2) | |
| ) | |
| self.p6_to_p7 = nn.Sequential( | |
| nn.MaxPool2d(3, 2) | |
| ) | |
| self.p4_down_channel_2 = nn.Sequential( | |
| nn.Conv2d(in_channels_list[1], out_channels, 1), | |
| nn.BatchNorm2d(out_channels, momentum=0.01, eps=1e-3), | |
| ) | |
| self.p5_down_channel_2 = nn.Sequential( | |
| nn.Conv2d(in_channels_list[2], out_channels, 1), | |
| nn.BatchNorm2d(out_channels, momentum=0.01, eps=1e-3), | |
| ) | |
| # Weight | |
| self.p6_w1 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True) | |
| self.p6_w1_relu = nn.ReLU() | |
| self.p5_w1 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True) | |
| self.p5_w1_relu = nn.ReLU() | |
| self.p4_w1 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True) | |
| self.p4_w1_relu = nn.ReLU() | |
| self.p3_w1 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True) | |
| self.p3_w1_relu = nn.ReLU() | |
| self.p4_w2 = nn.Parameter(torch.ones(3, dtype=torch.float32), requires_grad=True) | |
| self.p4_w2_relu = nn.ReLU() | |
| self.p5_w2 = nn.Parameter(torch.ones(3, dtype=torch.float32), requires_grad=True) | |
| self.p5_w2_relu = nn.ReLU() | |
| self.p6_w2 = nn.Parameter(torch.ones(3, dtype=torch.float32), requires_grad=True) | |
| self.p6_w2_relu = nn.ReLU() | |
| self.p7_w2 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True) | |
| self.p7_w2_relu = nn.ReLU() | |
| self.attention = attention | |
| def forward(self, inputs): | |
| """ | |
| illustration of a minimal bifpn unit | |
| P7_0 -------------------------> P7_2 --------> | |
| |-------------| ↑ | |
| ↓ | | |
| P6_0 ---------> P6_1 ---------> P6_2 --------> | |
| |-------------|--------------↑ ↑ | |
| ↓ | | |
| P5_0 ---------> P5_1 ---------> P5_2 --------> | |
| |-------------|--------------↑ ↑ | |
| ↓ | | |
| P4_0 ---------> P4_1 ---------> P4_2 --------> | |
| |-------------|--------------↑ ↑ | |
| |--------------↓ | | |
| P3_0 -------------------------> P3_2 --------> | |
| """ | |
| # downsample channels using same-padding conv2d to target phase's if not the same | |
| # judge: same phase as target, | |
| # if same, pass; | |
| # elif earlier phase, downsample to target phase's by pooling | |
| # elif later phase, upsample to target phase's by nearest interpolation | |
| if self.attention: | |
| p3_out, p4_out, p5_out, p6_out, p7_out = self._forward_fast_attention(inputs) | |
| else: | |
| p3_out, p4_out, p5_out, p6_out, p7_out = self._forward(inputs) | |
| return p3_out, p4_out, p5_out, p6_out, p7_out | |
| def _forward_fast_attention(self, inputs): | |
| if self.first_time: | |
| p3, p4, p5 = inputs[-3:] | |
| p6_in = self.p5_to_p6(p5) | |
| p7_in = self.p6_to_p7(p6_in) | |
| p3_in = self.p3_down_channel(p3) | |
| p4_in = self.p4_down_channel(p4) | |
| p5_in = self.p5_down_channel(p5) | |
| else: | |
| # P3_0, P4_0, P5_0, P6_0 and P7_0 | |
| p3_in, p4_in, p5_in, p6_in, p7_in = inputs | |
| # P7_0 to P7_2 | |
| # Weights for P6_0 and P7_0 to P6_1 | |
| p6_w1 = self.p6_w1_relu(self.p6_w1) | |
| weight = p6_w1 / (torch.sum(p6_w1, dim=0) + self.epsilon) | |
| # Connections for P6_0 and P7_0 to P6_1 respectively | |
| p6_up = self.conv6_up(self.swish(weight[0] * p6_in + weight[1] * self.p6_upsample(p7_in))) | |
| # Weights for P5_0 and P6_1 to P5_1 | |
| p5_w1 = self.p5_w1_relu(self.p5_w1) | |
| weight = p5_w1 / (torch.sum(p5_w1, dim=0) + self.epsilon) | |
| # Connections for P5_0 and P6_1 to P5_1 respectively | |
| p5_up = self.conv5_up(self.swish(weight[0] * p5_in + weight[1] * self.p5_upsample(p6_up))) | |
| # Weights for P4_0 and P5_1 to P4_1 | |
| p4_w1 = self.p4_w1_relu(self.p4_w1) | |
| weight = p4_w1 / (torch.sum(p4_w1, dim=0) + self.epsilon) | |
| # Connections for P4_0 and P5_1 to P4_1 respectively | |
| p4_up = self.conv4_up(self.swish(weight[0] * p4_in + weight[1] * self.p4_upsample(p5_up))) | |
| # Weights for P3_0 and P4_1 to P3_2 | |
| p3_w1 = self.p3_w1_relu(self.p3_w1) | |
| weight = p3_w1 / (torch.sum(p3_w1, dim=0) + self.epsilon) | |
| # Connections for P3_0 and P4_1 to P3_2 respectively | |
| p3_out = self.conv3_up(self.swish(weight[0] * p3_in + weight[1] * self.p3_upsample(p4_up))) | |
| if self.first_time: | |
| p4_in = self.p4_down_channel_2(p4) | |
| p5_in = self.p5_down_channel_2(p5) | |
| # Weights for P4_0, P4_1 and P3_2 to P4_2 | |
| p4_w2 = self.p4_w2_relu(self.p4_w2) | |
| weight = p4_w2 / (torch.sum(p4_w2, dim=0) + self.epsilon) | |
| # Connections for P4_0, P4_1 and P3_2 to P4_2 respectively | |
| p4_out = self.conv4_down( | |
| self.swish(weight[0] * p4_in + weight[1] * p4_up + weight[2] * self.p4_downsample(p3_out))) | |
| # Weights for P5_0, P5_1 and P4_2 to P5_2 | |
| p5_w2 = self.p5_w2_relu(self.p5_w2) | |
| weight = p5_w2 / (torch.sum(p5_w2, dim=0) + self.epsilon) | |
| # Connections for P5_0, P5_1 and P4_2 to P5_2 respectively | |
| p5_out = self.conv5_down( | |
| self.swish(weight[0] * p5_in + weight[1] * p5_up + weight[2] * self.p5_downsample(p4_out))) | |
| # Weights for P6_0, P6_1 and P5_2 to P6_2 | |
| p6_w2 = self.p6_w2_relu(self.p6_w2) | |
| weight = p6_w2 / (torch.sum(p6_w2, dim=0) + self.epsilon) | |
| # Connections for P6_0, P6_1 and P5_2 to P6_2 respectively | |
| p6_out = self.conv6_down( | |
| self.swish(weight[0] * p6_in + weight[1] * p6_up + weight[2] * self.p6_downsample(p5_out))) | |
| # Weights for P7_0 and P6_2 to P7_2 | |
| p7_w2 = self.p7_w2_relu(self.p7_w2) | |
| weight = p7_w2 / (torch.sum(p7_w2, dim=0) + self.epsilon) | |
| # Connections for P7_0 and P6_2 to P7_2 | |
| p7_out = self.conv7_down(self.swish(weight[0] * p7_in + weight[1] * self.p7_downsample(p6_out))) | |
| return p3_out, p4_out, p5_out, p6_out, p7_out | |
| def _forward(self, inputs): | |
| if self.first_time: | |
| p3, p4, p5 = inputs | |
| p6_in = self.p5_to_p6(p5) | |
| p7_in = self.p6_to_p7(p6_in) | |
| p3_in = self.p3_down_channel(p3) | |
| p4_in = self.p4_down_channel(p4) | |
| p5_in = self.p5_down_channel(p5) | |
| else: | |
| # P3_0, P4_0, P5_0, P6_0 and P7_0 | |
| p3_in, p4_in, p5_in, p6_in, p7_in = inputs | |
| # P7_0 to P7_2 | |
| # Connections for P6_0 and P7_0 to P6_1 respectively | |
| p6_up = self.conv6_up(self.swish(p6_in + self.p6_upsample(p7_in))) | |
| # Connections for P5_0 and P6_1 to P5_1 respectively | |
| p5_up = self.conv5_up(self.swish(p5_in + self.p5_upsample(p6_up))) | |
| # Connections for P4_0 and P5_1 to P4_1 respectively | |
| p4_up = self.conv4_up(self.swish(p4_in + self.p4_upsample(p5_up))) | |
| # Connections for P3_0 and P4_1 to P3_2 respectively | |
| p3_out = self.conv3_up(self.swish(p3_in + self.p3_upsample(p4_up))) | |
| if self.first_time: | |
| p4_in = self.p4_down_channel_2(p4) | |
| p5_in = self.p5_down_channel_2(p5) | |
| # Connections for P4_0, P4_1 and P3_2 to P4_2 respectively | |
| p4_out = self.conv4_down( | |
| self.swish(p4_in + p4_up + self.p4_downsample(p3_out))) | |
| # Connections for P5_0, P5_1 and P4_2 to P5_2 respectively | |
| p5_out = self.conv5_down( | |
| self.swish(p5_in + p5_up + self.p5_downsample(p4_out))) | |
| # Connections for P6_0, P6_1 and P5_2 to P6_2 respectively | |
| p6_out = self.conv6_down( | |
| self.swish(p6_in + p6_up + self.p6_downsample(p5_out))) | |
| # Connections for P7_0 and P6_2 to P7_2 | |
| p7_out = self.conv7_down(self.swish(p7_in + self.p7_downsample(p6_out))) | |
| return p3_out, p4_out, p5_out, p6_out, p7_out |