import torch import torch.nn as nn import torch.nn.functional as F def weight_init(m): if isinstance(m, (nn.Conv2d,)): # torch.nn.init.xavier_uniform_(m.weight, gain=1.0) torch.nn.init.xavier_normal_(m.weight, gain=1.0) # torch.nn.init.normal_(m.weight, mean=0.0, std=0.01) if m.weight.data.shape[1] == torch.Size([1]): torch.nn.init.normal_(m.weight, mean=0.0) if m.bias is not None: torch.nn.init.zeros_(m.bias) # for fusion layer if isinstance(m, (nn.ConvTranspose2d,)): # torch.nn.init.xavier_uniform_(m.weight, gain=1.0) torch.nn.init.xavier_normal_(m.weight, gain=1.0) # torch.nn.init.normal_(m.weight, mean=0.0, std=0.01) if m.weight.data.shape[1] == torch.Size([1]): torch.nn.init.normal_(m.weight, std=0.1) if m.bias is not None: torch.nn.init.zeros_(m.bias) class CoFusion(nn.Module): def __init__(self, in_ch, out_ch): super(CoFusion, self).__init__() self.conv1 = nn.Conv2d(in_ch, 64, kernel_size=3, stride=1, padding=1) self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) self.conv3 = nn.Conv2d(64, out_ch, kernel_size=3, stride=1, padding=1) self.relu = nn.ReLU() self.norm_layer1 = nn.GroupNorm(4, 64) self.norm_layer2 = nn.GroupNorm(4, 64) def forward(self, x): # fusecat = torch.cat(x, dim=1) attn = self.relu(self.norm_layer1(self.conv1(x))) attn = self.relu(self.norm_layer2(self.conv2(attn))) attn = F.softmax(self.conv3(attn), dim=1) # return ((fusecat * attn).sum(1)).unsqueeze(1) return ((x * attn).sum(1)).unsqueeze(1) class _DenseLayer(nn.Sequential): def __init__(self, input_features, out_features): super(_DenseLayer, self).__init__() # self.add_module('relu2', nn.ReLU(inplace=True)), self.add_module('conv1', nn.Conv2d(input_features, out_features, kernel_size=3, stride=1, padding=2, bias=True)), self.add_module('norm1', nn.BatchNorm2d(out_features)), self.add_module('relu1', nn.ReLU(inplace=True)), self.add_module('conv2', nn.Conv2d(out_features, out_features, kernel_size=3, stride=1, bias=True)), self.add_module('norm2', nn.BatchNorm2d(out_features)) def forward(self, x): x1, x2 = x new_features = super(_DenseLayer, self).forward(F.relu(x1)) # F.relu() # if new_features.shape[-1]!=x2.shape[-1]: # new_features =F.interpolate(new_features,size=(x2.shape[2],x2.shape[-1]), mode='bicubic', # align_corners=False) return 0.5 * (new_features + x2), x2 class _DenseBlock(nn.Sequential): def __init__(self, num_layers, input_features, out_features): super(_DenseBlock, self).__init__() for i in range(num_layers): layer = _DenseLayer(input_features, out_features) self.add_module('denselayer%d' % (i + 1), layer) input_features = out_features class UpConvBlock(nn.Module): def __init__(self, in_features, up_scale): super(UpConvBlock, self).__init__() self.up_factor = 2 self.constant_features = 16 layers = self.make_deconv_layers(in_features, up_scale) assert layers is not None, layers self.features = nn.Sequential(*layers) def make_deconv_layers(self, in_features, up_scale): layers = [] all_pads=[0,0,1,3,7] for i in range(up_scale): kernel_size = 2 ** up_scale pad = all_pads[up_scale] # kernel_size-1 out_features = self.compute_out_features(i, up_scale) layers.append(nn.Conv2d(in_features, out_features, 1)) layers.append(nn.ReLU(inplace=True)) layers.append(nn.ConvTranspose2d( out_features, out_features, kernel_size, stride=2, padding=pad)) in_features = out_features return layers def compute_out_features(self, idx, up_scale): return 1 if idx == up_scale - 1 else self.constant_features def forward(self, x): return self.features(x) class SingleConvBlock(nn.Module): def __init__(self, in_features, out_features, stride, use_bs=True ): super(SingleConvBlock, self).__init__() self.use_bn = use_bs self.conv = nn.Conv2d(in_features, out_features, 1, stride=stride, bias=True) self.bn = nn.BatchNorm2d(out_features) def forward(self, x): x = self.conv(x) if self.use_bn: x = self.bn(x) return x class DoubleConvBlock(nn.Module): def __init__(self, in_features, mid_features, out_features=None, stride=1, use_act=True): super(DoubleConvBlock, self).__init__() self.use_act = use_act if out_features is None: out_features = mid_features self.conv1 = nn.Conv2d(in_features, mid_features, 3, padding=1, stride=stride) self.bn1 = nn.BatchNorm2d(mid_features) self.conv2 = nn.Conv2d(mid_features, out_features, 3, padding=1) self.bn2 = nn.BatchNorm2d(out_features) self.relu = nn.ReLU(inplace=True) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.conv2(x) x = self.bn2(x) if self.use_act: x = self.relu(x) return x class DexiNed(nn.Module): """ Definition of the DXtrem network. """ def __init__(self): super(DexiNed, self).__init__() self.block_1 = DoubleConvBlock(3, 32, 64, stride=2,) self.block_2 = DoubleConvBlock(64, 128, use_act=False) self.dblock_3 = _DenseBlock(2, 128, 256) # [128,256,100,100] self.dblock_4 = _DenseBlock(3, 256, 512) self.dblock_5 = _DenseBlock(3, 512, 512) self.dblock_6 = _DenseBlock(3, 512, 256) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # left skip connections, figure in Journal self.side_1 = SingleConvBlock(64, 128, 2) self.side_2 = SingleConvBlock(128, 256, 2) self.side_3 = SingleConvBlock(256, 512, 2) self.side_4 = SingleConvBlock(512, 512, 1) self.side_5 = SingleConvBlock(512, 256, 1) # Sory I forget to comment this line :( # right skip connections, figure in Journal paper self.pre_dense_2 = SingleConvBlock(128, 256, 2) self.pre_dense_3 = SingleConvBlock(128, 256, 1) self.pre_dense_4 = SingleConvBlock(256, 512, 1) self.pre_dense_5 = SingleConvBlock(512, 512, 1) self.pre_dense_6 = SingleConvBlock(512, 256, 1) self.up_block_1 = UpConvBlock(64, 1) self.up_block_2 = UpConvBlock(128, 1) self.up_block_3 = UpConvBlock(256, 2) self.up_block_4 = UpConvBlock(512, 3) self.up_block_5 = UpConvBlock(512, 4) self.up_block_6 = UpConvBlock(256, 4) self.block_cat = SingleConvBlock(6, 1, stride=1, use_bs=False) # hed fusion method # self.block_cat = CoFusion(6,6)# cats fusion method self.apply(weight_init) def slice(self, tensor, slice_shape): t_shape = tensor.shape height, width = slice_shape if t_shape[-1]!=slice_shape[-1]: new_tensor = F.interpolate( tensor, size=(height, width), mode='bicubic',align_corners=False) else: new_tensor=tensor # tensor[..., :height, :width] return new_tensor def forward(self, x): assert x.ndim == 4, x.shape # Block 1 block_1 = self.block_1(x) block_1_side = self.side_1(block_1) # Block 2 block_2 = self.block_2(block_1) block_2_down = self.maxpool(block_2) block_2_add = block_2_down + block_1_side block_2_side = self.side_2(block_2_add) # Block 3 block_3_pre_dense = self.pre_dense_3(block_2_down) block_3, _ = self.dblock_3([block_2_add, block_3_pre_dense]) block_3_down = self.maxpool(block_3) # [128,256,50,50] block_3_add = block_3_down + block_2_side block_3_side = self.side_3(block_3_add) # Block 4 block_2_resize_half = self.pre_dense_2(block_2_down) block_4_pre_dense = self.pre_dense_4(block_3_down+block_2_resize_half) block_4, _ = self.dblock_4([block_3_add, block_4_pre_dense]) block_4_down = self.maxpool(block_4) block_4_add = block_4_down + block_3_side block_4_side = self.side_4(block_4_add) # Block 5 block_5_pre_dense = self.pre_dense_5( block_4_down) #block_5_pre_dense_512 +block_4_down block_5, _ = self.dblock_5([block_4_add, block_5_pre_dense]) block_5_add = block_5 + block_4_side # Block 6 block_6_pre_dense = self.pre_dense_6(block_5) block_6, _ = self.dblock_6([block_5_add, block_6_pre_dense]) # upsampling blocks out_1 = self.up_block_1(block_1) out_2 = self.up_block_2(block_2) out_3 = self.up_block_3(block_3) out_4 = self.up_block_4(block_4) out_5 = self.up_block_5(block_5) out_6 = self.up_block_6(block_6) results = [out_1, out_2, out_3, out_4, out_5, out_6] # concatenate multiscale outputs block_cat = torch.cat(results, dim=1) # Bx6xHxW block_cat = self.block_cat(block_cat) # Bx1xHxW # return results results.append(block_cat) return results if __name__ == '__main__': batch_size = 8 img_height = 352 img_width = 352 # device = "cuda" if torch.cuda.is_available() else "cpu" device = "cpu" input = torch.rand(batch_size, 3, img_height, img_width).to(device) # target = torch.rand(batch_size, 1, img_height, img_width).to(device) print(f"input shape: {input.shape}") model = DexiNed().to(device) output = model(input) print(f"output shapes: {[t.shape for t in output]}") # for i in range(20000): # print(i) # output = model(input) # loss = nn.MSELoss()(output[-1], target) # loss.backward()