pytorch中BatchNorm2d的理解

b1=torch.nn.BatchNorm2d(3)
a=torch.randn(2,3,4,4)

c=b1(a)
c.size()
Out[14]: torch.Size([2, 3, 4, 4])

(a[0,0]-torch.cat((a[0,0],a[1,0]),dim=1).mean())/
torch.pow(torch.cat((a[0,0],a[1,0]),dim=1).var(unbiased=False)+1e-5,0.5)*b1.weight[0]+b1.bias[0]
Out[23]: 
tensor([[-1.4331,  0.4803, -0.9487,  0.4142],
        [-0.4953,  0.2832,  0.0450, -0.2222],
        [-0.1621, -0.7239, -0.6519, -0.1368],
        [-1.2073,  0.3538, -0.9681, -0.1016]], grad_fn=)
c[0,0]
Out[24]: 
tensor([[-1.4331,  0.4803, -0.9487,  0.4142],
        [-0.4953,  0.2832,  0.0450, -0.2222],
        [-0.1621, -0.7239, -0.6519, -0.1368],
        [-1.2073,  0.3538, -0.9681, -0.1016]], grad_fn=)

b1.weight
Out[26]: 
tensor([0.7185, 0.6812, 0.0770], requires_grad=True)
b1.bias
Out[27]: 
tensor([0., 0., 0.], requires_grad=True)
b1.eps
Out[28]: 1e-05


b1.running_mean
Out[29]: tensor([-0.0067,  0.0155, -0.0003])
b1.running_var
Out[30]: tensor([0.9891, 0.9808, 0.9868])
torch.cat((a[0,0],a[1,0])).mean()
Out[31]: tensor(-0.0668) #r_mean=0.9*r_mean+0.1*batch_mean   ,ini=0
torch.cat((a[0,0],a[1,0])).var(unbiased=False)
Out[32]: tensor(0.8631) #r_var=0.9*r_var+0.1*batch_var     ,ini=1,unbiased=True
#ref[4]中关于r_var初始化为0是错误的

pytorch中BatchNorm2d的理解_第1张图片

参考:

https://blog.csdn.net/tmk_01/article/details/80679549

https://blog.csdn.net/LoseInVain/article/details/86476010

https://blog.csdn.net/xk_snail/article/details/80006624

https://blog.csdn.net/qunnie_yi/article/details/80128445

你可能感兴趣的:(算法学习)