PyTorch 中参数的默认初始化在各个层的 reset_parameters()
方法中。例如:nn.Linear
和 nn.Conv2D
,都是在 [-limit, limit] 之间的均匀分布(Uniform distribution),其中 limit 是 1. / sqrt(fan_in)
,fan_in
是指参数张量(tensor)的输入单元的数量
下面是几种常见的初始化方式。
Xavier初始化的基本思想是保持输入和输出的方差一致,这样就避免了所有输出值都趋向于0。这是通用的方法,适用于任何激活函数。
def init_model(self):
for m in self.children():
if isinstance(m, (nn.Conv2d, nn.Linear)):
nn.init.xavier_normal_(m.weight)
logging.info('init mode successful')
或者
for m in model.modules():
if isinstance(m, (nn.Conv2d, nn.Linear)):
nn.init.xavier_uniform_(m.weight(), gain=nn.init.calculate_gain('relu'))
也可以使用 gain
参数来自定义初始化的标准差来匹配特定的激活函数:
for m in model.modules():
if isinstance(m, (nn.Conv2d, nn.Linear)):
nn.init.xavier_uniform_(m.weight(), gain=nn.init.calculate_gain('relu'))
参考资料:
torch.nn.init.kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')
He initialization的思想是:在ReLU网络中,假定每一层有一半的神经元被激活,另一半为0。推荐在ReLU网络中使用。
# he initialization
for m in model.modules():
if isinstance(m, (nn.Conv2d, nn.Linear)):
nn.init.kaiming_normal_(m.weight, mode='fan_in')
主要用以解决深度网络下的梯度消失、梯度爆炸问题,在RNN中经常使用的参数初始化方法。
for m in model.modules():
if isinstance(m, (nn.Conv2d, nn.Linear)):
nn.init.orthogonal(m.weight)
在非线性激活函数之前,我们想让输出值有比较好的分布(例如高斯分布),以便于计算梯度和更新参数。Batch Normalization 将输出值强行做一次 Gaussian Normalization 和线性变换:
for m in model:
if isinstance(m, nn.BatchNorm2d):
nn.init.constant(m.weight, 1)
nn.init.constant(m.bias, 0)
conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
nn.init.xavier_uniform(conv1.weight)
nn.init.constant(conv1.bias, 0.1)
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv2d') != -1:
nn.init.xavier_normal_(m.weight.data)
nn.init.constant_(m.bias.data, 0.0)
elif classname.find('Linear') != -1:
nn.init.xavier_normal_(m.weight)
nn.init.constant_(m.bias, 0.0)
net = Net()
net.apply(weights_init) #apply函数会递归地搜索网络内的所有module并把参数表示的函数应用到所有的module上。
不建议访问以下划线为前缀的成员,他们是内部的,如果有改变不会通知用户。更推荐的一种方法是检查某个module是否是某种类型:
def weights_init(m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
nn.init.xavier_normal_(m.weight)
nn.init.constant_(m.bias, 0.0)
转载于https://blog.csdn.net/rocking_struggling/article/details/109149791