PyTorch【7】PyTorch网络权重初始化

一.常用的权重初始化方法

方法名 说明
nn.init.uniform_(tensor,a=0,b=1) 从均匀分布U(a,b)中生成值,填充输入的张量或变量
nn.init.normal_(tensor,mean=0,std=1) 从正态分布(mean,std)中生成值,填充输入的张量或变量
nn.init.constant_(tensor,val) 用val的值填充输入的张量或变量
nn.init.eye_(tensor) 用单位矩阵填充输入的张量或变量
nn.init.dirac_(tensor,groups=1) 用Dirac delta函数填充{3,4,5}维的输入张量或变量,groups为卷积层的组数
nn.init.xavier_normal_(tensor,gain=1) 使用Glorot方法均匀分布生成值,填充输入的张量或变量(gain为缩放尺度)
nn.init.xavier_uniform_(tensor,gain=1) 使用Glorot方法正态分布生成值,填充输入的张量或变量
nn.init.kaiming_normal_(
tensor,a=0,mode='fan_in',nonlinearity='leaky_relu')
使用He方法均匀分布生成值,填充输入的张量或变量
nn.init.kaiming_uniform_(
tensor,a=0,mode='fan_in',nonlinearity='leaky_relu')
使用He方法正态分布生成值,填充输入的张量或变量
nn.init.orthogonal_(tensor,gain=1) 使用正交矩阵填充输入的张量或变量

二.不用的初始化调用方式

1.针对某一层的权重初始化

con1 = torch.nn.Conv2d(3,16,3)
print(con1.weight.shape)	# torch.Size([16,3,3,3])
print(con1.bias.shape)		# torch.Size([16])

torch.nn.init.normal(con1.weight,mean=0,std=1)
torch.nn.init.constant(con1.bias,val=0.1)

2.针对一个网络的权重初始化

class TestNet(nn.Module):
    def __init__(self):
        super(TestNet, self).__init__()
        self.conv1 = nn.Conv2d(3,16,3)
        self.hidden = nn.Sequential(
            nn.Linear(100,100),
            nn.ReLU(),
            nn.Linear(100,50),
            nn.ReLU()
        )
        self.cla = nn.Linear(50,10)

    def forward(self,x):
        x = self.conv1(x)
        x = x.view(x.shape[0],-1)
        x = self.hidden(x)
        output = self.cla(x)
        return output

testnet = TestNet()
print(testnet)
'''
TestNet(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1))
  (hidden): Sequential(
    (0): Linear(in_features=100, out_features=100, bias=True)
    (1): ReLU()
    (2): Linear(in_features=100, out_features=50, bias=True)
    (3): ReLU()
  )
  (cla): Linear(in_features=50, out_features=10, bias=True)
)
'''
def init_weights(m):
    if type(m)==nn.Conv2d:
        torch.nn.init.normal(m.weight,mean=0,std=0.5)
    if type(m)==nn.Linear:
        torch.nn.init.uniform(m.weight,a=-0.1,b=0.1)
        m.bias.data.fill_(0.01)

testnet.apply(init_weights)     

你可能感兴趣的:(PyTorch框架学习,pytorch,深度学习,python)