pytorch——weights_init(m)

本文参考连接:https://www.jianshu.com/p/c982d55db463
个人认为是讲的比较通俗易懂的一篇好文。

针对于不同层类型定制化初始化
举个栗子:

def weights_init(m):    ##定义参数初始化函数                  
    classname = m.__class__.__name__    # m作为一个形参,原则上可以传递很多的内容,为了实现多实参传递,每一个moudle要给出自己的name. 所以这句话就是返回m的名字。具体例子下边会详细说明。
    if classname.find('Conv') != -1:#find()函数,实现查找classname中是否含有conv字符,没有返回-1;有返回0.
        nn.init.normal_(m.weight.data, 0.0, 0.02)#m.weight.data表示需要初始化的权重。 nn.init.normal_()表示随机初始化采用正态分布,均值为0,标准差为0.02.
    elif classname.find('BatchNorm') != -1:           
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0) # nn.init.constant_()表示将偏差定义为常量0 

netG.apply(weights_init)#netG是我们给写的神经网络定义的类实例。apply函数会递归地搜索网络内的所有module并把参数表示的函数应用到所有的module上。也就是说apply函数,会一层一层的去拜访Generator网络层。

具体例子见下:

class Generator(nn.Module):  # 创建Generator子类,括号内指定父类的名称
    def __init__(self):  # 初始化父类的属性
        super(Generator, self).__init__()  # 将父类和子类关联,调用父类nn.Moudle的方法__init__(),让Generator实例包含父类的所有属
        self.main = nn.Sequential(  # 按照顺序构造神经层,序列容器
            nn.ConvTranspose2d(5, 10, 4, 1, 0, bias=False),  # 转置卷积,输出
            nn.BatchNorm2d(10)  # 对每个特征图上的点,进行减均值除标准差的操作,affine设置为true(默认),引入权重w和b两个可学习的参数)

netG = Generator()        # 定义类实例

在这个例子中:
1、第一个代码中的classname=ConvTranspose2d,classname=BatchNorm2d。2、第一个代码中的netG.apply(weights_init),会按顺序:先看看nn.ConvTranspose2d这一层,需不需要初始化,然后再看看nn.BatchNorm2d这一层需不需要初始化。
将以上两个例子放在一块,得到如下的代码。个人认为可以这样来理解。对于有些地方还不是很明白,欢迎各位大神指点!!!!!

class Generator(nn.Module):  # 创建Generator子类,括号内指定父类的名称
    def __init__(self):  # 初始化父类的属性
        super(Generator, self).__init__()  # 将父类和子类关联,调用父类nn.Moudle的方法__init__(),让Generator实例包含父类的所有属
        self.main = nn.Sequential(  # 按照顺序构造神经层,序列容器
            nn.ConvTranspose2d(5, 10, 4, 1, 0, bias=False),  # 转置卷积,输出
            nn.BatchNorm2d(10)  # 对每个特征图上的点,进行减均值除标准差的操作,affine设置为true(默认),引入权重w和b两个可学习的参数)

netG = Generator()        # 定义类实例


def weights_init(m):    ##定义参数初始化函数                  
    classname = m.__class__.__name__    # m作为一个形参,原则上可以传递很多的内容,为了实现多实参传递,每一个moudle要给出自己的name. 所以这句话就是返回m的名字。具体例子下边会详细说明。
    if classname.find('Conv') != -1:#find()函数,实现查找classname中是否含有conv字符,没有返回-1;有返回0.
        nn.init.normal_(m.weight.data, 0.0, 0.02)#m.weight.data表示需要初始化的权重。 nn.init.normal_()表示随机初始化采用正态分布,均值为0,标准差为0.02.
    elif classname.find('BatchNorm') != -1:           
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0) # nn.init.constant_()表示将偏差定义为常量0 

netG.apply(weights_init)#netG是我们给写的神经网络定义的类实例。apply函数会递归地搜索网络内的所有module并把参数表示的函数应用到所有的module上。也就是说apply函数,会一层一层的去拜访Generator网络层。

你可能感兴趣的:(PyTorch,weights_init)