[深度学习] SpaceToDepth 类

类代码

class SpaceToDepth(nn.Module):
    def __init__(self, block_size):
        super(SpaceToDepth, self).__init__()
        self.block_size = block_size
        self.block_size_sq = block_size*block_size

    def forward(self, input):
        output = input.permute(0, 2, 3, 1)
        (batch_size, s_height, s_width, s_depth) = output.size()
        d_depth = s_depth * self.block_size_sq
        d_width = int(s_width / self.block_size)
        d_height = int(s_height / self.block_size)
        t_1 = output.split(self.block_size, 2)
        stack = [t_t.reshape(batch_size, d_height, d_depth) for t_t in t_1]
        output = torch.stack(stack, 1)
        output = output.permute(0, 2, 1, 3)
        output = output.permute(0, 3, 1, 2)
        return output

逐句解析

labels = space2depth(labels)

将 labels 输入给 space2depth对象,调用对象里的 forward函数,input:(64,1,120,160)

output = input.permute(0, 2, 3, 1)

交换维度,output:(64,120,160,1)

(batch_size, s_height, s_width, s_depth) = output.size()

batch_size=64,s_height=120,s_width=160,s_depth=1

d_depth = s_depth * self.block_size_sq

其中,self.block_size_sq=block_sizeblock_size=88=64
d_depth=64

d_width = int(s_width / self.block_size)
d_height = int(s_height / self.block_size)

d_width=20
d_height=15

t_1 = output.split(self.block_size, 2)

output:(64,120,160,1)

在 output 的第2维度上,按没 block_size=8,切块

len(t_1)=20

t_1[0].shape=(64, 120, 8, 1)

把batch size 看成1时,可以方便理解这步拆分,如下图:
[深度学习] SpaceToDepth 类_第1张图片

stack = [t_t.reshape(batch_size, d_height, d_depth) for t_t in t_1]

对 t_1中的每一块(64,120,8,1),总共20块。

reshape成(64,15,64)

同样把batch size 看成1 用于方便理解这部reshape,如下图:
[深度学习] SpaceToDepth 类_第2张图片
len(stack)=20

output = torch.stack(stack, 1)

把 stack中的 20块,在第1维度堆叠起来:
[深度学习] SpaceToDepth 类_第3张图片
output:(64,20,15,64)

output = output.permute(0, 2, 1, 3)

output:(64,15,20,64)

output = output.permute(0, 3, 1, 2)

output:(64,64,15,20)

return 返回,回到 labels = space2depth(labels)

也就是说,(64,1,120,160)的labels 经过 space2depth的处理,转换成了 (64,64,15,20)的labels

你可能感兴趣的:(深度学习,python,人工智能)