pytorch 中BatchNormation的理解。

一:先来看看pytorch BatchNorm2d的官方文档。

pytorch 中BatchNormation的理解。_第1张图片

pytorch 中BatchNormation的理解。_第2张图片

  由上图可知,当affine=True时,除了计算batchnormation需要计算均值、方差之外还有两个额外的超参数\gamma ,\beta\beta。这两个超参数是仿射变换所需的参数并且这两个超参数在网络训练中是可以学习到的,即可以不断的更新。所以在计算网络参数时也要考虑batchnormation中的两个参数。这两个参数为向量,向量的维度与输入特征图(C)的个数相同。

二:batchnormation的理解

  对于卷积神经网络来说,其输入尺寸为(N,C,H,W)其中N为每一个batch的数据个数,也就是batchsize。C为输入的特征图个数。H,W分别为输入图像的高和宽。计算batchnormation需要在特征图维度(C)计算其均值和方差。计算出来的均值和方差是一个向量,向量的维度与特征图的维度相同。

三:batchnormation的计算验证。

  为了验证的batchnormation的计算,暂时不考虑其放射变换的两个参数。

  下面是调用pytorch的函数得到的结果。

import torch
input=torch.rand((2,3,4,4))
input
Out[7]: 
tensor([[[[0.1129, 0.1401, 0.5713, 0.7547],
          [0.2804, 0.3129, 0.4287, 0.5344],
          [0.7010, 0.4444, 0.8572, 0.0266],
          [0.4833, 0.1943, 0.8078, 0.9936]],
         [[0.6653, 0.9816, 0.3019, 0.6163],
          [0.3263, 0.9584, 0.5891, 0.9371],
          [0.3813, 0.3770, 0.8869, 0.3120],
          [0.6032, 0.5349, 0.9897, 0.4070]],
         [[0.3461, 0.4580, 0.9292, 0.2834],
          [0.3356, 0.1652, 0.2538, 0.9413],
          [0.2629, 0.1728, 0.7218, 0.4769],
          [0.6479, 0.2530, 0.5863, 0.2996]]],
        [[[0.7151, 0.0034, 0.2873, 0.2120],
          [0.2684, 0.8313, 0.0947, 0.6129],
          [0.4107, 0.9293, 0.6260, 0.7687],
          [0.8199, 0.7393, 0.2184, 0.3194]],
         [[0.6304, 0.9363, 0.9413, 0.9579],
          [0.9455, 0.5982, 0.7114, 0.6138],
          [0.0993, 0.0609, 0.0189, 0.5691],
          [0.7284, 0.2882, 0.5706, 0.5343]],
         [[0.7061, 0.1233, 0.8495, 0.8738],
          [0.5106, 0.9541, 0.9235, 0.8846],
          [0.1265, 0.8666, 0.0171, 0.9323],
          [0.0212, 0.1634, 0.5212, 0.4888]]]])


m = torch.nn.BatchNorm2d(3,affine=False)
m(input)
Out[10]: 
tensor([[[[-1.3104, -1.2145,  0.3066,  0.9536],
          [-0.7195, -0.6048, -0.1963,  0.1763],
          [ 0.7641, -0.1412,  1.3150, -1.6147],
          [-0.0040, -1.0232,  1.1406,  1.7962]],
         [[ 0.2478,  1.3790, -1.0521,  0.0724],
          [-0.9648,  1.2961, -0.0246,  1.2200],
          [-0.7680, -0.7833,  1.0405, -1.0159],
          [ 0.0258, -0.2185,  1.4080, -0.6759]],
         [[-0.5114, -0.1467,  1.3892, -0.7160],
          [-0.5456, -1.1011, -0.8124,  1.4287],
          [-0.7828, -1.0765,  0.7132, -0.0850],
          [ 0.4722, -0.8151,  0.2716, -0.6631]]],
        [[[ 0.8139, -1.6967, -0.6952, -0.9608],
          [-0.7618,  1.2238, -1.3744,  0.4534],
          [-0.2600,  1.5692,  0.4993,  1.0029],
          [ 1.1833,  0.8991, -0.9381, -0.5818]],
         [[ 0.1231,  1.2170,  1.2348,  1.2944],
          [ 1.2500,  0.0078,  0.4126,  0.0635],
          [-1.7766, -1.9138, -2.0640, -0.0964],
          [ 0.4734, -1.1011, -0.0908, -0.2208]],
         [[ 0.6620, -1.2376,  1.1296,  1.2087],
          [ 0.0249,  1.4705,  1.3707,  1.2438],
          [-1.2274,  1.1852, -1.5840,  1.3992],
          [-1.5706, -1.1069,  0.0592, -0.0464]]]])

  接着来看看按自己的理解对其进行验证。

  1:首先求均值,在输入特征图的维度上。

