对BN的理解

BN在网络中的位置和操作流程

引言

机器学习有一个重要假设:IID,就是假设训练数据和测试数据是满足相同分布的,BatchNorm就是在深度神经网络训练过程中使得每一层神经网络的输入保持相同分布的。为什么对输入数据做BN,原因在于神经网络学习过程本质上是为了学习数据的分布。

“Internal Covariate Shift”问题:

内部协变量偏移,Internal指的是网络深层的隐层,Covariate(协变量:不可控,但对结果有重要影响)指的是网络的参数。在训练过程中,因为各层参数不停在变化,导致隐层的输入分布老是变来变去。

BN的基本思想:

每个隐层节点的激活输入分布固定下来,避免了“Internal Covariate Shift”问题了,顺带解决反向传播中梯度消失问题。BN思路来源于:如果对图像做白化操作(0均值,1方差的正态分布),神经网络收敛较快,深度神经网络的每一个隐层都是输入层,不过是相对下一层来说而已,BN可以理解为对深层神经网络每个隐层神经元的激活值做简化版本的白化操作。

一句话:对于每个隐层神经元,把逐渐向非线性函数映射后向取值区间极限饱和区靠拢的输入分布强制拉回到均值为0方差为1的比较标准的正态分布,使得非线性变换函数的输入值落入对输入比较敏感的区域,以此避免梯度消失问题。经过BN后,目前大部分Activation的值落入非线性函数的线性区内,其对应的导数远离导数饱和区,这样来加速训练收敛过程。

疑点:BN操作之后,非线性激活函数变成了和线性函数一样的效果,显然是不行的,为了保证非线性的获得,对变换后的满足均值为0方差为1的x又进行了scale加上shift操作(y=scale*x+shift),每个神经元增加了两个参数scale和shift参数,这两个参数是通过训练学习到的,意思是通过scale和shift把这个值从标准正态分布左移或者右移一点并长胖一点或者变瘦一点,每个实例挪动的程度不一样,这样等价于非线性函数的值从正中心周围的线性区往非线性区动了动。这样找到一个线性和非线性的较好的平衡点,既能享受非线性的较强表达能力的好处,又避免太靠非线性区两头使得网络收敛速度太慢。这里理想状态的scale和shift操作会不会又把x变换成未变换之前的状态,又回到Internal Covariate Shift问题哪里?应该不会哈哈哈,否则BN完全没用了啊,事实证明。

Inference时的BN操作:

一个实例是没法算实例集合求出的均值和方差,既然没有从Mini-Batch数据里可以得到的统计量,那就想其它办法来获得这个统计量,就是均值和方差。可以用从所有训练实例中获得的统计量来代替Mini-Batch里面m个训练实例获得的均值和方差统计量,因为本来就打算用全局的统计量,只是因为计算量等太大所以才会用Mini-Batch这种简化方式的,那么在推理的时候直接用全局统计量即可。把每个Mini-Batch的均值和方差统计量记住,然后对这些均值和方差求其对应的数学期望即可得出全局统计量。设置model.eval()的一个作用就是固定BN层,不像在训练阶段去求每个mini-batch的均值方差,而是直接取出之前记录在网络里面的每个mini-batch的方差,去求期望.

个人理解

为什么bs越大越好,因为bs越大,每个bs的分布就越趋近于同分布,这样网络比较容易学习数据的分布规律,梯度更新方向比较一致,收敛更快。

BN中的参数

看一个例子

import torch
import torch.nn as nn
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(6)


    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)

        return x

model = Net()
for name, para in model.named_parameters():
    print(name, para)

print('************************************************************')
for name, buffer in model.named_buffers():
    print(name, buffer)

输出为

OrderedDict([('conv1.weight', tensor([[[[ 0.0108,  0.1240,  0.0641],
          [ 0.0838,  0.0657,  0.0785],
          [ 0.0755, -0.1763, -0.0934]],

         [[-0.1210, -0.1455, -0.1416],
          [ 0.0903,  0.0632,  0.0489],
          [-0.0614, -0.1614,  0.1625]],

         [[ 0.1661, -0.0992, -0.1398],
          [ 0.1170,  0.1084,  0.1536],
          [ 0.0179,  0.1310, -0.0289]]],


        [[[ 0.1363,  0.1840,  0.1140],
          [ 0.0471,  0.0555,  0.1758],
          [-0.0386,  0.1077,  0.1612]],

         [[ 0.1177,  0.1799, -0.0495],
          [-0.0314, -0.1714,  0.1125],
          [-0.0723, -0.0770,  0.1663]],

         [[-0.1474,  0.0866, -0.0111],
          [ 0.1476, -0.0468, -0.0683],
          [ 0.0535,  0.1440,  0.1900]]],


        [[[-0.0954,  0.0743, -0.0975],
          [ 0.0741,  0.1436, -0.1203],
          [-0.0047,  0.1317, -0.1513]],

         [[-0.1422,  0.1404,  0.1614],
          [ 0.0025, -0.1499,  0.1647],
          [ 0.0192,  0.0324,  0.0593]],

         [[-0.0041,  0.1813, -0.1696],
          [ 0.0822,  0.1765, -0.1627],
          [ 0.0262,  0.1857, -0.0359]]],


        [[[-0.1816, -0.1198, -0.1289],
          [-0.0138,  0.1118, -0.0687],
          [-0.0078, -0.0975, -0.0646]],

         [[ 0.1763, -0.0490, -0.1117],
          [ 0.0976, -0.0156,  0.1104],
          [-0.0755,  0.0067,  0.0637]],

         [[-0.0131, -0.1783,  0.0628],
          [ 0.1020,  0.1713, -0.0764],
          [-0.1752,  0.0589, -0.0661]]],


        [[[-0.0292,  0.1491,  0.1690],
          [-0.1483,  0.1089, -0.1463],
          [-0.1159,  0.0097,  0.1525]],

         [[-0.0439, -0.0683, -0.0691],
          [-0.0465, -0.0289,  0.1653],
          [ 0.1307, -0.0170, -0.1875]],

         [[-0.0941,  0.1616,  0.0168],
          [ 0.1385,  0.1919,  0.0238],
          [-0.0705,  0.1550,  0.1585]]],


        [[[ 0.1091,  0.0602, -0.1886],
          [ 0.0663,  0.1151, -0.1629],
          [ 0.0955, -0.1370, -0.1030]],

         [[-0.1690,  0.1786,  0.0723],
          [-0.0280, -0.0451, -0.0303],
          [-0.0342, -0.0909, -0.1883]],

         [[ 0.1072,  0.1869,  0.0249],
          [ 0.1028, -0.1043,  0.0852],
          [-0.0532, -0.1132, -0.0372]]]])), ('conv1.bias', tensor([-0.0907,  0.1700, -0.0342,  0.1511,  0.0931,  0.0797])), ('bn1.weight', tensor([1., 1., 1., 1., 1., 1.])), ('bn1.bias', tensor([0., 0., 0., 0., 0., 0.])), ('bn1.running_mean', tensor([0., 0., 0., 0., 0., 0.])), ('bn1.running_var', tensor([1., 1., 1., 1., 1., 1.])), ('bn1.num_batches_tracked', tensor(0))])

