Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torchvision | |
| from . import resnet, resnext | |
| try: | |
| from lib.nn import SynchronizedBatchNorm2d | |
| except ImportError: | |
| from torch.nn import BatchNorm2d as SynchronizedBatchNorm2d | |
| class SegmentationModuleBase(nn.Module): | |
| def __init__(self): | |
| super(SegmentationModuleBase, self).__init__() | |
| def pixel_acc(self, pred, label): | |
| _, preds = torch.max(pred, dim=1) | |
| valid = (label >= 0).long() | |
| acc_sum = torch.sum(valid * (preds == label).long()) | |
| pixel_sum = torch.sum(valid) | |
| acc = acc_sum.float() / (pixel_sum.float() + 1e-10) | |
| return acc | |
| class SegmentationModule(SegmentationModuleBase): | |
| def __init__(self, net_enc, net_dec, crit, deep_sup_scale=None): | |
| super(SegmentationModule, self).__init__() | |
| self.encoder = net_enc | |
| self.decoder = net_dec | |
| self.crit = crit | |
| self.deep_sup_scale = deep_sup_scale | |
| def forward(self, feed_dict, *, segSize=None): | |
| if segSize is None: # training | |
| if self.deep_sup_scale is not None: # use deep supervision technique | |
| (pred, pred_deepsup) = self.decoder(self.encoder(feed_dict['img_data'], return_feature_maps=True)) | |
| else: | |
| pred = self.decoder(self.encoder(feed_dict['img_data'], return_feature_maps=True)) | |
| loss = self.crit(pred, feed_dict['seg_label']) | |
| if self.deep_sup_scale is not None: | |
| loss_deepsup = self.crit(pred_deepsup, feed_dict['seg_label']) | |
| loss = loss + loss_deepsup * self.deep_sup_scale | |
| acc = self.pixel_acc(pred, feed_dict['seg_label']) | |
| return loss, acc | |
| else: # inference | |
| pred = self.decoder(self.encoder(feed_dict['img_data'], return_feature_maps=True), segSize=segSize) | |
| return pred | |
| def conv3x3(in_planes, out_planes, stride=1, has_bias=False): | |
| "3x3 convolution with padding" | |
| return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, | |
| padding=1, bias=has_bias) | |
| def conv3x3_bn_relu(in_planes, out_planes, stride=1): | |
| return nn.Sequential( | |
| conv3x3(in_planes, out_planes, stride), | |
| SynchronizedBatchNorm2d(out_planes), | |
| nn.ReLU(inplace=True), | |
| ) | |
| class ModelBuilder(): | |
| # custom weights initialization | |
| def weights_init(self, m): | |
| classname = m.__class__.__name__ | |
| if classname.find('Conv') != -1: | |
| nn.init.kaiming_normal_(m.weight.data) | |
| elif classname.find('BatchNorm') != -1: | |
| m.weight.data.fill_(1.) | |
| m.bias.data.fill_(1e-4) | |
| #elif classname.find('Linear') != -1: | |
| # m.weight.data.normal_(0.0, 0.0001) | |
| def build_encoder(self, arch='resnet50_dilated8', fc_dim=512, weights=''): | |
| pretrained = True if len(weights) == 0 else False | |
| if arch == 'resnet34': | |
| raise NotImplementedError | |
| orig_resnet = resnet.__dict__['resnet34'](pretrained=pretrained) | |
| net_encoder = Resnet(orig_resnet) | |
| elif arch == 'resnet34_dilated8': | |
| raise NotImplementedError | |
| orig_resnet = resnet.__dict__['resnet34'](pretrained=pretrained) | |
| net_encoder = ResnetDilated(orig_resnet, | |
| dilate_scale=8) | |
| elif arch == 'resnet34_dilated16': | |
| raise NotImplementedError | |
| orig_resnet = resnet.__dict__['resnet34'](pretrained=pretrained) | |
| net_encoder = ResnetDilated(orig_resnet, | |
| dilate_scale=16) | |
| elif arch == 'resnet50': | |
| orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained) | |
| net_encoder = Resnet(orig_resnet) | |
| elif arch == 'resnet50_dilated8': | |
| orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained) | |
| net_encoder = ResnetDilated(orig_resnet, | |
| dilate_scale=8) | |
| elif arch == 'resnet50_dilated16': | |
| orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained) | |
| net_encoder = ResnetDilated(orig_resnet, | |
| dilate_scale=16) | |
| elif arch == 'resnet101': | |
| orig_resnet = resnet.__dict__['resnet101'](pretrained=pretrained) | |
| net_encoder = Resnet(orig_resnet) | |
| elif arch == 'resnet101_dilated8': | |
| orig_resnet = resnet.__dict__['resnet101'](pretrained=pretrained) | |
| net_encoder = ResnetDilated(orig_resnet, | |
| dilate_scale=8) | |
| elif arch == 'resnet101_dilated16': | |
| orig_resnet = resnet.__dict__['resnet101'](pretrained=pretrained) | |
| net_encoder = ResnetDilated(orig_resnet, | |
| dilate_scale=16) | |
| elif arch == 'resnext101': | |
| orig_resnext = resnext.__dict__['resnext101'](pretrained=pretrained) | |
| net_encoder = Resnet(orig_resnext) # we can still use class Resnet | |
| else: | |
| raise Exception('Architecture undefined!') | |
| # net_encoder.apply(self.weights_init) | |
| if len(weights) > 0: | |
| # print('Loading weights for net_encoder') | |
| net_encoder.load_state_dict( | |
| torch.load(weights, map_location=lambda storage, loc: storage), strict=False) | |
| return net_encoder | |
| def build_decoder(self, arch='ppm_bilinear_deepsup', | |
| fc_dim=512, num_class=150, | |
| weights='', inference=False, use_softmax=False): | |
| if arch == 'c1_bilinear_deepsup': | |
| net_decoder = C1BilinearDeepSup( | |
| num_class=num_class, | |
| fc_dim=fc_dim, | |
| inference=inference, | |
| use_softmax=use_softmax) | |
| elif arch == 'c1_bilinear': | |
| net_decoder = C1Bilinear( | |
| num_class=num_class, | |
| fc_dim=fc_dim, | |
| inference=inference, | |
| use_softmax=use_softmax) | |
| elif arch == 'ppm_bilinear': | |
| net_decoder = PPMBilinear( | |
| num_class=num_class, | |
| fc_dim=fc_dim, | |
| inference=inference, | |
| use_softmax=use_softmax) | |
| elif arch == 'ppm_bilinear_deepsup': | |
| net_decoder = PPMBilinearDeepsup( | |
| num_class=num_class, | |
| fc_dim=fc_dim, | |
| inference=inference, | |
| use_softmax=use_softmax) | |
| elif arch == 'upernet_lite': | |
| net_decoder = UPerNet( | |
| num_class=num_class, | |
| fc_dim=fc_dim, | |
| inference=inference, | |
| use_softmax=use_softmax, | |
| fpn_dim=256) | |
| elif arch == 'upernet': | |
| net_decoder = UPerNet( | |
| num_class=num_class, | |
| fc_dim=fc_dim, | |
| inference=inference, | |
| use_softmax=use_softmax, | |
| fpn_dim=512) | |
| elif arch == 'upernet_tmp': | |
| net_decoder = UPerNetTmp( | |
| num_class=num_class, | |
| fc_dim=fc_dim, | |
| inference=inference, | |
| use_softmax=use_softmax, | |
| fpn_dim=512) | |
| else: | |
| raise Exception('Architecture undefined!') | |
| net_decoder.apply(self.weights_init) | |
| if len(weights) > 0: | |
| # print('Loading weights for net_decoder') | |
| net_decoder.load_state_dict( | |
| torch.load(weights, map_location=lambda storage, loc: storage), strict=False) | |
| return net_decoder | |
| class Resnet(nn.Module): | |
| def __init__(self, orig_resnet): | |
| super(Resnet, self).__init__() | |
| # take pretrained resnet, except AvgPool and FC | |
| self.conv1 = orig_resnet.conv1 | |
| self.bn1 = orig_resnet.bn1 | |
| self.relu1 = orig_resnet.relu1 | |
| self.conv2 = orig_resnet.conv2 | |
| self.bn2 = orig_resnet.bn2 | |
| self.relu2 = orig_resnet.relu2 | |
| self.conv3 = orig_resnet.conv3 | |
| self.bn3 = orig_resnet.bn3 | |
| self.relu3 = orig_resnet.relu3 | |
| self.maxpool = orig_resnet.maxpool | |
| self.layer1 = orig_resnet.layer1 | |
| self.layer2 = orig_resnet.layer2 | |
| self.layer3 = orig_resnet.layer3 | |
| self.layer4 = orig_resnet.layer4 | |
| def forward(self, x, return_feature_maps=False): | |
| conv_out = [] | |
| x = self.relu1(self.bn1(self.conv1(x))) | |
| x = self.relu2(self.bn2(self.conv2(x))) | |
| x = self.relu3(self.bn3(self.conv3(x))) | |
| x = self.maxpool(x) | |
| x = self.layer1(x); conv_out.append(x); | |
| x = self.layer2(x); conv_out.append(x); | |
| x = self.layer3(x); conv_out.append(x); | |
| x = self.layer4(x); conv_out.append(x); | |
| if return_feature_maps: | |
| return conv_out | |
| return [x] | |
| class ResnetDilated(nn.Module): | |
| def __init__(self, orig_resnet, dilate_scale=8): | |
| super(ResnetDilated, self).__init__() | |
| from functools import partial | |
| if dilate_scale == 8: | |
| orig_resnet.layer3.apply( | |
| partial(self._nostride_dilate, dilate=2)) | |
| orig_resnet.layer4.apply( | |
| partial(self._nostride_dilate, dilate=4)) | |
| elif dilate_scale == 16: | |
| orig_resnet.layer4.apply( | |
| partial(self._nostride_dilate, dilate=2)) | |
| # take pretrained resnet, except AvgPool and FC | |
| self.conv1 = orig_resnet.conv1 | |
| self.bn1 = orig_resnet.bn1 | |
| self.relu1 = orig_resnet.relu1 | |
| self.conv2 = orig_resnet.conv2 | |
| self.bn2 = orig_resnet.bn2 | |
| self.relu2 = orig_resnet.relu2 | |
| self.conv3 = orig_resnet.conv3 | |
| self.bn3 = orig_resnet.bn3 | |
| self.relu3 = orig_resnet.relu3 | |
| self.maxpool = orig_resnet.maxpool | |
| self.layer1 = orig_resnet.layer1 | |
| self.layer2 = orig_resnet.layer2 | |
| self.layer3 = orig_resnet.layer3 | |
| self.layer4 = orig_resnet.layer4 | |
| def _nostride_dilate(self, m, dilate): | |
| classname = m.__class__.__name__ | |
| if classname.find('Conv') != -1: | |
| # the convolution with stride | |
| if m.stride == (2, 2): | |
| m.stride = (1, 1) | |
| if m.kernel_size == (3, 3): | |
| m.dilation = (dilate//2, dilate//2) | |
| m.padding = (dilate//2, dilate//2) | |
| # other convoluions | |
| else: | |
| if m.kernel_size == (3, 3): | |
| m.dilation = (dilate, dilate) | |
| m.padding = (dilate, dilate) | |
| def forward(self, x, return_feature_maps=False): | |
| conv_out = [] | |
| x = self.relu1(self.bn1(self.conv1(x))) | |
| x = self.relu2(self.bn2(self.conv2(x))) | |
| x = self.relu3(self.bn3(self.conv3(x))) | |
| x = self.maxpool(x) | |
| x = self.layer1(x); conv_out.append(x); | |
| x = self.layer2(x); conv_out.append(x); | |
| x = self.layer3(x); conv_out.append(x); | |
| x = self.layer4(x); conv_out.append(x); | |
| if return_feature_maps: | |
| return conv_out | |
| return [x] | |
| # last conv, bilinear upsample | |
| class C1BilinearDeepSup(nn.Module): | |
| def __init__(self, num_class=150, fc_dim=2048, inference=False, use_softmax=False): | |
| super(C1BilinearDeepSup, self).__init__() | |
| self.use_softmax = use_softmax | |
| self.inference = inference | |
| self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1) | |
| self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1) | |
| # last conv | |
| self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) | |
| self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) | |
| def forward(self, conv_out, segSize=None): | |
| conv5 = conv_out[-1] | |
| x = self.cbr(conv5) | |
| x = self.conv_last(x) | |
| if self.inference or self.use_softmax: # is True during inference | |
| x = nn.functional.interpolate( | |
| x, size=segSize, mode='bilinear', align_corners=False) | |
| if self.use_softmax: | |
| x = nn.functional.softmax(x, dim=1) | |
| return x | |
| # deep sup | |
| conv4 = conv_out[-2] | |
| _ = self.cbr_deepsup(conv4) | |
| _ = self.conv_last_deepsup(_) | |
| x = nn.functional.log_softmax(x, dim=1) | |
| _ = nn.functional.log_softmax(_, dim=1) | |
| return (x, _) | |
| # last conv, bilinear upsample | |
| class C1Bilinear(nn.Module): | |
| def __init__(self, num_class=150, fc_dim=2048, inference=False, use_softmax=False): | |
| super(C1Bilinear, self).__init__() | |
| self.use_softmax = use_softmax | |
| self.inference = inference | |
| self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1) | |
| # last conv | |
| self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) | |
| def forward(self, conv_out, segSize=None): | |
| conv5 = conv_out[-1] | |
| x = self.cbr(conv5) | |
| x = self.conv_last(x) | |
| if self.inference or self.use_softmax: # is True during inference | |
| x = nn.functional.interpolate( | |
| x, size=segSize, mode='bilinear', align_corners=False) | |
| if self.use_softmax: | |
| x = nn.functional.softmax(x, dim=1) | |
| else: | |
| x = nn.functional.log_softmax(x, dim=1) | |
| return x | |
| # pyramid pooling, bilinear upsample | |
| class PPMBilinear(nn.Module): | |
| def __init__(self, num_class=150, fc_dim=4096, | |
| inference=False, use_softmax=False, pool_scales=(1, 2, 3, 6)): | |
| super(PPMBilinear, self).__init__() | |
| self.use_softmax = use_softmax | |
| self.inference = inference | |
| self.ppm = [] | |
| for scale in pool_scales: | |
| self.ppm.append(nn.Sequential( | |
| nn.AdaptiveAvgPool2d(scale), | |
| nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False), | |
| SynchronizedBatchNorm2d(512), | |
| nn.ReLU(inplace=True) | |
| )) | |
| self.ppm = nn.ModuleList(self.ppm) | |
| self.conv_last = nn.Sequential( | |
| nn.Conv2d(fc_dim+len(pool_scales)*512, 512, | |
| kernel_size=3, padding=1, bias=False), | |
| SynchronizedBatchNorm2d(512), | |
| nn.ReLU(inplace=True), | |
| nn.Dropout2d(0.1), | |
| nn.Conv2d(512, num_class, kernel_size=1) | |
| ) | |
| def forward(self, conv_out, segSize=None): | |
| conv5 = conv_out[-1] | |
| input_size = conv5.size() | |
| ppm_out = [conv5] | |
| for pool_scale in self.ppm: | |
| ppm_out.append(nn.functional.interpolate( | |
| pool_scale(conv5), | |
| (input_size[2], input_size[3]), | |
| mode='bilinear', align_corners=False)) | |
| ppm_out = torch.cat(ppm_out, 1) | |
| x = self.conv_last(ppm_out) | |
| if self.inference or self.use_softmax: # is True during inference | |
| x = nn.functional.interpolate( | |
| x, size=segSize, mode='bilinear', align_corners=False) | |
| if self.use_softmax: | |
| x = nn.functional.softmax(x, dim=1) | |
| else: | |
| x = nn.functional.log_softmax(x, dim=1) | |
| return x | |
| # pyramid pooling, bilinear upsample | |
| class PPMBilinearDeepsup(nn.Module): | |
| def __init__(self, num_class=150, fc_dim=4096, | |
| inference=False, use_softmax=False, pool_scales=(1, 2, 3, 6)): | |
| super(PPMBilinearDeepsup, self).__init__() | |
| self.use_softmax = use_softmax | |
| self.inference = inference | |
| self.ppm = [] | |
| for scale in pool_scales: | |
| self.ppm.append(nn.Sequential( | |
| nn.AdaptiveAvgPool2d(scale), | |
| nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False), | |
| SynchronizedBatchNorm2d(512), | |
| nn.ReLU(inplace=True) | |
| )) | |
| self.ppm = nn.ModuleList(self.ppm) | |
| self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1) | |
| self.conv_last = nn.Sequential( | |
| nn.Conv2d(fc_dim+len(pool_scales)*512, 512, | |
| kernel_size=3, padding=1, bias=False), | |
| SynchronizedBatchNorm2d(512), | |
| nn.ReLU(inplace=True), | |
| nn.Dropout2d(0.1), | |
| nn.Conv2d(512, num_class, kernel_size=1) | |
| ) | |
| self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) | |
| self.dropout_deepsup = nn.Dropout2d(0.1) | |
| def forward(self, conv_out, segSize=None): | |
| conv5 = conv_out[-1] | |
| input_size = conv5.size() | |
| ppm_out = [conv5] | |
| for pool_scale in self.ppm: | |
| ppm_out.append(nn.functional.interpolate( | |
| pool_scale(conv5), | |
| (input_size[2], input_size[3]), | |
| mode='bilinear', align_corners=False)) | |
| ppm_out = torch.cat(ppm_out, 1) | |
| x = self.conv_last(ppm_out) | |
| if self.inference or self.use_softmax: # is True during inference | |
| x = nn.functional.interpolate( | |
| x, size=segSize, mode='bilinear', align_corners=False) | |
| if self.use_softmax: | |
| x = nn.functional.softmax(x, dim=1) | |
| return x | |
| # deep sup | |
| conv4 = conv_out[-2] | |
| _ = self.cbr_deepsup(conv4) | |
| _ = self.dropout_deepsup(_) | |
| _ = self.conv_last_deepsup(_) | |
| x = nn.functional.log_softmax(x, dim=1) | |
| _ = nn.functional.log_softmax(_, dim=1) | |
| return (x, _) | |
| # upernet | |
| class UPerNet(nn.Module): | |
| def __init__(self, num_class=150, fc_dim=4096, | |
| inference=False, use_softmax=False, pool_scales=(1, 2, 3, 6), | |
| fpn_inplanes=(256,512,1024,2048), fpn_dim=256): | |
| super(UPerNet, self).__init__() | |
| self.use_softmax = use_softmax | |
| self.inference = inference | |
| # PPM Module | |
| self.ppm_pooling = [] | |
| self.ppm_conv = [] | |
| for scale in pool_scales: | |
| self.ppm_pooling.append(nn.AdaptiveAvgPool2d(scale)) | |
| self.ppm_conv.append(nn.Sequential( | |
| nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False), | |
| SynchronizedBatchNorm2d(512), | |
| nn.ReLU(inplace=True) | |
| )) | |
| self.ppm_pooling = nn.ModuleList(self.ppm_pooling) | |
| self.ppm_conv = nn.ModuleList(self.ppm_conv) | |
| self.ppm_last_conv = conv3x3_bn_relu(fc_dim + len(pool_scales)*512, fpn_dim, 1) | |
| # FPN Module | |
| self.fpn_in = [] | |
| for fpn_inplane in fpn_inplanes[:-1]: # skip the top layer | |
| self.fpn_in.append(nn.Sequential( | |
| nn.Conv2d(fpn_inplane, fpn_dim, kernel_size=1, bias=False), | |
| SynchronizedBatchNorm2d(fpn_dim), | |
| nn.ReLU(inplace=True) | |
| )) | |
| self.fpn_in = nn.ModuleList(self.fpn_in) | |
| self.fpn_out = [] | |
| for i in range(len(fpn_inplanes) - 1): # skip the top layer | |
| self.fpn_out.append(nn.Sequential( | |
| conv3x3_bn_relu(fpn_dim, fpn_dim, 1), | |
| )) | |
| self.fpn_out = nn.ModuleList(self.fpn_out) | |
| self.conv_last = nn.Sequential( | |
| conv3x3_bn_relu(len(fpn_inplanes) * fpn_dim, fpn_dim, 1), | |
| nn.Conv2d(fpn_dim, num_class, kernel_size=1) | |
| ) | |
| def forward(self, conv_out, segSize=None): | |
| conv5 = conv_out[-1] | |
| input_size = conv5.size() | |
| ppm_out = [conv5] | |
| for pool_scale, pool_conv in zip(self.ppm_pooling, self.ppm_conv): | |
| ppm_out.append(pool_conv(nn.functional.interploate( | |
| pool_scale(conv5), | |
| (input_size[2], input_size[3]), | |
| mode='bilinear', align_corners=False))) | |
| ppm_out = torch.cat(ppm_out, 1) | |
| f = self.ppm_last_conv(ppm_out) | |
| fpn_feature_list = [f] | |
| for i in reversed(range(len(conv_out) - 1)): | |
| conv_x = conv_out[i] | |
| conv_x = self.fpn_in[i](conv_x) # lateral branch | |
| f = nn.functional.interpolate( | |
| f, size=conv_x.size()[2:], mode='bilinear', align_corners=False) # top-down branch | |
| f = conv_x + f | |
| fpn_feature_list.append(self.fpn_out[i](f)) | |
| fpn_feature_list.reverse() # [P2 - P5] | |
| output_size = fpn_feature_list[0].size()[2:] | |
| fusion_list = [fpn_feature_list[0]] | |
| for i in range(1, len(fpn_feature_list)): | |
| fusion_list.append(nn.functional.interpolate( | |
| fpn_feature_list[i], | |
| output_size, | |
| mode='bilinear', align_corners=False)) | |
| fusion_out = torch.cat(fusion_list, 1) | |
| x = self.conv_last(fusion_out) | |
| if self.inference or self.use_softmax: # is True during inference | |
| x = nn.functional.interpolate( | |
| x, size=segSize, mode='bilinear', align_corners=False) | |
| if self.use_softmax: | |
| x = nn.functional.softmax(x, dim=1) | |
| return x | |
| x = nn.functional.log_softmax(x, dim=1) | |
| return x | |