怎么理解Batch Normalization及其Back Propagation

主要参考:

  1. cs231n Andrej 的讲解


    怎么理解Batch Normalization及其Back Propagation_第1张图片
  2. 国外一个blog:
    https://kratzert.github.io/2016/02/12/understanding-the-gradient-flow-through-the-batch-normalization-layer.html
  3. 数学公式推导解法:
    http://cthorey.github.io./backpropagation/

1. Forward

之所以叫Batch,是因为mean和variance只基于一部分数据来计算。在assignment中,train的时候batch normalization的计算方式是:

    # get the mean of every feature
    u_b = (np.sum(x, axis=0))/N
    # get the variance
    sigma_squared_b = np.sum((x-u_b)**2, axis=0)/N
    # get x_hat
    x_hat = (x-u_b)/np.sqrt(sigma_squared_b+eps)

    out = gamma*x_hat + beta
    cache['mean'] = u_b
    cache['variance'] = sigma_squared_b
    cache['x_hat'] = x_hat
    cache['gamma'] = gamma
    cache['beta'] = beta
    cache['eps'] = eps
    cache['x'] = x

    # keep tracking of running mean and var
    running_mean = momentum*running_mean + (1-momentum)*u_b
    running_var = momentum*running_var + (1-momentum)*sigma_squared_b

注意到,最后还有一个在实时更新的带有momentum的running_mean和running_var。这个update是有意义的,因为它可以调整在训练过程中的数据偏离。

2. Backward

backward是重点介绍对象,一开始我觉得实现起来很直接很简单,后来发现本质上是一个有些复杂的小网络的back prop。cs231n在介绍back prop的时候讲的很好,在面对任何可微分的系统的时候,我们都可以按照graph的方式来进行一点一点的回溯。以下是blog中介绍的batch normalization的graph:

怎么理解Batch Normalization及其Back Propagation_第2张图片

具体一步步的计算直接参考第二个链接就可以了。良心分析啊...

笔者本人感到比较无奈的一点是最近实在是太忙了,对于这个batch normalization的assignment直接去看了外网上的回答和思路没有自己推导。希望之后有时间的时候完全自己再过一遍推一遍,毕竟看懂别人是怎么做的和自己真正完全明白了还是有区别的,可能还有一些隐形的问题我没有发现。

你可能感兴趣的:(怎么理解Batch Normalization及其Back Propagation)