import numpy as np
mean_arr=np.mean(input.numpy(),axis=(0,2,3))
Out[12]: array([0.48439348, 0.5960132 , 0.50301135], dtype=float32)

       2:计算方差

val_arr=np.var(input.numpy(),axis=(0,2,3))
Out[13]: array([0.08036561, 0.07816137, 0.09410284], dtype=float32)

        3:按一中的公式计算batchnormation  (将数据转为numpy,方便后续计算)

mean_arr=np.mean(input.numpy(),axis=(0,2,3))
val_arr=np.var(input.numpy(),axis=(0,2,3))
def cal_batchNorm(input, mean, val, eps=1e-5):
    out = (input - mean) / (np.sqrt(val + eps))
    return out
out=cal_batchNorm(input.numpy(),mean_arr[:,np.newaxis,np.newaxis],val_arr[:,np.newaxis,np.newaxis])
out
Out[18]: 
array([[[[-1.3104438 , -1.2145115 ,  0.30663943,  0.95358294],
         [-0.7195165 , -0.60476387, -0.19628356,  0.17627575],
         [ 0.7641234 , -0.14118858,  1.3150083 , -1.6147343 ],
         [-0.00395737, -1.0231562 ,  1.1405642 ,  1.7962272 ]],
        [[ 0.24783905,  1.3790345 , -1.0520871 ,  0.07244595],
         [-0.96478295,  1.29615   , -0.02457083,  1.2200397 ],
         [-0.7679552 , -0.78325504,  1.0404555 , -1.0158582 ],
         [ 0.0257508 , -0.2185238 ,  1.4080186 , -0.67593056]],
        [[-0.51142544, -0.1467187 ,  1.3892365 , -0.7159999 ],
         [-0.54556066, -1.1010932 , -0.81242603,  1.4286704 ],
         [-0.78281313, -1.0764958 ,  0.71321374, -0.08504389],
         [ 0.47220108, -0.8150544 ,  0.2716269 , -0.6630923 ]]],
       [[[ 0.8138572 , -1.6967356 , -0.695164  , -0.96077466],
         [-0.7617672 ,  1.2238052 , -1.3743901 ,  0.4533996 ],
         [-0.26004952,  1.5692399 ,  0.49934447,  1.0029478 ],
         [ 1.1833199 ,  0.8990661 , -0.9381481 , -0.5818167 ]],
        [[ 0.12309013,  1.2169923 ,  1.2348384 ,  1.2944146 ],
         [ 1.2500226 ,  0.00781386,  0.41258934,  0.0634794 ],
         [-1.7765759 , -1.9137751 , -2.0639913 , -0.09636869],
         [ 0.47342288, -1.101118  , -0.09080926, -0.22079165]],
        [[ 0.6620024 , -1.2376062 ,  1.1295809 ,  1.2087289 ],
         [ 0.02485057,  1.4704677 ,  1.370655  ,  1.2438185 ],
         [-1.2273672 ,  1.1852448 , -1.5840211 ,  1.399184  ],
         [-1.5706285 , -1.1069006 ,  0.05915949, -0.04639584]]]],
      dtype=float32)

         通过上面的结果验证了batchnormation的计算方式。

你可能感兴趣的:(pytorch)