class MergeLayer1(nn.Module): # list_k: [[64, 512, 64], [128, 512, 128], [256, 0, 256] ... ]
def __init__(self, list_k):
super(MergeLayer1, self).__init__()
self.list_k = list_k
# list_k = config['merge1']
# 1 'merge1': [[128, 256, 128, 3, 1],
# 2 [256, 512, 256, 3, 1],
# 3 [512, 0, 512, 5, 2],
# 4 [512, 0, 512, 5, 2],
# 5 [512, 0, 512, 7, 3]]
trans, up, score, trans_F2_sal, up_s2, score_s2 = [], [], [], [], [], [] # todo
# todo
up_s2.append(nn.Sequential(nn.Conv2d(128, 128, 3, 1, 1), nn.ReLU(inplace=True),
nn.Conv2d(128, 128, 3, 1, 1), nn.ReLU(inplace=True),
nn.Conv2d(128, 128, 3, 1, 1), nn.ReLU(inplace=True)))
# todo
score_s2.append(nn.Conv2d(128, 1, 3, 1, 1))
for ik in list_k:
if ik[1] > 0: # first two 1,2:
# [128, 256, 128, 3, 1],
# [256, 512, 256, 3, 1]
trans.append(nn.Sequential(nn.Conv2d(ik[1], ik[0], 1, 1, bias=False), nn.ReLU(inplace=True)))
# trans = [(256, 128, 1, 1), (512, 128, 1, 1)]
# 0 1 2 3 4
# 0 C2 [128, 256, 128, 3, 1],
# 1 C3 [256, 512, 256, 3, 1]
# 2 C4 [512, 0, 512, 5, 2],
# 4 C5 [512, 0, 512, 5, 2],
# 5 C6 [512, 0, 512, 7, 3]]
up.append(nn.Sequential(nn.Conv2d(ik[0], ik[2], ik[3], 1, ik[4]), nn.ReLU(inplace=True),
nn.Conv2d(ik[2], ik[2], ik[3], 1, ik[4]), nn.ReLU(inplace=True),
nn.Conv2d(ik[2], ik[2], ik[3], 1, ik[4]), nn.ReLU(inplace=True)))
# 512
# 512
# 512
score.append(nn.Conv2d(ik[2], 1, 3, 1, 1))
# trans = [(256, 128, 1, 1), (512, 256, 1, 1), (512, 128, 1, 1)]
trans.append(nn.Sequential(nn.Conv2d(512, 128, 1, 1, bias=False), nn.ReLU(inplace=True)))
trans_F2_sal.append(nn.Sequential(nn.Conv2d(256, 128, 1, 1, bias=False), nn.ReLU(inplace=True))) # todo
self.trans, self.up, self.score = nn.ModuleList(trans), nn.ModuleList(up), nn.ModuleList(score)
self.trans_F2_sal = nn.ModuleList(trans_F2_sal) # todo
self.up_s2 = nn.ModuleList(up_s2) # todo
self.score_s2 = nn.ModuleList(score_s2) # todo
self.relu = nn.ReLU()
def forward(self, list_x, x_size): # self.merge1(conv2merge, x_size)
# list_x: conv2merge = self.base(x) -> self.base = base -> vgg -> vgg16() or resnet50()
# x_size: x_size = x.size()[2:] -> NCHW -> HW -> [H, W]
up_edge, up_sal, edge_feature, sal_feature, sal_feature_s2 = [], [], [], [], [] # todo
num_f = len(list_x)
c2 = list_x[0] # C2 # todo
tmp = self.up[num_f - 1](list_x[num_f - 1]) # c(6)->F^(6)
sal_feature.append(tmp) # sal_feature = [F^(6)]
U_tmp = tmp # U_tmp = C6
up_sal.append(F.interpolate(self.score[num_f - 1](tmp), x_size, mode='bilinear', align_corners=True))
# up_sal = [D6]
for j in range(2, num_f): # 2->5: 2 3 4
i = num_f - j # i=5-j: 3 2 1
# list[0] C2
# list[1] C3
# list[2] C4
# list[3] C5
# list[4] C6
'''
print('=>list_x[0].size():', list_x[0].size())
=>list_x[0].size() torch.Size([1, 128, 134, 200]) C2
=>list_x[1].size() torch.Size([1, 256, 68, 101]) C3
=>list_x[2].size() torch.Size([1, 512, 34, 51]) C4
=>list_x[3].size() torch.Size([1, 512, 17, 26]) C5
=>list_x[4].size() torch.Size([1, 512, 17, 26]) C6
=>if list_x[i].size()[1] >= U_tmp.size()[1]
512 >= 512
=>if list_x[i].size()[1] >= U_tmp.size()[1]
512 >= 512
=>if list_x[i].size()[1] < U_tmp.size()[1]
256 < 512
'''
# i=3, C5.channel=512 = U_tmp.size()[1]=C6.channel=512
# i=2, C4.channel=512 = 512 ...
# i=1, C3.channel=256 < 512 ...
if list_x[i].size()[1] < U_tmp.size()[1]: # i=1
# trans = [(256, 128, 1, 1), (512, 256, 1, 1), (512, 128, 1, 1)]
# trans[1] = (512, 256, 1 ,1)
# F3 = C3 + C6 -> channel 512 to 256
# -> size to C3.size
U_tmp = list_x[i] + F.interpolate((self.trans[i](U_tmp)), list_x[i].size()[2:], mode='bilinear',
align_corners=True) # U_tmp = F^(6)
else: # i=3,2
# C5.channel=512 = C6.channel=512
# C4.channel=512 = C6.channel=512
# F5 = C5 + C6 -> size to C5
# F4 = C4 + C6 -> size to C4
U_tmp = list_x[i] + F.interpolate((U_tmp), list_x[i].size()[2:], mode='bilinear', align_corners=True)
# i=3,2,1: merged_C5, merged_C4, merged_C3
# tmp= F^(5), F^(4), F^(3)
tmp = self.up[i](U_tmp)
# todo
if i == 1: # list[1] C3, U_tmp=F3
sal_feature_s2.append(U_tmp)
U_tmp = tmp
sal_feature.append(tmp) # sal_feature = [F^(6), F^(5), F^(4), F^(3)]
# => transition layer
# in_channels, out_channels, kernel_size, stride, padding
# i=3 s5 512, 1, 3, 1, 1
# i=2 s4 512, 1, 3, 1, 1
# i=1 s3 256, 1, 3, 1, 1
# up_sal = [D6, D5, D4, D3]
up_sal.append(F.interpolate(self.score[i](tmp), x_size, mode='bilinear', align_corners=True))