pytorch模型初始化参数

这个一般在构建模型nn.Module类时就完成比较容易。首先介绍几个常用的函数

1.常用函数

1.1 self.modules()

这是在继承了nn.Module后的属性,存储所有层结构,可遍历(用来判断层类型)。

1.2 isinstance()

判断类型,如果类型符合就开始初始化参数了

1.3 实例

class Model(torch.nn.Module):
    def __init__(self):
        # 初始化操作
	
    def __forward__(self, x):
        # 计算
    
    # 重点!!!初始化
    def _initialize_weights(self):
        for m in self.modules():  # self.modules()获得所有层,可遍历
            if isinstance(m, nn.Linear):  # isinstance判断层类型
                torch.nn.init.xavier_uniform_(m.weight, gain=1)  # 这里执行初始化操作

初始化操作有很多类型,甚至可以自定义,下面将遇到的进行总结。

目前我遇到的有:sigmod,Xavier

2.sigmod

这里会涉及到梯度爆炸和梯度消失的问题,遇到后补充!!!

3.Xavier

为了解决梯度爆炸和消失的问题,解决方案是保障每一层网络的输入和输出的方差基本一致。

3.1 原理

大佬原文:https://blog.csdn.net/u011534057/article/details/51673458

Xavier有两种形式,uniform和normal

3.2 torch.nn.init.xavier_uniform_

构造如下:

torch.nn.init.xavier_uniform_(tensor, gain=1)  

# tensor传入的是模型指定层,比如上面代码中符合条件的m

符合均匀分布~U(-a, a),a的计算公式如下:
pytorch模型初始化参数_第1张图片

3.3 torch.nn.init.xavier_normal_

构造如下:

torch.nn.init.xavier_normal_(tensor, gain=1)  # tensor同上

符合正态分布~N(0, std),std计算公式如下:

pytorch模型初始化参数_第2张图片

3.4 fan_in和fan_out

以Linear层为例,经过调试源码,我们发现fan_in和fan_out就是Linear层的输入和输出,即Linear参数

in_features和out_features。
pytorch模型初始化参数_第3张图片
其中一些判断条件直接跳过了,可能在别的地方会用到,遇到补充!!!

你可能感兴趣的:(Pytorch使用,python,深度学习,人工智能)