pytorch中模型参数初始化方法

法一:

1、自定义初始化函数

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

2、应用到模型上

netG.apply(weights_init)

法二:

直接遍历整个网络参数,添加判断条件

for name, param in net.named_parameters():
    if 'weight' in name:
        init.normal_(param, mean=0, std=0.01)
        print(name, param.data)

 

你可能感兴趣的:(pytorch,机器学习,深度学习,人工智能)