pytroch用自定义的tensor初始化nn.sequential中linear或者conv层的一种简单方法。

话不多说,上代码,上面写的很清楚。

import torch.nn as nn
import torch
net= nn.Sequential(
    nn.Linear(1024, 512),
    nn.ReLU(inplace=True),
    nn.Linear(512, 256),
    nn.ReLU(inplace=True),
    nn.Linear(256, 6),
)
net[4].weight.data=torch.zeros(6,256)
net[4].bias.data=torch.ones(6)
t=torch.randn(32,1024)
print(net(t).size())
cnn=nn.Sequential(
    nn.Conv2d(2,8,3,1,1),
    nn.Conv2d(8,19,3,1,1)
)
cnn[0].weight.data=torch.randn(8*2*3*3).view(8,2,3,3)
cnn[0].bias.data=torch.ones(8)
t=torch.randn(32,2,100,100)
print(cnn(t).size())

注:

  1. 注意线性层和卷积层输入通道和输出通道的关系,初始化的时候要是转置的形式。
  2. 后面生成一个tensor送入网络是为了测试初始化的正确性
  3. 出错的话请参考:The expanded size of the tensor (256) must match the existing size (81) at non-singleton dimension1
  4. 虽然代码中写的是sequential,对于module同样是可以用的

你可能感兴趣的:(pytorch学习)