conv1.weight Parameter containing:
tensor([[[[-0.1410,  0.0936, -0.0152],
          [-0.1397, -0.1212, -0.1048],
          [-0.1421, -0.0171,  0.0640]],

         [[ 0.1423, -0.1203, -0.0369],
          [-0.0067,  0.0966,  0.1195],
          [ 0.0143,  0.0839, -0.0283]],

         [[-0.1537, -0.1123, -0.1345],
          [ 0.0886,  0.1017,  0.0533],
          [-0.0084, -0.1251,  0.1744]]],


        [[[ 0.1859, -0.1693, -0.1616],
          [ 0.0567,  0.1256,  0.0887],
          [-0.0761, -0.1245, -0.0764]],

         [[ 0.1298, -0.1307, -0.0978],
          [ 0.0780,  0.0860, -0.0598],
          [-0.0295, -0.1884,  0.0191]],

         [[-0.1898, -0.0489,  0.1485],
          [-0.1887, -0.0618, -0.1429],
          [ 0.1066, -0.0593,  0.0559]]],


        [[[ 0.0189,  0.0575,  0.1358],
          [-0.1079, -0.0591, -0.1221],
          [ 0.0100, -0.0392,  0.0423]],

         [[ 0.1072,  0.1461, -0.1267],
          [-0.1478,  0.1647,  0.1149],
          [ 0.0258, -0.1862, -0.0070]],

         [[ 0.1138, -0.0968,  0.0016],
          [-0.0955,  0.1802,  0.0822],
          [-0.1311,  0.0945, -0.0038]]],


        [[[ 0.1647, -0.0404,  0.0610],
          [-0.1558,  0.1357,  0.1779],
          [-0.0070,  0.1030, -0.0585]],

         [[ 0.1592,  0.0970,  0.0614],
          [-0.0068, -0.0732,  0.1352],
          [ 0.0447,  0.0769, -0.0384]],

         [[-0.0589, -0.0711, -0.0543],
          [ 0.0926, -0.0984, -0.0573],
          [ 0.0687,  0.1849,  0.0993]]],


        [[[ 0.0730,  0.0036,  0.0584],
          [ 0.0568,  0.0311, -0.1742],
          [ 0.1582, -0.0496, -0.0620]],

         [[ 0.0348, -0.1415,  0.0212],
          [-0.1688,  0.0436, -0.1485],
          [ 0.0154, -0.1302,  0.1255]],

         [[ 0.1393,  0.0575, -0.1821],
          [ 0.0244, -0.1584,  0.0886],
          [-0.0158, -0.1907, -0.1038]]],


        [[[ 0.0019, -0.0077, -0.0073],
          [ 0.0667,  0.1904,  0.1622],
          [-0.1315,  0.1265,  0.0110]],

         [[ 0.0979,  0.0211, -0.1126],
          [ 0.1260,  0.1614,  0.0309],
          [-0.0724, -0.1381,  0.1275]],

         [[-0.0206, -0.0674, -0.0358],
          [-0.0800, -0.0408,  0.1636],
          [ 0.0082, -0.0014, -0.0292]]]], requires_grad=True)
conv1.bias Parameter containing:
tensor([-0.0660,  0.0184,  0.0102,  0.1804, -0.0702,  0.0977],
       requires_grad=True)
bn1.weight Parameter containing:
tensor([1., 1., 1., 1., 1., 1.], requires_grad=True)
bn1.bias Parameter containing:
tensor([0., 0., 0., 0., 0., 0.], requires_grad=True)
************************************************************
bn1.running_mean tensor([0., 0., 0., 0., 0., 0.])
bn1.running_var tensor([1., 1., 1., 1., 1., 1.])
bn1.num_batches_tracked tensor(0)

可以看到,网络中的参数除了parameters,还有一些不用更新的参数,主要是BN中的'bn1.running_mean'bn1.running_var,这些参数只在forward时进行统计计算,backward时并不会被更新,这些参数也称为buffer,可以用model.buffers()获取。顺便提一下,在进行推理时设置model.val(),会固定这些参数,不会计算,而是采用记录的全局统计量,如上所述。

创建于2020.11.26

你可能感兴趣的:(对BN的理解)