pytorch 初始化

文章目录

  • 前言
  • 一、初始化整个模型
  • 二、利用apply
    • 1. 判断实例对象类型,利用isinstance()
    • 2. 通过类名判断
  • 总结


前言

学习用,侵删

记录,如何使用pytorch初始化。


一、初始化整个模型

pytorch 官方resnet代码

for m in self.modules():
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
    elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)

其他借鉴

import torch.nn as nn
def initialize_weights(model):
    # Initializes weights according to the DCGAN paper
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)
        # if you also want for linear layers ,add one more elif condition 
def init_all(model, init_func, *params, **kwargs):
    for p in model.parameters():
        init_func(p, *params, **kwargs)

model = UNet(3, 10)
init_all(model, torch.nn.init.normal_, mean=0., std=1) 
# or
init_all(model, torch.nn.init.constant_, 1.) 

参考:

  1. pytorch官方代码:
  2. stack overflow

二、利用apply

1. 判断实例对象类型,利用isinstance()

代码如下:

def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform(m.weight)
        m.bias.data.fill_(0.01)

net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
net.apply(init_weights)

2. 通过类名判断

代码如下:

# takes in a module and applies the specified weight initialization
    def weights_init_uniform(m):
        classname = m.__class__.__name__
        # for every Linear layer in a model..
        if classname.find('Linear') != -1:
            # apply a uniform distribution to the weights and a bias=0
            m.weight.data.uniform_(0.0, 1.0)
            m.bias.data.fill_(0)
     # custom weights initialization called on netG and netD
	def weights_init(m):
	    classname = m.__class__.__name__
	    if classname.find('Conv') != -1:
	        nn.init.normal_(m.weight.data, 0.0, 0.02)
	    elif classname.find('BatchNorm') != -1:
	        nn.init.normal_(m.weight.data, 1.0, 0.02)
	        nn.init.constant_(m.bias.data, 0)

    model_uniform = Net()
    model_uniform.apply(weights_init_uniform)

参考:

  1. 如何自定义参数初始化方式
  2. pytorch 官方文档

总结

如何对多层网络进行初始化,没有自己造过轮子,记录学习一下。

你可能感兴趣的:(深度学习,pytorch,深度学习)