pytorch中模型参数初始化

参数初始化(Weight Initialization)

PyTorch 中参数的默认初始化在各个层的 reset_parameters() 方法中。例如:nn.Linear 和 nn.Conv2D,都是在 [-limit, limit] 之间的均匀分布(Uniform distribution),其中 limit 是 1. / sqrt(fan_in) ,fan_in 是指参数张量(tensor)的输入单元的数量

下面是几种常见的初始化方式。

Xavier Initialization

Xavier初始化的基本思想是保持输入和输出的方差一致,这样就避免了所有输出值都趋向于0。这是通用的方法,适用于任何激活函数。

    def init_model(self):
        for m in self.children():
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                nn.init.xavier_normal_(m.weight)
                logging.info('init mode successful')

或者

for m in model.modules():
    if isinstance(m, (nn.Conv2d, nn.Linear)):
        nn.init.xavier_uniform_(m.weight(), gain=nn.init.calculate_gain('relu'))

也可以使用 gain 参数来自定义初始化的标准差来匹配特定的激活函数:

for m in model.modules():
    if isinstance(m, (nn.Conv2d, nn.Linear)):
        nn.init.xavier_uniform_(m.weight(), gain=nn.init.calculate_gain('relu'))

参考资料:

  • Understanding the difficulty of training deep feedforward neural networks

He et. al Initialization

torch.nn.init.kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')

He initialization的思想是:在ReLU网络中,假定每一层有一半的神经元被激活,另一半为0。推荐在ReLU网络中使用。

# he initialization
for m in model.modules():
    if isinstance(m, (nn.Conv2d, nn.Linear)):
        nn.init.kaiming_normal_(m.weight, mode='fan_in')

正交初始化(Orthogonal Initialization)

主要用以解决深度网络下的梯度消失、梯度爆炸问题,在RNN中经常使用的参数初始化方法。

for m in model.modules():
    if isinstance(m, (nn.Conv2d, nn.Linear)):
        nn.init.orthogonal(m.weight)

Batchnorm Initialization

在非线性激活函数之前,我们想让输出值有比较好的分布(例如高斯分布),以便于计算梯度和更新参数。Batch Normalization 将输出值强行做一次 Gaussian Normalization 和线性变换:

for m in model:
    if isinstance(m, nn.BatchNorm2d):
        nn.init.constant(m.weight, 1)
        nn.init.constant(m.bias, 0)

单层初始化

conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
nn.init.xavier_uniform(conv1.weight)
nn.init.constant(conv1.bias, 0.1)

模型初始化

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv2d') != -1:
        nn.init.xavier_normal_(m.weight.data)
        nn.init.constant_(m.bias.data, 0.0)
    elif classname.find('Linear') != -1:
        nn.init.xavier_normal_(m.weight)
        nn.init.constant_(m.bias, 0.0)
net = Net()
net.apply(weights_init) #apply函数会递归地搜索网络内的所有module并把参数表示的函数应用到所有的module上。

不建议访问以下划线为前缀的成员,他们是内部的,如果有改变不会通知用户。更推荐的一种方法是检查某个module是否是某种类型:

def weights_init(m):
    if isinstance(m, (nn.Conv2d, nn.Linear)):
        nn.init.xavier_normal_(m.weight)
        nn.init.constant_(m.bias, 0.0)

转载于https://blog.csdn.net/rocking_struggling/article/details/109149791

你可能感兴趣的:(计算机视觉,深度学习)