PyTorch模型参数初始化

文章目录

  • PyTorch默认模型参数初始化
    • Conv2d
    • BatchNorm2d
    • Linear
  • PyTorch提供的初始化方式
    • 初始化为常数
    • 初始化使值采样于某种分布
    • Xavier初始化
    • Kaiming初始化
    • 其他
    • gain值计算
  • 如何进行参数初始化
  • 单层初始化样例
  • 模型初始化样例

总体来说,模型的初始化是为了让模型能够更快收敛,提高训练速度。当然,也算一个小trick,合理设置是能够提升模型的performance的,当然这就有点炼丹了。

先说明一下,非特殊情况,其实大可不必太关注模型参数初始化。PyTorch默认会进行初始化,如Conv2d,BatchNorm2d和Linear。当然如果有特殊考虑,恰当的初始化是能够给模型Performance有加成的。下面介绍一下PyTorch默认的参数初始化,可以选用的初始化方法以及对整个模型如何进行参数初始化。

PyTorch默认模型参数初始化

Conv2d

Conv2d是集成ConvNd的,在ConvNd中我们可以找到如下初始化代码

def reset_parameters(self) -> None:
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

采用的是Kaiming uniform
Kaiming initialization的详情可以参见论文:Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification He, K. et al. (2015)
其思想可简单概括为:在ReLU网络中,假定每一层有一半的神经元被激活,另一半为0,所以,要保持方差不变,只需要在 Xavier 的基础上再除以2。
PyTorch提供了两个版本的Kaiming initialization

torch.nn.init.kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')

其初始化结果是张量值采样自 U ( − bound , + bound ) U(-\text{bound},+\text{bound}) U(bound,+bound),其中bound的计算公式:
bound = 6 ( 1 + a 2 ) × fan_in \text{bound} = \sqrt{\frac{6}{(1+a^2) \times \text{fan\_in} }} bound=(1+a2)×fan_in6

torch.nn.init.kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')

其初始化结果是张量值采样自 N ( 0 , s t d ) N(0,std) N(0,std),其中std的计算公式:
s t d = 2 ( 1 + a 2 ) × fan_in std = \sqrt{\frac{2}{(1+a^2) \times \text{fan\_in}}} std=(1+a2)×fan_in2

BatchNorm2d

def reset_running_stats(self) -> None:
	if self.track_running_stats:
		self.running_mean.zero_()
		self.running_var.fill_(1)
		self.num_batches_tracked.zero_()

def reset_parameters(self) -> None:
   self.reset_running_stats()
   if self.affine:
       init.ones_(self.weight)
       init.zeros_(self.bias)

weight初始化为1,bias初始化为0。也有看到网上有人提到weight是初始化为 U ( 0 , 1 ) U(0,1) U(0,1),不确定是不是不同版本的PyTorch有改动,这边参考来源是点这里

Linear

def reset_parameters(self) -> None:
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound) 

对weights和bias都在 U ( − bound , + bound ) U(-\text{bound},+\text{bound}) U(bound,+bound)总采样。

PyTorch提供的初始化方式

初始化为常数

torch.nn.init.constant_(tensor, val) #初始化为一个特定常数
torch.nn.init.ones_(tensor) #初始化为1
torch.nn.init.zeros_(tensor) #初始化为0

初始化使值采样于某种分布

torch.nn.init.uniform_(tensor, a=0.0, b=1.0) #均匀分布
torch.nn.init.normal_(tensor, mean=0.0, std=1.0) #正态分布

Xavier初始化

torch.nn.init.xavier_uniform_(tensor, gain=1.0)
torch.nn.init.xavier_normal_(tensor, gain=1.0)

Kaiming初始化

torch.nn.init.kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')
torch.nn.init.kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')

其他

torch.nn.init.orthogonal_(tensor, gain=1) #正交初始化
torch.nn.init.sparse_(tensor, sparsity, std=0.01) #稀疏初始化
torch.nn.init.eye_(tensor)
torch.nn.init.dirac_(tensor, groups=1)

具体可以参见 点这里

gain值计算

Pytorch给出了不同层的推荐gain值
PyTorch模型参数初始化_第1张图片也可以通过调用下面函数计算

torch.nn.init.calculate_gain(nonlinearity, param=None)

比如

gain = nn.init.calculate_gain('leaky_relu', 0.2)  # leaky_relu with negative_slope=0.2

如何进行参数初始化

了解了PyTorch的初始化方式后,下面举例说明一下如何进行参数初始化。

单层初始化样例

conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
nn.init.xavier_uniform(conv1.weight)
nn.init.constant(conv1.bias, 0.1)

模型初始化样例


class Net(nn.Module): 
    def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
        super().__init__()

        self.layer = nn.Sequential(
            nn.Linear(in_dim, n_hidden_1), 
            nn.ReLU(True),
            nn.Linear(n_hidden_1, n_hidden_2),
            nn.ReLU(True),
            nn.Linear(n_hidden_2, out_dim)
             ) 
        self.init_weights()
             
    def _init_weights(self):
        for  m in self._modules:
            if isinstance(m,nn.Conv2d):
            	nn.init.xavier_normal_(m.weight)
            	nn.init.constant_(m.bias, 0.0)

你可能感兴趣的:(实用工具,pytorch,深度学习,神经网络)