Batch Normalization 简单理解

1:背景

由于在训练神经网络的过程中,每一层的 params是不断更新的,由于params的更新会导致下一层输入的分布情况发生改变,所以这就要求我们进行权重初始化,减小学习率。这个现象就叫做internal covariate shift。

2:idea思想

虽然可以通过whitening来加速收敛,但是需要的计算资源会很大。

而Batch Normalizationn的思想则是对于每一组batch,在网络的每一层中,分feature对输入进行normalization,对各个feature分别normalization,即对网络中每一层的单个神经元输入,计算均值和方差后,再进行normalization。

对于CNN来说normalize “Wx+b”而非 “x”,也可以忽略掉b,即normalize “Wx”,而计算均值和方差的时候,是在feature map的基础上(原来是每一个feature)

3:算法流程(对network进行normalize)

算法一

Batch Normalization 简单理解_第1张图片

算法二

Batch Normalization 简单理解_第2张图片

4:代码(keras):

class BatchNormalization(Layer):
    '''
        Reference:
            Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift
                http://arxiv.org/pdf/1502.03167v3.pdf
            mode: 0 -> featurewise normalization
                  1 -> samplewise normalization (may sometimes outperform featurewise mode)
            momentum: momentum term in the computation of a running estimate of the mean and std of the data
    '''
    def __init__(self, input_shape, epsilon=1e-6, mode=0, momentum=0.9, weights=None):
        super(BatchNormalization, self).__init__()
        self.init = initializations.get("uniform")
        self.input_shape = input_shape
        self.epsilon = epsilon
        self.mode = mode
        self.momentum = momentum
        self.input = ndim_tensor(len(self.input_shape) + 1)

        self.gamma = self.init((self.input_shape))
        self.beta = shared_zeros(self.input_shape)

        self.params = [self.gamma, self.beta]
        self.running_mean = shared_zeros(self.input_shape)
        self.running_std = shared_ones((self.input_shape))
        if weights is not None:
            self.set_weights(weights)

    def get_weights(self):
        return super(BatchNormalization, self).get_weights() + [self.running_mean.get_value(), self.running_std.get_value()]

    def set_weights(self, weights):
        self.running_mean.set_value(floatX(weights[-2]))
        self.running_std.set_value(floatX(weights[-1]))
        super(BatchNormalization, self).set_weights(weights[:-2])

    def init_updates(self):
        X = self.get_input(train=True)
        m = X.mean(axis=0)
        std = T.mean((X - m) ** 2 + self.epsilon, axis=0) ** 0.5
        mean_update = self.momentum * self.running_mean + (1-self.momentum) * m
        std_update = self.momentum * self.running_std + (1-self.momentum) * std
        self.updates = [(self.running_mean, mean_update), (self.running_std, std_update)]

    def get_output(self, train):
        X = self.get_input(train)

        if self.mode == 0:
            X_normed = (X - self.running_mean) / (self.running_std + self.epsilon)

        elif self.mode == 1:
            m = X.mean(axis=-1, keepdims=True)
            std = X.std(axis=-1, keepdims=True)
            X_normed = (X - m) / (std + self.epsilon)

        out = self.gamma * X_normed + self.beta
        return out

    def get_config(self):
        return {"name": self.__class__.__name__,
                "input_shape": self.input_shape,
                "epsilon": self.epsilon,
                "mode": self.mode}

你可能感兴趣的:(papers)