池化层的实现

池化层

池化是缩小高、长方向上的空间的运算。如下图所示,进行将2 × 2的区域集约成1个元素的处理,缩小空间大小:
池化层的实现_第1张图片
上述例子是按步幅2进行2 × 2的Max池化时的处理顺序。“Max池化”是获取最大值的运算,“2 × 2”表示目标区域的大小。从2 × 2的区域中取出最大的元素。此外,这个例子中将步幅设为了2,所以2 × 2的窗口的移动间隔为2个元素。另外,一般来说,池化的窗口大小会和步幅设定成相同的值。比如,3 × 3的窗口的步幅会设为3,4 × 4的窗口的步幅会设为4等。除了Max池化之外,还有Average池化等。相对于Max池化是从目标区域中取出最大值,Average池化则是计算目标区域的平均值。在图像识别领域,主要使用Max池化。

池化层的特征

1.没有要学习的参数

池化层和卷积层不同,没有要学习的参数。池化只是从目标区域中取最大值(或者平均值),所以不存在要学习的参数。

2.通道数不发生变化

经过池化运算,输入数据和输出数据的通道数不会发生变化。如下图所示,计算是按通道独立进行的。
池化层的实现_第2张图片

对微小的位置变化具有鲁棒性(健壮)

输入数据发生微小偏差时,池化仍会返回相同的结果。因此,池化对输入数据的微小偏差具有鲁棒性。比如,3 × 3的池化的情况下,如下图所示,池化会吸收输入数据的偏差(根据数据的不同,结果有可能不一致)。
池化层的实现_第3张图片

池化层的实现

池化层的实现和卷积层相同,都使用im2col展开输入数据。不过,池化的情况下,在通道方向上是独立的,这一点和卷积层不同。具体地讲,池化的应用区域按通道单独展开:
池化层的实现_第4张图片
像这样展开之后,只需对展开的矩阵求各行的最大值,并转换为合适的形状即可。
池化层的实现_第5张图片
上面就是池化层的forward处理的实现流程。池化层的backward处理可以参考ReLU层的实现中使用的max的反向传播。Python实现如下:

class Pooling:
    def __init__(self, pool_h, pool_w, stride=1, pad=0):
        self.pool_h = pool_h
        self.pool_w = pool_w
        self.stride = stride
        self.pad = pad
        
        self.x = None
        self.arg_max = None

    def forward(self, x):
        N, C, H, W = x.shape
        out_h = int(1 + (H - self.pool_h) / self.stride)
        out_w = int(1 + (W - self.pool_w) / self.stride)

        col = im2col(x, self.pool_h, self.pool_w, self.stride, self.pad)
        col = col.reshape(-1, self.pool_h*self.pool_w)

        arg_max = np.argmax(col, axis=1)
        out = np.max(col, axis=1)
        out = out.reshape(N, out_h, out_w, C).transpose(0, 3, 1, 2)

        self.x = x
        self.arg_max = arg_max

        return out

    def backward(self, dout):
        dout = dout.transpose(0, 2, 3, 1)
        
        pool_size = self.pool_h * self.pool_w
        dmax = np.zeros((dout.size, pool_size))
        dmax[np.arange(self.arg_max.size), self.arg_max.flatten()] = dout.flatten()
        dmax = dmax.reshape(dout.shape + (pool_size,)) 
        
        dcol = dmax.reshape(dmax.shape[0] * dmax.shape[1] * dmax.shape[2], -1)
        dx = col2im(dcol, self.x.shape, self.pool_h, self.pool_w, self.stride, self.pad)
        
        return dx

你可能感兴趣的:(基于python的深度学习,深度学习,python)