import torch import torch.nn as nn import torch.nn.functional as F import einops from timm.models.layers import trunc_normal_ from einops import rearrange import math from model.MobileNetV2 import mobilenet_v2 from torch.nn import Parameter class BasicConv2d(nn.Module): def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1): super(BasicConv2d, self).__init__() self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=False) self.bn = nn.BatchNorm2d(out_planes) self.relu = nn.ReLU(inplace=True) def forward(self, x): x = self.conv(x) x = self.bn(x) x = self.relu(x) return x class Reduction(nn.Module): def __init__(self, in_channel, out_channel): super(Reduction, self).__init__() self.reduce = nn.Sequential( BasicConv2d(in_channel, out_channel, 1), BasicConv2d(out_channel, out_channel, 3, padding=1), BasicConv2d(out_channel, out_channel, 3, padding=1) ) def forward(self, x): return self.reduce(x) class TopDownLayer(nn.Module): def __init__(self, channel): super(TopDownLayer, self).__init__() self.conv = nn.Sequential(nn.Conv2d(channel, channel, 3, 1, 1, bias=False), nn.BatchNorm2d(channel)) self.relu = nn.ReLU() self.channel_compress = nn.Sequential( nn.Conv2d(channel * 2, channel, 1, bias=False), nn.BatchNorm2d(channel), nn.ReLU() ) def forward(self, x, x2): res1 = self.conv(x) res1 = self.relu(res1) res1 = F.interpolate(res1, x2.size()[2:], mode='bilinear', align_corners=True) res_cat = torch.cat((res1, x2), dim=1) resl = self.channel_compress(res_cat) return resl class MultiHeadAttention(nn.Module): def __init__(self, head=8, d_model=32, dropout=0.1): super(MultiHeadAttention, self).__init__() assert (d_model % head == 0) self.d_k = d_model // head self.head = head self.d_model = d_model self.linear_query = nn.Linear(d_model, d_model) self.linear_key = nn.Linear(d_model, d_model) self.linear_value = nn.Linear(d_model, d_model) self.dropout = nn.Dropout(p=dropout) self.attn = None self.inb = nn.Linear(32, d_model) def self_attention(self, query, key, value, mask=None): d_k = query.shape[-1] scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) self_attn = F.softmax(scores, dim=-1) # self.attn = self_attn if self.attn is None else self.attn + self_attn if self.dropout is not None: self_attn = self.dropout(self_attn) return torch.matmul(self_attn, value), self_attn def forward(self, query, key, value, mask=None): n_batch = query.size(0) query = query.flatten(start_dim=2).permute(0, 2, 1) query = self.inb(query) key = key.flatten(start_dim=2).permute(0, 2, 1) key = self.inb(key) value = value.flatten(start_dim=2).permute(0, 2, 1) value = self.inb(value) x, self.attn = self.self_attention(query, key, value, mask=mask) x = x.permute(0, 2, 1) embedding_dim = x.size(-1) d_k = h = int(embedding_dim ** 0.5) x = einops.rearrange(x, 'b n (d_k h) -> b n d_k h', d_k=d_k, h=h) return x class Upsample(nn.Module): def __init__(self): super(Upsample, self).__init__() def forward(self, x, x2): x = F.interpolate(x, size=x2.size()[2:], mode='bilinear', align_corners=True) return x class MultiScaleAttention(nn.Module): def __init__(self, channel): super(MultiScaleAttention, self).__init__() # SPatial attention for each branch self.attention_branches = nn.ModuleList([SpatialAttention() for _ in range(5)]) self.upsample = Upsample() self.conv_reduce = nn.Conv2d(channel * 6, channel, kernel_size=1) def forward(self, x0, x1, x2, x3, x4, x5): x0_att = self.attention_branches[0](x0) * x0 x1_att = self.attention_branches[0](x1) * x1 x2_att = self.attention_branches[0](x2) * x2 x3_att = self.attention_branches[0](x3) * x3 x4_att = self.attention_branches[0](x4) * x4 x5_att = self.attention_branches[0](x5) * x5 x1_att_up = self.upsample(x1_att, x0) x2_att_up = self.upsample(x2_att, x0) x3_att_up = self.upsample(x3_att, x0) x4_att_up = self.upsample(x4_att, x0) x5_att_up = self.upsample(x5_att, x0) x_cat = torch.cat((x0_att, x1_att_up, x2_att_up, x3_att_up, x4_att_up, x5_att_up), dim=1) x_out = self.conv_reduce(x_cat) return x_out class Basic2(nn.Module): def __init__(self, in_channel, out_channel): super(Basic2, self).__init__() self.relu = nn.ReLU(True) # join self.channel_attention = ChannelAttention(out_channel) self.channel_attention = SpatialAttention() self.branch0 = nn.Sequential( BasicConv2d(in_channel, out_channel, 1), ) self.branch1 = nn.Sequential( BasicConv2d(in_channel, out_channel, kernel_size=(1, 3), padding=(0, 1)), BasicConv2d(out_channel, out_channel, kernel_size=(3, 1), padding=(1, 0)), BasicConv2d(out_channel, out_channel, 3, padding=3, dilation=3) ) self.branch2 = nn.Sequential( BasicConv2d(in_channel, out_channel, kernel_size=(1, 5), padding=(0, 2)), BasicConv2d(out_channel, out_channel, kernel_size=(5, 1), padding=(2, 0)), BasicConv2d(out_channel, out_channel, 3, padding=5, dilation=5) ) self.branch3 = nn.Sequential( BasicConv2d(in_channel, out_channel, kernel_size=(1, 7), padding=(0, 3)), BasicConv2d(out_channel, out_channel, kernel_size=(7, 1), padding=(3, 0)), BasicConv2d(out_channel, out_channel, 3, padding=7, dilation=7) ) self.branch4 = nn.Sequential( BasicConv2d(in_channel, out_channel, kernel_size=(1, 9), padding=(0, 4)), BasicConv2d(out_channel, out_channel, kernel_size=(9, 1), padding=(4, 0)), BasicConv2d(out_channel, out_channel, 3, padding=9, dilation=9) ) self.branch5 = nn.Sequential( BasicConv2d(in_channel, out_channel, kernel_size=(1, 11), padding=(0, 5)), BasicConv2d(out_channel, out_channel, kernel_size=(11, 1), padding=(5, 0)), BasicConv2d(out_channel, out_channel, 3, padding=11, dilation=11) ) self.multi_scale_attention = MultiScaleAttention(out_channel) self.conv_combine = BasicConv2d(in_channel, in_channel, kernel_size=3, padding=1) def forward(self, x): x0 = self.branch0(x) x1 = self.branch1(x) x2 = self.branch2(x) x3 = self.branch3(x) x4 = self.branch4(x) x5 = self.branch5(x) x_att = self.multi_scale_attention(x0, x1, x2, x3, x4, x5) x_combind = self.conv_combine(x_att) x = x_combind + x return x class ChannelAttention(nn.Module): def __init__(self, in_planes): super(ChannelAttention, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.max_pool = nn.AdaptiveMaxPool2d(1) self.fc = nn.Sequential(nn.Conv2d(in_planes, in_planes // 2, 1, bias=False), nn.ReLU(), nn.Conv2d(in_planes // 2, in_planes, 1, bias=False)) self.sigmoid = nn.Sigmoid() def forward(self, x): avg_out = self.fc(self.avg_pool(x)) max_out = self.fc(self.max_pool(x)) out = avg_out + max_out return self.sigmoid(out) class SpatialAttention(nn.Module): def __init__(self, kernel_size=7): super(SpatialAttention, self).__init__() self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False) self.sigmoid = nn.Sigmoid() def forward(self, x): avg_out = torch.mean(x, dim=1, keepdim=True) max_out, _ = torch.max(x, dim=1, keepdim=True) x1 = torch.cat([avg_out, max_out], dim=1) x2 = self.conv1(x1) return self.sigmoid(x2) class MModule(nn.Module): def __init__(self, channel): super(MModule, self).__init__() self.basic = Basic2(channel, channel) self.SA = SpatialAttention() self.CA = ChannelAttention(channel) def forward(self, x): x_mix = self.basic(x) x_mix = x_mix * self.CA(x_mix) + x_mix x_mix1 = x_mix * self.SA(x_mix) + x_mix x_mix1 = x_mix1 + x return x_mix1 class MNodule(nn.Module): def __init__(self, channel): super(MNodule, self).__init__() self.atrconv1 = BasicConv2d(channel, channel, 3, padding=3, dilation=3) self.atrconv2 = BasicConv2d(channel, channel, 3, padding=5, dilation=5) self.atrconv3 = BasicConv2d(channel, channel, 3, padding=7, dilation=7) self.branch1 = nn.Sequential( BasicConv2d(channel, channel, 1), BasicConv2d(channel, channel, kernel_size=(1, 3), padding=(0, 1)), BasicConv2d(channel, channel, kernel_size=(3, 1), padding=(1, 0)) ) self.branch2 = nn.Sequential( BasicConv2d(channel, channel, 1), BasicConv2d(channel, channel, kernel_size=(1, 5), padding=(0, 2)), BasicConv2d(channel, channel, kernel_size=(5, 1), padding=(2, 0)) ) self.branch3 = nn.Sequential( BasicConv2d(channel, channel, 1), BasicConv2d(channel, channel, kernel_size=(1, 7), padding=(0, 3)), BasicConv2d(channel, channel, kernel_size=(7, 1), padding=(3, 0)) ) self.conv_cat1 = BasicConv2d(2 * channel, channel, 3, padding=1) self.conv_cat2 = BasicConv2d(2 * channel, channel, 3, padding=1) self.conv_cat3 = BasicConv2d(2 * channel, channel, 3, padding=1) self.conv1_1 = BasicConv2d(channel, channel, 1) self.SA = SpatialAttention() self.CA = ChannelAttention(channel) self.sal_conv = nn.Sequential( BasicConv2d(channel, channel, 3, padding=1), BasicConv2d(channel, channel, 3, padding=1) ) self.sigmoid = nn.Sigmoid() def forward(self, x): x1 = self.branch1(x) x_atr1 = self.atrconv1(x) s_mfeb1 = self.conv_cat1(torch.cat((x1, x_atr1), 1)) + x x2 = self.branch2(s_mfeb1) x_atr2 = self.atrconv2(s_mfeb1) s_mfeb2 = self.conv_cat2(torch.cat((x2, x_atr2), 1)) + s_mfeb1 + x x3 = self.branch3(s_mfeb2) x_atr3 = self.atrconv3(s_mfeb2) s_mfeb3 = self.conv_cat3(torch.cat((x3, x_atr3), 1)) + s_mfeb1 + s_mfeb2 + x x_m = self.conv1_1(s_mfeb3) x_ca = self.CA(x_m) * x_m x_e = self.CA(x_m) * x_m x_mix = self.sal_conv((self.SA(x_ca)) * x_ca) + s_mfeb1 + s_mfeb2 + s_mfeb3 + x return x_mix class TransBasicConv2d(nn.Module): def __init__(self, in_planes, out_planes, kernel_size=2, stride=2, padding=0, dilation=1, bias=False): super(TransBasicConv2d, self).__init__() self.Deconv = nn.ConvTranspose2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias) self.bn = nn.BatchNorm2d(out_planes) self.relu = nn.ReLU(inplace=True) def forward(self, x): x = self.Deconv(x) x = self.bn(x) x = self.relu(x) return x class features(nn.Module): def __init__(self, channel): super(features, self).__init__() self.conv1 = BasicConv2d(channel, channel, 1) self.conv2 = BasicConv2d(channel, channel, 1) self.conv3 = BasicConv2d(channel, channel, 1) self.conv4 = BasicConv2d(channel, channel, 1) self.conv5 = BasicConv2d(channel, channel, 1) def forward(self, x1, x2, x3, x4, x5): x1 = self.conv1(x1) x2 = self.conv2(x2) x3 = self.conv3(x3) x4 = self.conv4(x4) x5 = self.conv5(x5) return x1, x2, x3, x4, x5 class conv_upsamle(nn.Module): def __init__(self, channel): super(conv_upsamle, self).__init__() self.conv = BasicConv2d(channel, channel, 3, padding=1) def forward(self, x, target): if x.size()[2:] != target.size()[2:]: x = F.interpolate(x, size=target.size()[2:], mode='bilinear', align_corners=True) x = self.conv(x) return x class AP_MP(nn.Module): def __init__(self, stride=2): super(AP_MP, self).__init__() self.sz = stride self.gapLayer = nn.AvgPool2d(kernel_size=self.sz, stride=self.sz) self.gmpLayer = nn.MaxPool2d(kernel_size=self.sz, stride=self.sz) def forward(self, x1, x2): B, C, H, W = x1.size() apimg = self.gapLayer(x1) mpimg = self.gmpLayer(x2) byimg = torch.norm(abs(apimg - mpimg), p=2, dim=1, keepdim=True) return byimg class MOM(nn.Module): def __init__(self, channel): super(MOM, self).__init__() self.channel = channel self.conv1 = BasicConv2d(channel, channel, 3, padding=1) self.conv2 = BasicConv2d(channel, channel, 3, padding=1) self.CA1 = ChannelAttention(self.channel) self.CA2 = ChannelAttention(self.channel) self.SA1 = SpatialAttention() self.SA2 = SpatialAttention() self.glbamp = AP_MP() self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) self.conv = BasicConv2d(channel * 2 , channel, kernel_size=1, stride=1) self.upSA = SpatialAttention() def forward(self, x1, x2): x1 = self.conv1(x1) x2 = self.conv2(x2) x1 = x1 + x1 * self.CA1(x1) x2 = x2 + x2 * self.CA2(x2) nx1 = x1 + x1 * self.SA2(x2) nx2 = x2 + x2 * self.SA1(x1) res = self.conv(torch.cat([nx1, nx2], dim=1)) res = res + x1 edg = res ske = res return res, edg, ske class AFM(nn.Module): def __init__(self, channel): super(AFM, self).__init__() self.max_pool = nn.AdaptiveMaxPool2d(1) self.sigmoid = nn.Sigmoid() self.conv1_1 = nn.Conv2d(channel, channel, kernel_size=1) self.ca1 = ChannelAttention(channel) self.ca2 = ChannelAttention(channel) self.sa = SpatialAttention() self.sal_conv = nn.Sequential( BasicConv2d(channel, channel, 3, padding=1), BasicConv2d(channel, channel, 3, padding=1) ) self.sigmoid = nn.Sigmoid() def forward(self, x1, x2): x2 = self.sigmoid(self.max_pool(x2)) xb = x2 * x1 x = self.conv1_1(xb) x_c = self.ca1(x) * x x_d = self.ca2(x) * x s_mea = self.sal_conv((self.sa(x_c)) * x_c) + x1 + x2 + xb ske = s_mea e_pred = s_mea return s_mea, e_pred, ske class DummyMOM(nn.Module): def __init__(self, channel): super(DummyMOM, self).__init__() self.conv1 = nn.Identity() # 保持输入输出一致 self.conv2 = nn.Identity() # 保持输入输出一致 # 调整为64个输入通道 self.conv = nn.Conv2d(64, 32, kernel_size=1) # 1x1卷积调整通道数 def forward(self, x1, x2): # 先做拼接,然后调整通道数为32 res = self.conv(torch.cat([x1, x2], dim=1)) edg = res ske = res return res, edg, ske class YUEM(nn.Module): def __init__(self, channel): super(YUEM, self).__init__() self.channel = channel self.m1 = MModule(self.channel) self.m2 = MNodule(self.channel) self.mha = MultiHeadAttention(channel) def forward(self, x1, x2): x1 = self.m1(x1) x21 = self.m2(x2) res = self.mha(x1, x21, x2) edg = res ske = res return res, edg, ske class MTG(nn.Module): def __init__(self, channel): super(MTG, self).__init__() self.ccs = nn.ModuleList([nn.Sequential( BasicConv2d(3 * channel, channel, kernel_size=3, padding=1), BasicConv2d(channel, channel, kernel_size=3, padding=1) ) for i in range(5)]) def forward(self, x_sal, x_edg, x_ske): x_combined = torch.cat((x_sal, x_edg,x_ske), dim=1) x_sal_n = self.ccs[0](x_combined) return x_sal_n class MMS(nn.Module): def __init__(self, pretrained=True, channel=32): super(MMS, self).__init__() self.backbone = mobilenet_v2(pretrained) self.Translayer1 = Reduction(16, channel) self.Translayer2 = Reduction(24, channel) self.Translayer3 = Reduction(32, channel) self.Translayer4 = Reduction(96, channel) self.Translayer5 = Reduction(320, channel) self.trans_conv1 = TransBasicConv2d(channel, channel, kernel_size=2, stride=2, padding=0, dilation=1, bias=False) self.trans_conv2 = TransBasicConv2d(channel, channel, kernel_size=2, stride=2, padding=0, dilation=1, bias=False) self.trans_conv3 = TransBasicConv2d(channel, channel, kernel_size=2, stride=2, padding=0, dilation=1, bias=False) self.trans_conv4 = TransBasicConv2d(channel, channel, kernel_size=2, stride=2, padding=0, dilation=1, bias=False) self.mom = MOM(channel) # self.mom = DummyMOM(channel) self.afm = AFM(channel) # self.afm = DummyMOM(channel) self.yuem = YUEM(channel) # self.yuem = DummyMOM(channel) self.sigmoid = nn.Sigmoid() self.sal_features = features(channel) self.edg_features = features(channel) self.ske_features = features(channel) self.MTG = MTG(channel) self.ccs = nn.ModuleList([nn.Sequential( BasicConv2d(3 * channel, channel, kernel_size=3, padding=1), BasicConv2d(channel, channel, kernel_size=3, padding=1) ) for i in range(5)]) self.cme = nn.ModuleList([nn.Sequential( BasicConv2d(3 * channel, channel, kernel_size=3, padding=1), BasicConv2d(channel, channel, kernel_size=3, padding=1) ) for i in range(5)]) self.cms = nn.ModuleList([nn.Sequential( BasicConv2d(3 * channel, channel, kernel_size=3, padding=1), BasicConv2d(channel, channel, kernel_size=3, padding=1) ) for i in range(5)]) self.conv_cats = nn.ModuleList([nn.Sequential( BasicConv2d(2 * channel, channel, kernel_size=3, padding=1), BasicConv2d(channel, channel, kernel_size=3, padding=1) ) for i in range(12)]) self.cus = nn.ModuleList([conv_upsamle(channel) for i in range(12)]) self.prediction = nn.ModuleList([ nn.Sequential( BasicConv2d(channel, channel, kernel_size=3, padding=1), nn.Conv2d(channel, 1, kernel_size=1) ) for i in range(3) ]) self.S1 = nn.Sequential( BasicConv2d(channel, channel, 3, padding=1), nn.Conv2d(channel, 1, 1) ) self.S2 = nn.Sequential( BasicConv2d(channel, channel, 3, padding=1), nn.Conv2d(channel, 1, 1) ) self.S3 = nn.Sequential( BasicConv2d(channel, channel, 3, padding=1), nn.Conv2d(channel, 1, 1) ) self.S4 = nn.Sequential( BasicConv2d(channel, channel, 3, padding=1), nn.Conv2d(channel, 1, 1) ) self.S5 = nn.Sequential( BasicConv2d(channel, channel, 3, padding=1), nn.Conv2d(channel, 1, 1) ) def forward(self, x): size = x.size()[2:] conv1, conv2, conv3, conv4, conv5 = self.backbone(x) conv1 = self.Translayer1(conv1) conv2 = self.Translayer2(conv2) conv3 = self.Translayer3(conv3) conv4 = self.Translayer4(conv4) conv5 = self.Translayer5(conv5) rgc5, edg5, ske5 = self.afm(conv5, conv5) rgc4, edg4, ske4 = self.yuem(conv4, self.trans_conv4(conv5)) rgc3, edg3, ske3 = self.yuem(conv3, self.trans_conv3(conv4)) rgc2, edg2, ske2 = self.mom(conv2, self.trans_conv2(conv3)) rgc1, edg1, ske1 = self.mom(conv1, self.trans_conv1(conv2)) x_sal1, x_sal2, x_sal3, x_sal4, x_sal5 = self.sal_features(rgc1, rgc2, rgc3, rgc4, rgc5) x_edg1, x_edg2, x_edg3, x_edg4, x_edg5 = self.edg_features(edg1, edg2, edg3, edg4, edg5) x_ske1, x_ske2, x_ske3, x_ske4, x_ske5 = self.ske_features(ske1, ske2, ske3, ske4, ske5) x_sal5_n = self.ccs[0](torch.cat((x_sal5, x_edg5, x_sal5), 1)) + x_sal5 x_edg5_n = self.cme[0](torch.cat((x_sal5, x_edg5, x_sal5), 1)) + x_edg5 x_ske5_n = self.cms[0](torch.cat((x_sal5, x_edg5, x_ske5), 1)) + x_ske5 x_sal4 = self.conv_cats[0](torch.cat((x_sal4, self.cus[0](x_sal5_n, x_sal4)), 1)) x_edg4 = self.conv_cats[1](torch.cat((x_edg4, self.cus[1](x_edg5_n, x_edg4)), 1)) x_ske4 = self.conv_cats[2](torch.cat((x_ske4, self.cus[2](x_ske5_n, x_ske4)), 1)) x_sal4_n = self.MTG(x_sal4, x_edg4, x_ske4) + x_sal4 x_edg4_n = self.MTG(x_sal4, x_edg4, x_ske4) + x_edg4 x_ske4_n = self.MTG(x_sal4, x_edg4, x_ske4) + x_ske4 x_sal3 = self.conv_cats[3](torch.cat((x_sal3, self.cus[3](x_sal4_n, x_sal3)), 1)) x_edg3 = self.conv_cats[4](torch.cat((x_edg3, self.cus[4](x_edg4_n, x_edg3)), 1)) x_ske3 = self.conv_cats[5](torch.cat((x_ske3, self.cus[5](x_ske4_n, x_ske3)), 1)) x_sal3_n = self.MTG(x_sal3, x_edg3, x_ske3) + x_sal3 x_edg3_n = self.MTG(x_sal3, x_edg3, x_ske3) + x_edg3 x_ske3_n = self.MTG(x_sal3, x_edg3, x_ske3) + x_ske3 x_sal2 = self.conv_cats[6](torch.cat((x_sal2, self.cus[6](x_sal3_n, x_sal2)), 1)) x_edg2 = self.conv_cats[7](torch.cat((x_edg2, self.cus[7](x_edg3_n, x_edg2)), 1)) x_ske2 = self.conv_cats[8](torch.cat((x_ske2, self.cus[8](x_ske3_n, x_ske2)), 1)) x_sal2_n = self.MTG(x_sal2, x_edg2, x_ske2) + x_sal2 x_edg2_n = self.MTG(x_sal2, x_edg2, x_ske2) + x_edg2 x_ske2_n = self.MTG(x_sal2, x_edg2, x_ske2) + x_ske2 x_sal1 = self.conv_cats[9](torch.cat((x_sal1, self.cus[9](x_sal2_n, x_sal1)), 1)) x_edg1 = self.conv_cats[10](torch.cat((x_edg1, self.cus[10](x_edg2_n, x_edg1)), 1)) x_ske1 = self.conv_cats[11](torch.cat((x_ske1, self.cus[11](x_ske2_n, x_ske1)), 1)) x_sal1_n = self.MTG(x_sal1, x_edg1, x_ske1) + x_sal1 x_edg1_n = self.MTG(x_sal1, x_edg1, x_ske1) + x_edg1 x_ske1_n = self.MTG(x_sal1, x_edg1, x_ske1) + x_ske1 sal_out = self.prediction[0](x_sal1_n) edg_out = self.prediction[1](x_edg1_n) ske_out = self.prediction[2](x_ske1_n) x_sal2_n = self.prediction[0](x_sal2_n) x_edg2_n = self.prediction[1](x_edg2_n) x_ske2_n = self.prediction[2](x_ske2_n) x_sal3_n = self.prediction[0](x_sal3_n) x_edg3_n = self.prediction[1](x_edg3_n) x_ske3_n = self.prediction[2](x_ske3_n) x_sal4_n = self.prediction[0](x_sal4_n) x_edg4_n = self.prediction[1](x_edg4_n) x_ske4_n = self.prediction[2](x_ske4_n) x_sal5_n = self.prediction[0](x_sal5_n) x_edg5_n = self.prediction[1](x_edg5_n) x_ske5_n = self.prediction[2](x_ske5_n) sal_out = F.interpolate(sal_out, size=size, mode='bilinear', align_corners=True) edg_out = F.interpolate(edg_out, size=size, mode='bilinear', align_corners=True) ske_out = F.interpolate(ske_out, size=size, mode='bilinear', align_corners=True) sal2 = F.interpolate(x_sal2_n, size=size, mode='bilinear', align_corners=True) edg2 = F.interpolate(x_edg2_n, size=size, mode='bilinear', align_corners=True) ske2 = F.interpolate(x_ske2_n, size=size, mode='bilinear', align_corners=True) sal3 = F.interpolate(x_sal3_n, size=size, mode='bilinear', align_corners=True) edg3 = F.interpolate(x_edg3_n, size=size, mode='bilinear', align_corners=True) ske3 = F.interpolate(x_ske3_n, size=size, mode='bilinear', align_corners=True) sal4 = F.interpolate(x_sal4_n, size=size, mode='bilinear', align_corners=True) edg4 = F.interpolate(x_edg4_n, size=size, mode='bilinear', align_corners=True) ske4 = F.interpolate(x_ske4_n, size=size, mode='bilinear', align_corners=True) sal5 = F.interpolate(x_sal5_n, size=size, mode='bilinear', align_corners=True) edg5 = F.interpolate(x_edg5_n, size=size, mode='bilinear', align_corners=True) ske5 = F.interpolate(x_ske5_n, size=size, mode='bilinear', align_corners=True) return x_sal1_n, sal_out, self.sigmoid(sal_out), edg_out, self.sigmoid(edg_out), sal2, edg2, self.sigmoid( sal2), self.sigmoid(edg2), sal3, edg3, self.sigmoid(sal3), self.sigmoid(edg3), sal4, edg4, self.sigmoid( sal4), self.sigmoid(edg4), sal5, edg5, self.sigmoid(sal5), self.sigmoid(edg5), ske_out, self.sigmoid( ske_out), ske2, self.sigmoid(ske2), ske3, self.sigmoid(ske3), ske4, self.sigmoid(ske4), ske5, self.sigmoid( ske5) # return x_sal1_n, sal_out, self.sigmoid(sal_out), edg_out, self.sigmoid(edg_out), sal2, edg2, self.sigmoid( # sal2), self.sigmoid(edg2), sal3, edg3, self.sigmoid(sal3), self.sigmoid(edg3), sal4, edg4, self.sigmoid( # sal4), self.sigmoid(edg4), sal5, edg5, self.sigmoid(sal5), self.sigmoid(edg5)