pytorch:nn.BatchNormal

这里用的是batchnormal1d

import torch
from torch import nn

x = torch.rand([2,3,16]) #2张照片,3个通道,每个通道16个pixel
batch_normal = nn.BatchNorm1d(3) #括号里面的数字要与通道数一样
normal_result=batch_normal(x)

print(normal_result.size())

print(batch_normal.running_mean) #每个batch上16*2=32个数字的平均值,一共3个平均值

print(batch_normal.running_var) #每个batch上16*2=32个数字的方差,一共3个方差

在这里插入图片描述
pytorch:nn.BatchNormal_第1张图片
我们可以看到batch_normal的mean和var和直接算出来的mean和var不太一样,并且batch_normal的mean刚好是直接算出来的mean的1/10.
这个与batchnormal的momentum参数有关。
pytorch:nn.BatchNormal_第2张图片
把momentum设置成1,那么batch_normal的mean和var就和直接算出来的一样了:
pytorch:nn.BatchNormal_第3张图片

你可能感兴趣的:(pytorch)