MergeLayer1的详细介绍

class MergeLayer1(nn.Module):  # list_k: [[64, 512, 64], [128, 512, 128], [256, 0, 256] ... ]
    def __init__(self, list_k):
        super(MergeLayer1, self).__init__()
        self.list_k = list_k
        # list_k = config['merge1']
        # 1 'merge1': [[128, 256, 128, 3, 1],
        # 2            [256, 512, 256, 3, 1],
        # 3            [512, 0,   512, 5, 2],
        # 4            [512, 0,   512, 5, 2],
        # 5            [512, 0,   512, 7, 3]]
        trans, up, score, trans_F2_sal, up_s2, score_s2 = [], [], [], [], [], []  # todo
        # todo
        up_s2.append(nn.Sequential(nn.Conv2d(128, 128, 3, 1, 1), nn.ReLU(inplace=True),
                                   nn.Conv2d(128, 128, 3, 1, 1), nn.ReLU(inplace=True),
                                   nn.Conv2d(128, 128, 3, 1, 1), nn.ReLU(inplace=True)))
        # todo
        score_s2.append(nn.Conv2d(128, 1, 3, 1, 1))

        for ik in list_k:
            if ik[1] > 0:  # first two 1,2:
                # [128, 256, 128, 3, 1],
                # [256, 512, 256, 3, 1]
                trans.append(nn.Sequential(nn.Conv2d(ik[1], ik[0], 1, 1, bias=False), nn.ReLU(inplace=True)))
                # trans = [(256, 128, 1, 1), (512, 128, 1, 1)]
            #                                0    1    2    3  4
            # 0 C2                          [128, 256, 128, 3, 1],
            # 1 C3                          [256, 512, 256, 3, 1]
            # 2 C4                          [512, 0,   512, 5, 2],
            # 4 C5                          [512, 0,   512, 5, 2],
            # 5 C6                          [512, 0,   512, 7, 3]]
            up.append(nn.Sequential(nn.Conv2d(ik[0], ik[2], ik[3], 1, ik[4]), nn.ReLU(inplace=True),
                                    nn.Conv2d(ik[2], ik[2], ik[3], 1, ik[4]), nn.ReLU(inplace=True),
                                    nn.Conv2d(ik[2], ik[2], ik[3], 1, ik[4]), nn.ReLU(inplace=True)))
            #                      512
            #                      512
            #                      512
            score.append(nn.Conv2d(ik[2], 1, 3, 1, 1))
        # trans = [(256, 128, 1, 1), (512, 256, 1, 1), (512, 128, 1, 1)]
        trans.append(nn.Sequential(nn.Conv2d(512, 128, 1, 1, bias=False), nn.ReLU(inplace=True)))
        trans_F2_sal.append(nn.Sequential(nn.Conv2d(256, 128, 1, 1, bias=False), nn.ReLU(inplace=True)))  # todo
        self.trans, self.up, self.score = nn.ModuleList(trans), nn.ModuleList(up), nn.ModuleList(score)

        self.trans_F2_sal = nn.ModuleList(trans_F2_sal)  # todo
        self.up_s2 = nn.ModuleList(up_s2)  # todo
        self.score_s2 = nn.ModuleList(score_s2)  # todo

        self.relu = nn.ReLU()

    def forward(self, list_x, x_size):  # self.merge1(conv2merge, x_size)
        # list_x:  conv2merge = self.base(x) -> self.base = base -> vgg -> vgg16() or resnet50()
        # x_size:  x_size = x.size()[2:] -> NCHW -> HW -> [H, W]
        up_edge, up_sal, edge_feature, sal_feature, sal_feature_s2 = [], [], [], [], []  # todo

        num_f = len(list_x)
        c2 = list_x[0]  # C2 # todo
        tmp = self.up[num_f - 1](list_x[num_f - 1])  # c(6)->F^(6)
        sal_feature.append(tmp)  # sal_feature = [F^(6)]
        U_tmp = tmp  # U_tmp = C6
        up_sal.append(F.interpolate(self.score[num_f - 1](tmp), x_size, mode='bilinear', align_corners=True))
        # up_sal = [D6]

        for j in range(2, num_f):  # 2->5: 2 3 4
            i = num_f - j  # i=5-j: 3 2 1
            # list[0] C2
            # list[1] C3
            # list[2] C4
            # list[3] C5
            # list[4] C6
            '''
            print('=>list_x[0].size():', list_x[0].size())
            =>list_x[0].size() torch.Size([1, 128, 134, 200])  C2  
            =>list_x[1].size() torch.Size([1, 256, 68, 101])   C3
            =>list_x[2].size() torch.Size([1, 512, 34, 51])    C4 
            =>list_x[3].size() torch.Size([1, 512, 17, 26])    C5 
            =>list_x[4].size() torch.Size([1, 512, 17, 26])    C6
            
            =>if list_x[i].size()[1] >= U_tmp.size()[1] 
                                 512 >= 512
            =>if list_x[i].size()[1] >= U_tmp.size()[1] 
                                 512 >= 512
            =>if list_x[i].size()[1] < U_tmp.size()[1] 
                                 256 < 512
            '''
            # i=3, C5.channel=512  = U_tmp.size()[1]=C6.channel=512
            # i=2, C4.channel=512  = 512 ...
            # i=1, C3.channel=256  < 512 ...
            if list_x[i].size()[1] < U_tmp.size()[1]:  # i=1
                # trans = [(256, 128, 1, 1), (512, 256, 1, 1), (512, 128, 1, 1)]
                # trans[1] = (512, 256, 1 ,1)
                # F3  = C3        + C6 -> channel 512 to 256
                #                      -> size    to C3.size
                U_tmp = list_x[i] + F.interpolate((self.trans[i](U_tmp)), list_x[i].size()[2:], mode='bilinear',
                                                  align_corners=True)  # U_tmp = F^(6)
            else:  # i=3,2
                # C5.channel=512 = C6.channel=512
                # C4.channel=512 = C6.channel=512
                # F5  = C5        + C6 -> size to C5
                # F4  = C4        + C6 -> size to C4
                U_tmp = list_x[i] + F.interpolate((U_tmp), list_x[i].size()[2:], mode='bilinear', align_corners=True)

            # i=3,2,1:  merged_C5, merged_C4, merged_C3
            # tmp= F^(5), F^(4), F^(3)
            tmp = self.up[i](U_tmp)
            # todo
            if i == 1:  # list[1] C3, U_tmp=F3
                sal_feature_s2.append(U_tmp)
            U_tmp = tmp
            sal_feature.append(tmp)  # sal_feature = [F^(6), F^(5), F^(4), F^(3)]
            # => transition layer
            #         in_channels, out_channels, kernel_size, stride, padding
            # i=3 s5                     512, 1, 3, 1, 1
            # i=2 s4                     512, 1, 3, 1, 1
            # i=1 s3                     256, 1, 3, 1, 1
            # up_sal = [D6, D5, D4, D3]
            up_sal.append(F.interpolate(self.score[i](tmp), x_size, mode='bilinear', align_corners=True))

 

你可能感兴趣的:(EGNet)