【PyTorch】模型参数初始化 weights_init

Backto PyTorch Index

方法一:调用 apply

torch.nn.Module.apply(fn)
# 递归的调用weights_init函数,遍历nn.Module的submodule作为参数
# 常用来对模型的参数进行初始化
# fn是对参数进行初始化的函数的句柄,fn以nn.Module或者自己定义的nn.Module的子类作为参数
# fn (Module -> None) – function to be applied to each submodule
# Returns:  self
# Return type:  Module

例子:

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02) 
        # m.weight.data是卷积核参数, m.bias.data是偏置项参数
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

netG = _netG(ngpu) # 生成模型实例
netG.apply(weights_init) # 递归的调用weights_init函数,遍历netG的submodule作为参数

Ref

  • pytorch的weight-initilzation: 还有多种其他方式

你可能感兴趣的:(PyTorch)