class SpaceToDepth(nn.Module):
def __init__(self, block_size):
super(SpaceToDepth, self).__init__()
self.block_size = block_size
self.block_size_sq = block_size*block_size
def forward(self, input):
output = input.permute(0, 2, 3, 1)
(batch_size, s_height, s_width, s_depth) = output.size()
d_depth = s_depth * self.block_size_sq
d_width = int(s_width / self.block_size)
d_height = int(s_height / self.block_size)
t_1 = output.split(self.block_size, 2)
stack = [t_t.reshape(batch_size, d_height, d_depth) for t_t in t_1]
output = torch.stack(stack, 1)
output = output.permute(0, 2, 1, 3)
output = output.permute(0, 3, 1, 2)
return output
labels = space2depth(labels)
将 labels 输入给 space2depth对象,调用对象里的 forward函数,input:(64,1,120,160)
output = input.permute(0, 2, 3, 1)
交换维度,output:(64,120,160,1)
(batch_size, s_height, s_width, s_depth) = output.size()
batch_size=64,s_height=120,s_width=160,s_depth=1
d_depth = s_depth * self.block_size_sq
其中,self.block_size_sq=block_sizeblock_size=88=64
d_depth=64
d_width = int(s_width / self.block_size)
d_height = int(s_height / self.block_size)
d_width=20
d_height=15
t_1 = output.split(self.block_size, 2)
output:(64,120,160,1)
在 output 的第2维度上,按没 block_size=8,切块
len(t_1)=20
t_1[0].shape=(64, 120, 8, 1)
把batch size 看成1时,可以方便理解这步拆分,如下图:
stack = [t_t.reshape(batch_size, d_height, d_depth) for t_t in t_1]
对 t_1中的每一块(64,120,8,1),总共20块。
reshape成(64,15,64)
同样把batch size 看成1 用于方便理解这部reshape,如下图:
len(stack)=20
output = torch.stack(stack, 1)
把 stack中的 20块,在第1维度堆叠起来:
output:(64,20,15,64)
output = output.permute(0, 2, 1, 3)
output:(64,15,20,64)
output = output.permute(0, 3, 1, 2)
output:(64,64,15,20)
return 返回,回到 labels = space2depth(labels)
也就是说,(64,1,120,160)的labels 经过 space2depth的处理,转换成了 (64,64,15,20)的labels